diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..8693ea7bdbee695fd0c934b72ea9485f5ffe7147
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,32 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8ca5096f34244d1078a811f3c6d21d19eb9d3a44
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,151 @@
+### Project ignore
+
+/ParallelWaveGAN
+/wavegan_pretrained*
+/pretrained_models
+rsync
+.idea
+.DS_Store
+bak
+tmp
+*.tar.gz
+# mfa and kaldi
+kaldi_align/exp
+mfa
+montreal-forced-aligner
+mos
+nbs
+/configs_usr/*
+!/configs_usr/.gitkeep
+/fast_transformers
+/rnnoise
+/usr/*
+!/usr/.gitkeep
+
+# Created by .ignore support plugin (hsz.mobi)
+### Python template
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+将删除 datasets/remi/test/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..563c90c56cc4bea1f10c6307123319d92441b556
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Jinglin Liu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7db9b17d3c23482bd97236d4cd83de574d8f1561
--- /dev/null
+++ b/README.md
@@ -0,0 +1,10 @@
+---
+title: ProDiff
+emoji: 🤗
+colorFrom: yellow
+colorTo: orange
+sdk: gradio
+app_file: "inference/gradio/infer.py"
+pinned: false
+---
+
diff --git a/checkpoints/FastDiff/config.yaml b/checkpoints/FastDiff/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6cf5325101f782ab41eff945a16ab23a32f73d65
--- /dev/null
+++ b/checkpoints/FastDiff/config.yaml
@@ -0,0 +1,149 @@
+N: ''
+T: 1000
+accumulate_grad_batches: 1
+amp: false
+audio_channels: 1
+audio_num_mel_bins: 80
+audio_sample_rate: 22050
+aux_context_window: 0
+beta_0: 1.0e-06
+beta_T: 0.01
+binarization_args:
+  reset_phone_dict: true
+  reset_word_dict: true
+  shuffle: false
+  trim_eos_bos: false
+  with_align: false
+  with_f0: false
+  with_f0cwt: false
+  with_linear: false
+  with_spk_embed: false
+  with_spk_id: true
+  with_txt: false
+  with_wav: true
+  with_word: false
+binarizer_cls: data_gen.tts.vocoder_binarizer.VocoderBinarizer
+binary_data_dir: data/binary/LJSpeech
+check_val_every_n_epoch: 10
+clip_grad_norm: 1
+clip_grad_value: 0
+cond_channels: 80
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+dict_dir: ''
+diffusion_step_embed_dim_in: 128
+diffusion_step_embed_dim_mid: 512
+diffusion_step_embed_dim_out: 512
+disc_start_steps: 40000
+discriminator_grad_norm: 1
+dropout: 0.0
+ds_workers: 1
+enc_ffn_kernel_size: 9
+enc_layers: 4
+endless_ds: true
+eval_max_batches: -1
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 1024
+fmax: 7600
+fmin: 80
+frames_multiple: 1
+gen_dir_name: ''
+generator_grad_norm: 10
+griffin_lim_iters: 60
+hidden_size: 256
+hop_size: 256
+infer: false
+inner_channels: 32
+kpnet_conv_size: 3
+kpnet_hidden_channels: 64
+load_ckpt: ''
+loud_norm: false
+lr: 2e-4
+lvc_kernel_size: 3
+lvc_layers_each_block: 4
+max_epochs: 1000
+max_frames: 1548
+max_input_tokens: 1550
+max_samples: 25600
+max_sentences: 20
+max_tokens: 30000
+max_updates: 1000000
+max_valid_sentences: 1
+max_valid_tokens: 60000
+mel_loss: l1
+mel_vmax: 1.5
+mel_vmin: -6
+mfa_version: 2
+min_frames: 0
+min_level_db: -100
+noise_schedule: ''
+num_ckpt_keep: 3
+num_heads: 2
+num_mels: 80
+num_sanity_val_steps: -1
+num_spk: 400
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pitch_extractor: parselmouth
+pre_align_args:
+  allow_no_txt: false
+  denoise: false
+  nsample_per_mfa_group: 1000
+  sox_resample: false
+  sox_to_wav: false
+  trim_sil: false
+  txt_processor: en
+  use_tone: true
+pre_align_cls: egs.datasets.audio.pre_align.PreAlign
+print_nan_grads: false
+processed_data_dir: data/processed/LJSpeech
+profile_infer: false
+raw_data_dir: data/raw/LJSpeech-1.1
+ref_level_db: 20
+rename_tmux: true
+resume_from_checkpoint: 0
+save_best: true
+save_codes: []
+save_f0: false
+save_gt: true
+scheduler: rsqrt
+seed: 1234
+sort_by_len: true
+task_cls: modules.FastDiff.task.FastDiff.FastDiffTask
+tb_log_interval: 100
+test_ids: []
+test_input_dir: ''
+test_mel_dir: ''
+test_num: 100
+test_set_name: test
+train_set_name: train
+train_sets: ''
+upsample_ratios:
+- 8
+- 8
+- 4
+use_pitch_embed: false
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_wav: true
+use_weight_norm: true
+use_word_input: false
+val_check_interval: 2000
+valid_infer_interval: 10000
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+valid_set_name: valid
+vocoder_denoise_c: 0.0
+warmup_updates: 8000
+weight_decay: 0
+win_length: null
+win_size: 1024
+window: hann
+word_size: 30000
+work_dir: checkpoints/FastDiff
diff --git a/checkpoints/FastDiff/model_ckpt_steps_500000.ckpt b/checkpoints/FastDiff/model_ckpt_steps_500000.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..226e9f776d4950cf1711a08420a51c0c1aa0c526
--- /dev/null
+++ b/checkpoints/FastDiff/model_ckpt_steps_500000.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee7b6022e525c71a6025b41eeeafff9d6186b52cba76b580d6986bc8674902f3
+size 183951271
diff --git a/checkpoints/ProDiff/config.yaml b/checkpoints/ProDiff/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..27aa0559523fb9aa3fc0c92ef77a07ab9eece19f
--- /dev/null
+++ b/checkpoints/ProDiff/config.yaml
@@ -0,0 +1,205 @@
+accumulate_grad_batches: 1
+amp: false
+audio_num_mel_bins: 80
+audio_sample_rate: 22050
+base_config:
+- ./base.yaml
+binarization_args:
+  reset_phone_dict: true
+  reset_word_dict: true
+  shuffle: false
+  trim_eos_bos: false
+  trim_sil: false
+  with_align: true
+  with_f0: true
+  with_f0cwt: false
+  with_linear: false
+  with_spk_embed: false
+  with_spk_id: true
+  with_txt: true
+  with_wav: false
+  with_word: true
+binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
+binary_data_dir: data/binary/LJSpeech
+check_val_every_n_epoch: 10
+clip_grad_norm: 1
+clip_grad_value: 0
+conv_use_pos: false
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+debug: false
+dec_dilations:
+- 1
+- 1
+- 1
+- 1
+dec_ffn_kernel_size: 9
+dec_inp_add_noise: false
+dec_kernel_size: 5
+dec_layers: 4
+dec_num_heads: 2
+decoder_rnn_dim: 0
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l1
+dilation_cycle_length: 1
+dropout: 0.1
+ds_workers: 2
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 2
+enc_dec_norm: ln
+enc_dilations:
+- 1
+- 1
+- 1
+- 1
+enc_ffn_kernel_size: 9
+enc_kernel_size: 5
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: true
+ffn_act: gelu
+ffn_hidden_size: 1024
+ffn_padding: SAME
+fft_size: 1024
+fmax: 7600
+fmin: 80
+frames_multiple: 1
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+griffin_lim_iters: 60
+hidden_size: 256
+hop_size: 256
+infer: false
+keep_bins: 80
+lambda_commit: 0.25
+lambda_energy: 0.1
+lambda_f0: 1.0
+lambda_ph_dur: 0.1
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+layers_in_block: 2
+load_ckpt: ''
+loud_norm: false
+lr: 1.0
+max_beta: 0.06
+max_epochs: 1000
+max_frames: 1548
+max_input_tokens: 1550
+max_sentences: 48
+max_tokens: 32000
+max_updates: 200000
+max_valid_sentences: 1
+max_valid_tokens: 60000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6
+min_frames: 0
+min_level_db: -100
+num_ckpt_keep: 3
+num_heads: 2
+num_sanity_val_steps: -1
+num_spk: 1
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pitch_ar: false
+pitch_embed_type: 0
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l1
+pitch_norm: standard
+pitch_ssim_win: 11
+pitch_type: frame
+pre_align_args:
+  allow_no_txt: false
+  denoise: false
+  sox_resample: false
+  sox_to_wav: false
+  trim_sil: false
+  txt_processor: en
+  use_tone: true
+pre_align_cls: ''
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 2
+pretrain_fs_ckpt: ''
+print_nan_grads: false
+processed_data_dir: data/processed/LJSpeech
+profile_infer: false
+raw_data_dir: data/raw/LJSpeech
+ref_hidden_stride_kernel:
+- 0,3,5
+- 0,3,5
+- 0,2,5
+- 0,2,5
+- 0,2,5
+ref_level_db: 20
+ref_norm_layer: bn
+rename_tmux: true
+residual_channels: 256
+residual_layers: 20
+resume_from_checkpoint: 0
+save_best: true
+save_codes: []
+save_f0: false
+save_gt: true
+schedule_type: vpsde
+scheduler: rsqrt
+seed: 1234
+sil_add_noise: false
+sort_by_len: true
+spec_max: []
+spec_min: []
+task_cls: modules.ProDiff.task.ProDiff_task.ProDiff_Task
+tb_log_interval: 100
+teacher_ckpt: checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt
+test_ids: []
+test_input_dir: ''
+test_num: 100
+test_set_name: test
+timesteps: 4
+train_set_name: train
+train_sets: ''
+use_cond_disc: true
+use_energy_embed: true
+use_gt_dur: true
+use_gt_f0: true
+use_pitch_embed: true
+use_pos_embed: true
+use_ref_enc: false
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: true
+use_var_enc: false
+val_check_interval: 2000
+valid_infer_interval: 10000
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+valid_set_name: valid
+var_enc_vq_codes: 64
+vocoder_denoise_c: 0.0
+warmup_updates: 2000
+weight_decay: 0
+win_size: 1024
+word_size: 30000
+work_dir: checkpoints/ProDiff
diff --git a/checkpoints/ProDiff/model_ckpt_steps_200000.ckpt b/checkpoints/ProDiff/model_ckpt_steps_200000.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..2e8e845b8cd31da66a969c32db251f7bd55af470
--- /dev/null
+++ b/checkpoints/ProDiff/model_ckpt_steps_200000.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8cc8aad355c297b010e2c362341f736b3477744af76e02f6c9965409a7e9113a
+size 349055740
diff --git a/checkpoints/ProDiff_Teacher/config.yaml b/checkpoints/ProDiff_Teacher/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fd6983c1d8889954b0a1dfdf34c66f72a4407bbf
--- /dev/null
+++ b/checkpoints/ProDiff_Teacher/config.yaml
@@ -0,0 +1,205 @@
+accumulate_grad_batches: 1
+amp: false
+audio_num_mel_bins: 80
+audio_sample_rate: 22050
+base_config:
+- ./base.yaml
+binarization_args:
+  reset_phone_dict: true
+  reset_word_dict: true
+  shuffle: false
+  trim_eos_bos: false
+  trim_sil: false
+  with_align: true
+  with_f0: true
+  with_f0cwt: false
+  with_linear: false
+  with_spk_embed: false
+  with_spk_id: true
+  with_txt: true
+  with_wav: false
+  with_word: true
+binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
+binary_data_dir: data/binary/LJSpeech
+check_val_every_n_epoch: 10
+clip_grad_norm: 1
+clip_grad_value: 0
+conv_use_pos: false
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+debug: false
+dec_dilations:
+- 1
+- 1
+- 1
+- 1
+dec_ffn_kernel_size: 9
+dec_inp_add_noise: false
+dec_kernel_size: 5
+dec_layers: 4
+dec_num_heads: 2
+decoder_rnn_dim: 0
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l1
+dilation_cycle_length: 1
+dropout: 0.1
+ds_workers: 2
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 2
+enc_dec_norm: ln
+enc_dilations:
+- 1
+- 1
+- 1
+- 1
+enc_ffn_kernel_size: 9
+enc_kernel_size: 5
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: true
+ffn_act: gelu
+ffn_hidden_size: 1024
+ffn_padding: SAME
+fft_size: 1024
+fmax: 7600
+fmin: 80
+frames_multiple: 1
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+griffin_lim_iters: 60
+hidden_size: 256
+hop_size: 256
+infer: false
+keep_bins: 80
+lambda_commit: 0.25
+lambda_energy: 0.1
+lambda_f0: 1.0
+lambda_ph_dur: 0.1
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+layers_in_block: 2
+load_ckpt: ''
+loud_norm: false
+lr: 1.0
+max_beta: 0.06
+max_epochs: 1000
+max_frames: 1548
+max_input_tokens: 1550
+max_sentences: 48
+max_tokens: 32000
+max_updates: 200000
+max_valid_sentences: 1
+max_valid_tokens: 60000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6
+min_frames: 0
+min_level_db: -100
+num_ckpt_keep: 3
+num_heads: 2
+num_sanity_val_steps: -1
+num_spk: 1
+num_test_samples: 20
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pitch_ar: false
+pitch_embed_type: 0
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l1
+pitch_norm: standard
+pitch_ssim_win: 11
+pitch_type: frame
+pre_align_args:
+  allow_no_txt: false
+  denoise: false
+  sox_resample: false
+  sox_to_wav: false
+  trim_sil: false
+  txt_processor: en
+  use_tone: true
+pre_align_cls: egs.datasets.audio.lj.pre_align.LJPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 2
+pretrain_fs_ckpt: ''
+print_nan_grads: false
+processed_data_dir: data/processed/LJSpeech
+profile_infer: false
+raw_data_dir: data/raw/LJSpeech
+ref_hidden_stride_kernel:
+- 0,3,5
+- 0,3,5
+- 0,2,5
+- 0,2,5
+- 0,2,5
+ref_level_db: 20
+ref_norm_layer: bn
+rename_tmux: true
+residual_channels: 256
+residual_layers: 20
+resume_from_checkpoint: 0
+save_best: true
+save_codes: []
+save_f0: false
+save_gt: true
+schedule_type: vpsde
+scheduler: rsqrt
+seed: 1234
+sil_add_noise: false
+sort_by_len: true
+spec_max: []
+spec_min: []
+task_cls: modules.ProDiff.task.ProDiff_teacher_task.ProDiff_teacher_Task
+tb_log_interval: 100
+test_ids: []
+test_input_dir: ''
+test_num: 100
+test_set_name: test
+timescale: 1
+timesteps: 4
+train_set_name: train
+train_sets: ''
+use_cond_disc: true
+use_energy_embed: true
+use_gt_dur: true
+use_gt_f0: true
+use_pitch_embed: true
+use_pos_embed: true
+use_ref_enc: false
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: true
+use_var_enc: false
+val_check_interval: 2000
+valid_infer_interval: 10000
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+valid_set_name: valid
+var_enc_vq_codes: 64
+vocoder_denoise_c: 0.0
+warmup_updates: 2000
+weight_decay: 0
+win_size: 1024
+word_size: 30000
+work_dir: checkpoints/ProDiff_Teacher1
diff --git a/checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt b/checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..a3ca44b2160a0ff0ade91e6e52a6c06b87f0d1f7
--- /dev/null
+++ b/checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d3d02a215431c69dd54c1413b9a02cdc32795e2039ad9be857b12e85c470eea
+size 342252871
diff --git a/data/binary/LJSpeech/phone_set.json b/data/binary/LJSpeech/phone_set.json
new file mode 100644
index 0000000000000000000000000000000000000000..1d097037b62382bae4893512cc291d8699c0b049
--- /dev/null
+++ b/data/binary/LJSpeech/phone_set.json
@@ -0,0 +1 @@
+["!", ",", ".", ":", ";", "<BOS>", "<EOS>", "?", "AA0", "AA1", "AA2", "AE0", "AE1", "AE2", "AH0", "AH1", "AH2", "AO0", "AO1", "AO2", "AW0", "AW1", "AW2", "AY0", "AY1", "AY2", "B", "CH", "D", "DH", "EH0", "EH1", "EH2", "ER0", "ER1", "ER2", "EY0", "EY1", "EY2", "F", "G", "HH", "IH0", "IH1", "IH2", "IY0", "IY1", "IY2", "JH", "K", "L", "M", "N", "NG", "OW0", "OW1", "OW2", "OY0", "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH0", "UH1", "UH2", "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH", "|"]
\ No newline at end of file
diff --git a/data/binary/LJSpeech/spk_map.json b/data/binary/LJSpeech/spk_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..15bba8f120494a14ecd1308c6f534ec7e4322391
--- /dev/null
+++ b/data/binary/LJSpeech/spk_map.json
@@ -0,0 +1 @@
+{"SPK1": 0}
\ No newline at end of file
diff --git a/data/binary/LJSpeech/train_f0s_mean_std.npy b/data/binary/LJSpeech/train_f0s_mean_std.npy
new file mode 100644
index 0000000000000000000000000000000000000000..42b6fc952934d7e7aedc61c7975ef437c4981d08
--- /dev/null
+++ b/data/binary/LJSpeech/train_f0s_mean_std.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8790d5a84d77143690ae71a1f1e7fc81359e69ead263dc440366f2164c739efd
+size 144
diff --git a/data_gen/tts/base_binarizer.py b/data_gen/tts/base_binarizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b30a20c1cdc3403214ff527d68a50806befafeb9
--- /dev/null
+++ b/data_gen/tts/base_binarizer.py
@@ -0,0 +1,224 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+
+from utils.multiprocess_utils import chunked_multiprocess_run
+import random
+import traceback
+import json
+from resemblyzer import VoiceEncoder
+from tqdm import tqdm
+from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder
+from utils.hparams import set_hparams, hparams
+import numpy as np
+from utils.indexed_datasets import IndexedDatasetBuilder
+from vocoders.base_vocoder import VOCODERS
+import pandas as pd
+
+
+class BinarizationError(Exception):
+    pass
+
+
+class BaseBinarizer:
+    def __init__(self, processed_data_dir=None):
+        if processed_data_dir is None:
+            processed_data_dir = hparams['processed_data_dir']
+        self.processed_data_dirs = processed_data_dir.split(",")
+        self.binarization_args = hparams['binarization_args']
+        self.pre_align_args = hparams['pre_align_args']
+        self.forced_align = self.pre_align_args['forced_align']
+        tg_dir = None
+        if self.forced_align == 'mfa':
+            tg_dir = 'mfa_outputs'
+        if self.forced_align == 'kaldi':
+            tg_dir = 'kaldi_outputs'
+        self.item2txt = {}
+        self.item2ph = {}
+        self.item2wavfn = {}
+        self.item2tgfn = {}
+        self.item2spk = {}
+        for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
+            self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str)
+            for r_idx, r in self.meta_df.iterrows():
+                item_name = raw_item_name = r['item_name']
+                if len(self.processed_data_dirs) > 1:
+                    item_name = f'ds{ds_id}_{item_name}'
+                self.item2txt[item_name] = r['txt']
+                self.item2ph[item_name] = r['ph']
+                self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1])
+                self.item2spk[item_name] = r.get('spk', 'SPK1')
+                if len(self.processed_data_dirs) > 1:
+                    self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
+                if tg_dir is not None:
+                    self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid"
+        self.item_names = sorted(list(self.item2txt.keys()))
+        if self.binarization_args['shuffle']:
+            random.seed(1234)
+            random.shuffle(self.item_names)
+
+    @property
+    def train_item_names(self):
+        return self.item_names[hparams['test_num']+hparams['valid_num']:]
+
+    @property
+    def valid_item_names(self):
+        return self.item_names[0: hparams['test_num']+hparams['valid_num']]  #
+
+    @property
+    def test_item_names(self):
+        return self.item_names[0: hparams['test_num']]  # Audios for MOS testing are in 'test_ids'
+
+    def build_spk_map(self):
+        spk_map = set()
+        for item_name in self.item_names:
+            spk_name = self.item2spk[item_name]
+            spk_map.add(spk_name)
+        spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
+        assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
+        return spk_map
+
+    def item_name2spk_id(self, item_name):
+        return self.spk_map[self.item2spk[item_name]]
+
+    def _phone_encoder(self):
+        ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
+        ph_set = []
+        if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
+            for processed_data_dir in self.processed_data_dirs:
+                ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()]
+            ph_set = sorted(set(ph_set))
+            json.dump(ph_set, open(ph_set_fn, 'w'))
+        else:
+            ph_set = json.load(open(ph_set_fn, 'r'))
+        print("| phone set: ", ph_set)
+        return build_phone_encoder(hparams['binary_data_dir'])
+
+    def meta_data(self, prefix):
+        if prefix == 'valid':
+            item_names = self.valid_item_names
+        elif prefix == 'test':
+            item_names = self.test_item_names
+        else:
+            item_names = self.train_item_names
+        for item_name in item_names:
+            ph = self.item2ph[item_name]
+            txt = self.item2txt[item_name]
+            tg_fn = self.item2tgfn.get(item_name)
+            wav_fn = self.item2wavfn[item_name]
+            spk_id = self.item_name2spk_id(item_name)
+            yield item_name, ph, txt, tg_fn, wav_fn, spk_id
+
+    def process(self):
+        os.makedirs(hparams['binary_data_dir'], exist_ok=True)
+        self.spk_map = self.build_spk_map()
+        print("| spk_map: ", self.spk_map)
+        spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
+        json.dump(self.spk_map, open(spk_map_fn, 'w'))
+
+        self.phone_encoder = self._phone_encoder()
+        self.process_data('valid')
+        self.process_data('test')
+        self.process_data('train')
+
+    def process_data(self, prefix):
+        data_dir = hparams['binary_data_dir']
+        args = []
+        builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
+        lengths = []
+        f0s = []
+        total_sec = 0
+        if self.binarization_args['with_spk_embed']:
+            voice_encoder = VoiceEncoder().cuda()
+
+        meta_data = list(self.meta_data(prefix))
+        for m in meta_data:
+            args.append(list(m) + [self.phone_encoder, self.binarization_args])
+        num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3))
+        for f_id, (_, item) in enumerate(
+                zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))):
+            if item is None:
+                continue
+            item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
+                if self.binarization_args['with_spk_embed'] else None
+            if not self.binarization_args['with_wav'] and 'wav' in item:
+                print("del wav")
+                del item['wav']
+            builder.add_item(item)
+            lengths.append(item['len'])
+            total_sec += item['sec']
+            if item.get('f0') is not None:
+                f0s.append(item['f0'])
+        builder.finalize()
+        np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
+        if len(f0s) > 0:
+            f0s = np.concatenate(f0s, 0)
+            f0s = f0s[f0s != 0]
+            np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
+        print(f"| {prefix} total duration: {total_sec:.3f}s")
+
+    @classmethod
+    def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
+        if hparams['vocoder'] in VOCODERS:
+            wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
+        else:
+            wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
+        res = {
+            'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
+            'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
+        }
+        try:
+            if binarization_args['with_f0']:
+                cls.get_pitch(wav, mel, res)
+                if binarization_args['with_f0cwt']:
+                    cls.get_f0cwt(res['f0'], res)
+            if binarization_args['with_txt']:
+                try:
+                    phone_encoded = res['phone'] = encoder.encode(ph)
+                except:
+                    traceback.print_exc()
+                    raise BinarizationError(f"Empty phoneme")
+                if binarization_args['with_align']:
+                    cls.get_align(tg_fn, ph, mel, phone_encoded, res)
+        except BinarizationError as e:
+            print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
+            return None
+        return res
+
+    @staticmethod
+    def get_align(tg_fn, ph, mel, phone_encoded, res):
+        if tg_fn is not None and os.path.exists(tg_fn):
+            mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
+        else:
+            raise BinarizationError(f"Align not found")
+        if mel2ph.max() - 1 >= len(phone_encoded):
+            raise BinarizationError(
+                f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
+        res['mel2ph'] = mel2ph
+        res['dur'] = dur
+
+    @staticmethod
+    def get_pitch(wav, mel, res):
+        f0, pitch_coarse = get_pitch(wav, mel, hparams)
+        if sum(f0) == 0:
+            raise BinarizationError("Empty f0")
+        res['f0'] = f0
+        res['pitch'] = pitch_coarse
+
+    @staticmethod
+    def get_f0cwt(f0, res):
+        from utils.cwt import get_cont_lf0, get_lf0_cwt
+        uv, cont_lf0_lpf = get_cont_lf0(f0)
+        logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
+        cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
+        Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
+        if np.any(np.isnan(Wavelet_lf0)):
+            raise BinarizationError("NaN CWT")
+        res['cwt_spec'] = Wavelet_lf0
+        res['cwt_scales'] = scales
+        res['f0_mean'] = logf0s_mean_org
+        res['f0_std'] = logf0s_std_org
+
+
+if __name__ == "__main__":
+    set_hparams()
+    BaseBinarizer().process()
diff --git a/data_gen/tts/base_preprocess.py b/data_gen/tts/base_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c0b2cda06076d32b4eda800b134415e20d0f730
--- /dev/null
+++ b/data_gen/tts/base_preprocess.py
@@ -0,0 +1,245 @@
+import json
+import os
+import random
+import re
+import traceback
+from collections import Counter
+from functools import partial
+
+import librosa
+from tqdm import tqdm
+from data_gen.tts.txt_processors.base_text_processor import get_txt_processor_cls
+from data_gen.tts.wav_processors.base_processor import get_wav_processor_cls
+from utils.hparams import hparams
+from utils.multiprocess_utils import multiprocess_run_tqdm
+from utils.os_utils import link_file, move_file, remove_file
+from data_gen.tts.data_gen_utils import is_sil_phoneme, build_token_encoder
+
+
+class BasePreprocessor:
+    def __init__(self):
+        self.preprocess_args = hparams['preprocess_args']
+        txt_processor = self.preprocess_args['txt_processor']
+        self.txt_processor = get_txt_processor_cls(txt_processor)
+        self.raw_data_dir = hparams['raw_data_dir']
+        self.processed_dir = hparams['processed_data_dir']
+        self.spk_map_fn = f"{self.processed_dir}/spk_map.json"
+
+    def meta_data(self):
+        """
+        :return: {'item_name': Str, 'wav_fn': Str, 'txt': Str, 'spk_name': Str, 'txt_loader': None or Func}
+        """
+        raise NotImplementedError
+
+    def process(self):
+        processed_dir = self.processed_dir
+        wav_processed_tmp_dir = f'{processed_dir}/processed_tmp'
+        remove_file(wav_processed_tmp_dir)
+        os.makedirs(wav_processed_tmp_dir, exist_ok=True)
+        wav_processed_dir = f'{processed_dir}/{self.wav_processed_dirname}'
+        remove_file(wav_processed_dir)
+        os.makedirs(wav_processed_dir, exist_ok=True)
+
+        meta_data = list(tqdm(self.meta_data(), desc='Load meta data'))
+        item_names = [d['item_name'] for d in meta_data]
+        assert len(item_names) == len(set(item_names)), 'Key `item_name` should be Unique.'
+
+        # preprocess data
+        phone_list = []
+        word_list = []
+        spk_names = set()
+        process_item = partial(self.preprocess_first_pass,
+                               txt_processor=self.txt_processor,
+                               wav_processed_dir=wav_processed_dir,
+                               wav_processed_tmp=wav_processed_tmp_dir,
+                               preprocess_args=self.preprocess_args)
+        items = []
+        args = [{
+            'item_name': item_raw['item_name'],
+            'txt_raw': item_raw['txt'],
+            'wav_fn': item_raw['wav_fn'],
+            'txt_loader': item_raw.get('txt_loader'),
+            'others': item_raw.get('others', None)
+        } for item_raw in meta_data]
+        for item_, (item_id, item) in zip(meta_data, multiprocess_run_tqdm(process_item, args, desc='Preprocess')):
+            if item is not None:
+                item_.update(item)
+                item = item_
+                if 'txt_loader' in item:
+                    del item['txt_loader']
+                item['id'] = item_id
+                item['spk_name'] = item.get('spk_name', '<SINGLE_SPK>')
+                item['others'] = item.get('others', None)
+                phone_list += item['ph'].split(" ")
+                word_list += item['word'].split(" ")
+                spk_names.add(item['spk_name'])
+                items.append(item)
+
+        # add encoded tokens
+        ph_encoder, word_encoder = self._phone_encoder(phone_list), self._word_encoder(word_list)
+        spk_map = self.build_spk_map(spk_names)
+        args = [{
+            'ph': item['ph'], 'word': item['word'], 'spk_name': item['spk_name'],
+            'word_encoder': word_encoder, 'ph_encoder': ph_encoder, 'spk_map': spk_map
+        } for item in items]
+        for idx, item_new_kv in multiprocess_run_tqdm(self.preprocess_second_pass, args, desc='Add encoded tokens'):
+            items[idx].update(item_new_kv)
+
+        # build mfa data
+        if self.preprocess_args['use_mfa']:
+            mfa_dict = set()
+            mfa_input_dir = f'{processed_dir}/mfa_inputs'
+            remove_file(mfa_input_dir)
+            # group MFA inputs for better parallelism
+            mfa_groups = [i // self.preprocess_args['nsample_per_mfa_group'] for i in range(len(items))]
+            if self.preprocess_args['mfa_group_shuffle']:
+                random.seed(hparams['seed'])
+                random.shuffle(mfa_groups)
+            args = [{
+                'item': item, 'mfa_input_dir': mfa_input_dir,
+                'mfa_group': mfa_group, 'wav_processed_tmp': wav_processed_tmp_dir,
+                'preprocess_args': self.preprocess_args
+            } for item, mfa_group in zip(items, mfa_groups)]
+            for i, (ph_gb_word_nosil, new_wav_align_fn) in multiprocess_run_tqdm(
+                    self.build_mfa_inputs, args, desc='Build MFA data'):
+                items[i]['wav_align_fn'] = new_wav_align_fn
+                for w in ph_gb_word_nosil.split(" "):
+                    mfa_dict.add(f"{w} {w.replace('_', ' ')}")
+            mfa_dict = sorted(mfa_dict)
+            with open(f'{processed_dir}/mfa_dict.txt', 'w') as f:
+                f.writelines([f'{l}\n' for l in mfa_dict])
+        with open(f"{processed_dir}/{self.meta_csv_filename}.json", 'w') as f:
+            f.write(re.sub(r'\n\s+([\d+\]])', r'\1', json.dumps(items, ensure_ascii=False, sort_keys=False, indent=1)))
+        remove_file(wav_processed_tmp_dir)
+
+    @classmethod
+    def preprocess_first_pass(cls, item_name, txt_raw, txt_processor,
+                              wav_fn, wav_processed_dir, wav_processed_tmp,
+                              preprocess_args, txt_loader=None, others=None):
+        try:
+            if txt_loader is not None:
+                txt_raw = txt_loader(txt_raw)
+            ph, txt, word, ph2word, ph_gb_word = cls.txt_to_ph(txt_processor, txt_raw, preprocess_args)
+            wav_fn, wav_align_fn = cls.process_wav(
+                item_name, wav_fn,
+                hparams['processed_data_dir'],
+                wav_processed_tmp, preprocess_args)
+
+            # wav for binarization
+            ext = os.path.splitext(wav_fn)[1]
+            os.makedirs(wav_processed_dir, exist_ok=True)
+            new_wav_fn = f"{wav_processed_dir}/{item_name}{ext}"
+            move_link_func = move_file if os.path.dirname(wav_fn) == wav_processed_tmp else link_file
+            move_link_func(wav_fn, new_wav_fn)
+            return {
+                'txt': txt, 'txt_raw': txt_raw, 'ph': ph,
+                'word': word, 'ph2word': ph2word, 'ph_gb_word': ph_gb_word,
+                'wav_fn': new_wav_fn, 'wav_align_fn': wav_align_fn,
+                'others': others
+            }
+        except:
+            traceback.print_exc()
+            print(f"| Error is caught. item_name: {item_name}.")
+            return None
+
+    @staticmethod
+    def txt_to_ph(txt_processor, txt_raw, preprocess_args):
+        txt_struct, txt = txt_processor.process(txt_raw, preprocess_args)
+        ph = [p for w in txt_struct for p in w[1]]
+        return " ".join(ph), txt
+
+    @staticmethod
+    def process_wav(item_name, wav_fn, processed_dir, wav_processed_tmp, preprocess_args):
+        processors = [get_wav_processor_cls(v) for v in preprocess_args['wav_processors']]
+        processors = [k() for k in processors if k is not None]
+        if len(processors) >= 1:
+            sr_file = librosa.core.get_samplerate(wav_fn)
+            output_fn_for_align = None
+            ext = os.path.splitext(wav_fn)[1]
+            input_fn = f"{wav_processed_tmp}/{item_name}{ext}"
+            link_file(wav_fn, input_fn)
+            for p in processors:
+                outputs = p.process(input_fn, sr_file, wav_processed_tmp, processed_dir, item_name, preprocess_args)
+                if len(outputs) == 3:
+                    input_fn, sr, output_fn_for_align = outputs
+                else:
+                    input_fn, sr = outputs
+            return input_fn, output_fn_for_align
+        else:
+            return wav_fn, wav_fn
+
+    def _phone_encoder(self, ph_set):
+        ph_set_fn = f"{self.processed_dir}/phone_set.json"
+        if self.preprocess_args['reset_phone_dict'] or not os.path.exists(ph_set_fn):
+            ph_set = sorted(set(ph_set))
+            json.dump(ph_set, open(ph_set_fn, 'w'), ensure_ascii=False)
+            print("| Build phone set: ", ph_set)
+        else:
+            ph_set = json.load(open(ph_set_fn, 'r'))
+            print("| Load phone set: ", ph_set)
+        return build_token_encoder(ph_set_fn)
+
+    def _word_encoder(self, word_set):
+        word_set_fn = f"{self.processed_dir}/word_set.json"
+        if self.preprocess_args['reset_word_dict']:
+            word_set = Counter(word_set)
+            total_words = sum(word_set.values())
+            word_set = word_set.most_common(hparams['word_dict_size'])
+            num_unk_words = total_words - sum([x[1] for x in word_set])
+            word_set = ['<BOS>', '<EOS>'] + [x[0] for x in word_set]
+            word_set = sorted(set(word_set))
+            json.dump(word_set, open(word_set_fn, 'w'), ensure_ascii=False)
+            print(f"| Build word set. Size: {len(word_set)}, #total words: {total_words},"
+                  f" #unk_words: {num_unk_words}, word_set[:10]:, {word_set[:10]}.")
+        else:
+            word_set = json.load(open(word_set_fn, 'r'))
+            print("| Load word set. Size: ", len(word_set), word_set[:10])
+        return build_token_encoder(word_set_fn)
+
+    @classmethod
+    def preprocess_second_pass(cls, word, ph, spk_name, word_encoder, ph_encoder, spk_map):
+        word_token = word_encoder.encode(word)
+        ph_token = ph_encoder.encode(ph)
+        spk_id = spk_map[spk_name]
+        return {'word_token': word_token, 'ph_token': ph_token, 'spk_id': spk_id}
+
+    def build_spk_map(self, spk_names):
+        spk_map = {x: i for i, x in enumerate(sorted(list(spk_names)))}
+        assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
+        print(f"| Number of spks: {len(spk_map)}, spk_map: {spk_map}")
+        json.dump(spk_map, open(self.spk_map_fn, 'w'), ensure_ascii=False)
+        return spk_map
+
+    @classmethod
+    def build_mfa_inputs(cls, item, mfa_input_dir, mfa_group, wav_processed_tmp, preprocess_args):
+        item_name = item['item_name']
+        wav_align_fn = item['wav_align_fn']
+        ph_gb_word = item['ph_gb_word']
+        ext = os.path.splitext(wav_align_fn)[1]
+        mfa_input_group_dir = f'{mfa_input_dir}/{mfa_group}'
+        os.makedirs(mfa_input_group_dir, exist_ok=True)
+        new_wav_align_fn = f"{mfa_input_group_dir}/{item_name}{ext}"
+        move_link_func = move_file if os.path.dirname(wav_align_fn) == wav_processed_tmp else link_file
+        move_link_func(wav_align_fn, new_wav_align_fn)
+        ph_gb_word_nosil = " ".join(["_".join([p for p in w.split("_") if not is_sil_phoneme(p)])
+                                     for w in ph_gb_word.split(" ") if not is_sil_phoneme(w)])
+        with open(f'{mfa_input_group_dir}/{item_name}.lab', 'w') as f_txt:
+            f_txt.write(ph_gb_word_nosil)
+        return ph_gb_word_nosil, new_wav_align_fn
+
+    def load_spk_map(self, base_dir):
+        spk_map_fn = f"{base_dir}/spk_map.json"
+        spk_map = json.load(open(spk_map_fn, 'r'))
+        return spk_map
+
+    def load_dict(self, base_dir):
+        ph_encoder = build_token_encoder(f'{base_dir}/phone_set.json')
+        return ph_encoder
+
+    @property
+    def meta_csv_filename(self):
+        return 'metadata'
+
+    @property
+    def wav_processed_dirname(self):
+        return 'wav_processed'
\ No newline at end of file
diff --git a/data_gen/tts/bin/binarize.py b/data_gen/tts/bin/binarize.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bd3c1f69fa59ed52fdd32eb80e746dedbae7535
--- /dev/null
+++ b/data_gen/tts/bin/binarize.py
@@ -0,0 +1,20 @@
+import os
+
+os.environ["OMP_NUM_THREADS"] = "1"
+
+import importlib
+from utils.hparams import set_hparams, hparams
+
+
+def binarize():
+    binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
+    pkg = ".".join(binarizer_cls.split(".")[:-1])
+    cls_name = binarizer_cls.split(".")[-1]
+    binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
+    print("| Binarizer: ", binarizer_cls)
+    binarizer_cls().process()
+
+
+if __name__ == '__main__':
+    set_hparams()
+    binarize()
diff --git a/data_gen/tts/data_gen_utils.py b/data_gen/tts/data_gen_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b6bf10862cf3f9a8b2aee560ae5d44eabbf00bc
--- /dev/null
+++ b/data_gen/tts/data_gen_utils.py
@@ -0,0 +1,352 @@
+import warnings
+
+warnings.filterwarnings("ignore")
+
+# import parselmouth
+import os
+import torch
+from skimage.transform import resize
+from utils.text_encoder import TokenTextEncoder
+from utils.pitch_utils import f0_to_coarse
+import struct
+import webrtcvad
+from scipy.ndimage.morphology import binary_dilation
+import librosa
+import numpy as np
+from utils import audio
+import pyloudnorm as pyln
+import re
+import json
+from collections import OrderedDict
+
+PUNCS = '!,.?;:'
+
+int16_max = (2 ** 15) - 1
+
+
+def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
+    """
+    Ensures that segments without voice in the waveform remain no longer than a
+    threshold determined by the VAD parameters in params.py.
+    :param wav: the raw waveform as a numpy array of floats
+    :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
+    :return: the same waveform with silences trimmed away (length <= original wav length)
+    """
+
+    ## Voice Activation Detection
+    # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
+    # This sets the granularity of the VAD. Should not need to be changed.
+    sampling_rate = 16000
+    wav_raw, sr = librosa.core.load(path, sr=sr)
+
+    if norm:
+        meter = pyln.Meter(sr)  # create BS.1770 meter
+        loudness = meter.integrated_loudness(wav_raw)
+        wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
+        if np.abs(wav_raw).max() > 1.0:
+            wav_raw = wav_raw / np.abs(wav_raw).max()
+
+    wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
+
+    vad_window_length = 30  # In milliseconds
+    # Number of frames to average together when performing the moving average smoothing.
+    # The larger this value, the larger the VAD variations must be to not get smoothed out.
+    vad_moving_average_width = 8
+
+    # Compute the voice detection window size
+    samples_per_window = (vad_window_length * sampling_rate) // 1000
+
+    # Trim the end of the audio to have a multiple of the window size
+    wav = wav[:len(wav) - (len(wav) % samples_per_window)]
+
+    # Convert the float waveform to 16-bit mono PCM
+    pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
+
+    # Perform voice activation detection
+    voice_flags = []
+    vad = webrtcvad.Vad(mode=3)
+    for window_start in range(0, len(wav), samples_per_window):
+        window_end = window_start + samples_per_window
+        voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
+                                         sample_rate=sampling_rate))
+    voice_flags = np.array(voice_flags)
+
+    # Smooth the voice detection with a moving average
+    def moving_average(array, width):
+        array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
+        ret = np.cumsum(array_padded, dtype=float)
+        ret[width:] = ret[width:] - ret[:-width]
+        return ret[width - 1:] / width
+
+    audio_mask = moving_average(voice_flags, vad_moving_average_width)
+    audio_mask = np.round(audio_mask).astype(np.bool)
+
+    # Dilate the voiced regions
+    audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
+    audio_mask = np.repeat(audio_mask, samples_per_window)
+    audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
+    if return_raw_wav:
+        return wav_raw, audio_mask, sr
+    return wav_raw[audio_mask], audio_mask, sr
+
+
+def process_utterance(wav_path,
+                      fft_size=1024,
+                      hop_size=256,
+                      win_length=1024,
+                      window="hann",
+                      num_mels=80,
+                      fmin=80,
+                      fmax=7600,
+                      eps=1e-6,
+                      sample_rate=22050,
+                      loud_norm=False,
+                      min_level_db=-100,
+                      return_linear=False,
+                      trim_long_sil=False, vocoder='pwg'):
+    if isinstance(wav_path, str):
+        if trim_long_sil:
+            wav, _, _ = trim_long_silences(wav_path, sample_rate)
+        else:
+            wav, _ = librosa.core.load(wav_path, sr=sample_rate)
+    else:
+        wav = wav_path
+
+    if loud_norm:
+        meter = pyln.Meter(sample_rate)  # create BS.1770 meter
+        loudness = meter.integrated_loudness(wav)
+        wav = pyln.normalize.loudness(wav, loudness, -22.0)
+        if np.abs(wav).max() > 1:
+            wav = wav / np.abs(wav).max()
+
+    # get amplitude spectrogram
+    x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
+                          win_length=win_length, window=window, pad_mode="constant")
+    spc = np.abs(x_stft)  # (n_bins, T)
+
+    # get mel basis
+    fmin = 0 if fmin == -1 else fmin
+    fmax = sample_rate / 2 if fmax == -1 else fmax
+    mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
+    mel = mel_basis @ spc
+
+    if vocoder == 'pwg':
+        mel = np.log10(np.maximum(eps, mel))  # (n_mel_bins, T)
+    else:
+        assert False, f'"{vocoder}" is not in ["pwg"].'
+
+    l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1)
+    wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
+    wav = wav[:mel.shape[1] * hop_size]
+
+    if not return_linear:
+        return wav, mel
+    else:
+        spc = audio.amp_to_db(spc)
+        spc = audio.normalize(spc, {'min_level_db': min_level_db})
+        return wav, mel, spc
+
+
+def get_pitch(wav_data, mel, hparams):
+    """
+
+    :param wav_data: [T]
+    :param mel: [T, 80]
+    :param hparams:
+    :return:
+    """
+    time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000
+    f0_min = 80
+    f0_max = 750
+
+    if hparams['hop_size'] == 128:
+        pad_size = 4
+    elif hparams['hop_size'] == 256:
+        pad_size = 2
+    else:
+        assert False
+
+    f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
+        time_step=time_step / 1000, voicing_threshold=0.6,
+        pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
+    lpad = pad_size * 2
+    rpad = len(mel) - len(f0) - lpad
+    f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
+    # mel and f0 are extracted by 2 different libraries. we should force them to have the same length.
+    # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value...
+    # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda)
+    delta_l = len(mel) - len(f0)
+    assert np.abs(delta_l) <= 8
+    if delta_l > 0:
+        f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
+    f0 = f0[:len(mel)]
+    pitch_coarse = f0_to_coarse(f0)
+    return f0, pitch_coarse
+
+
+def remove_empty_lines(text):
+    """remove empty lines"""
+    assert (len(text) > 0)
+    assert (isinstance(text, list))
+    text = [t.strip() for t in text]
+    if "" in text:
+        text.remove("")
+    return text
+
+
+class TextGrid(object):
+    def __init__(self, text):
+        text = remove_empty_lines(text)
+        self.text = text
+        self.line_count = 0
+        self._get_type()
+        self._get_time_intval()
+        self._get_size()
+        self.tier_list = []
+        self._get_item_list()
+
+    def _extract_pattern(self, pattern, inc):
+        """
+        Parameters
+        ----------
+        pattern : regex to extract pattern
+        inc : increment of line count after extraction
+        Returns
+        -------
+        group : extracted info
+        """
+        try:
+            group = re.match(pattern, self.text[self.line_count]).group(1)
+            self.line_count += inc
+        except AttributeError:
+            raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
+        return group
+
+    def _get_type(self):
+        self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
+
+    def _get_time_intval(self):
+        self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
+        self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
+
+    def _get_size(self):
+        self.size = int(self._extract_pattern(r"size = (.*)", 2))
+
+    def _get_item_list(self):
+        """Only supports IntervalTier currently"""
+        for itemIdx in range(1, self.size + 1):
+            tier = OrderedDict()
+            item_list = []
+            tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
+            tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
+            if tier_class != "IntervalTier":
+                raise NotImplementedError("Only IntervalTier class is supported currently")
+            tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
+            tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
+            tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
+            tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
+            for i in range(int(tier_size)):
+                item = OrderedDict()
+                item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
+                item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
+                item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
+                item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
+                item_list.append(item)
+            tier["idx"] = tier_idx
+            tier["class"] = tier_class
+            tier["name"] = tier_name
+            tier["xmin"] = tier_xmin
+            tier["xmax"] = tier_xmax
+            tier["size"] = tier_size
+            tier["items"] = item_list
+            self.tier_list.append(tier)
+
+    def toJson(self):
+        _json = OrderedDict()
+        _json["file_type"] = self.file_type
+        _json["xmin"] = self.xmin
+        _json["xmax"] = self.xmax
+        _json["size"] = self.size
+        _json["tiers"] = self.tier_list
+        return json.dumps(_json, ensure_ascii=False, indent=2)
+
+
+def get_mel2ph(tg_fn, ph, mel, hparams):
+    ph_list = ph.split(" ")
+    with open(tg_fn, "r") as f:
+        tg = f.readlines()
+    tg = remove_empty_lines(tg)
+    tg = TextGrid(tg)
+    tg = json.loads(tg.toJson())
+    split = np.ones(len(ph_list) + 1, np.float) * -1
+    tg_idx = 0
+    ph_idx = 0
+    tg_align = [x for x in tg['tiers'][-1]['items']]
+    tg_align_ = []
+    for x in tg_align:
+        x['xmin'] = float(x['xmin'])
+        x['xmax'] = float(x['xmax'])
+        if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']:
+            x['text'] = ''
+            if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
+                tg_align_[-1]['xmax'] = x['xmax']
+                continue
+        tg_align_.append(x)
+    tg_align = tg_align_
+    tg_len = len([x for x in tg_align if x['text'] != ''])
+    ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
+    assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn)
+    while tg_idx < len(tg_align) or ph_idx < len(ph_list):
+        if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
+            split[ph_idx] = 1e8
+            ph_idx += 1
+            continue
+        x = tg_align[tg_idx]
+        if x['text'] == '' and ph_idx == len(ph_list):
+            tg_idx += 1
+            continue
+        assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn)
+        ph = ph_list[ph_idx]
+        if x['text'] == '' and not is_sil_phoneme(ph):
+            assert False, (ph_list, tg_align)
+        if x['text'] != '' and is_sil_phoneme(ph):
+            ph_idx += 1
+        else:
+            assert (x['text'] == '' and is_sil_phoneme(ph)) \
+                   or x['text'].lower() == ph.lower() \
+                   or x['text'].lower() == 'sil', (x['text'], ph)
+            split[ph_idx] = x['xmin']
+            if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]):
+                split[ph_idx - 1] = split[ph_idx]
+            ph_idx += 1
+            tg_idx += 1
+    assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
+    assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn)
+    mel2ph = np.zeros([mel.shape[0]], np.int)
+    split[0] = 0
+    split[-1] = 1e8
+    for i in range(len(split) - 1):
+        assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],)
+    split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split]
+    for ph_idx in range(len(ph_list)):
+        mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1
+    mel2ph_torch = torch.from_numpy(mel2ph)
+    T_t = len(ph_list)
+    dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch))
+    dur = dur[1:].numpy()
+    return mel2ph, dur
+
+
+def build_phone_encoder(data_dir):
+    phone_list_file = os.path.join(data_dir, 'phone_set.json')
+    phone_list = json.load(open(phone_list_file))
+    return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
+
+
+def is_sil_phoneme(p):
+    return not p[0].isalpha()
+
+
+def build_token_encoder(token_list_file):
+    token_list = json.load(open(token_list_file))
+    return TokenTextEncoder(None, vocab_list=token_list, replace_oov='<UNK>')
diff --git a/data_gen/tts/txt_processors/__init__.py b/data_gen/tts/txt_processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bff3e9af7d634363116c6605f22a52aad614dea
--- /dev/null
+++ b/data_gen/tts/txt_processors/__init__.py
@@ -0,0 +1 @@
+from . import en
\ No newline at end of file
diff --git a/data_gen/tts/txt_processors/base_text_processor.py b/data_gen/tts/txt_processors/base_text_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..69d51201dcb191c1c208ae1c87a34b5c97e6307f
--- /dev/null
+++ b/data_gen/tts/txt_processors/base_text_processor.py
@@ -0,0 +1,47 @@
+from data_gen.tts.data_gen_utils import is_sil_phoneme
+
+REGISTERED_TEXT_PROCESSORS = {}
+
+def register_txt_processors(name):
+    def _f(cls):
+        REGISTERED_TEXT_PROCESSORS[name] = cls
+        return cls
+
+    return _f
+
+
+def get_txt_processor_cls(name):
+    return REGISTERED_TEXT_PROCESSORS.get(name, None)
+
+
+class BaseTxtProcessor:
+    @staticmethod
+    def sp_phonemes():
+        return ['|']
+
+    @classmethod
+    def process(cls, txt, preprocess_args):
+        raise NotImplementedError
+
+    @classmethod
+    def postprocess(cls, txt_struct, preprocess_args):
+        # remove sil phoneme in head and tail
+        while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[0][0]):
+            txt_struct = txt_struct[1:]
+        while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[-1][0]):
+            txt_struct = txt_struct[:-1]
+        if preprocess_args['with_phsep']:
+            txt_struct = cls.add_bdr(txt_struct)
+        if preprocess_args['add_eos_bos']:
+            txt_struct = [["<BOS>", ["<BOS>"]]] + txt_struct + [["<EOS>", ["<EOS>"]]]
+        return txt_struct
+
+    @classmethod
+    def add_bdr(cls, txt_struct):
+        txt_struct_ = []
+        for i, ts in enumerate(txt_struct):
+            txt_struct_.append(ts)
+            if i != len(txt_struct) - 1 and \
+                    not is_sil_phoneme(txt_struct[i][0]) and not is_sil_phoneme(txt_struct[i + 1][0]):
+                txt_struct_.append(['|', ['|']])
+        return txt_struct_
\ No newline at end of file
diff --git a/data_gen/tts/txt_processors/en.py b/data_gen/tts/txt_processors/en.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f755d5ab1f2cf4407daee08cc3639a05e941a97
--- /dev/null
+++ b/data_gen/tts/txt_processors/en.py
@@ -0,0 +1,77 @@
+import re
+import unicodedata
+
+from g2p_en import G2p
+from g2p_en.expand import normalize_numbers
+from nltk import pos_tag
+from nltk.tokenize import TweetTokenizer
+
+from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor, register_txt_processors
+from data_gen.tts.data_gen_utils import is_sil_phoneme, PUNCS
+
+class EnG2p(G2p):
+    word_tokenize = TweetTokenizer().tokenize
+
+    def __call__(self, text):
+        # preprocessing
+        words = EnG2p.word_tokenize(text)
+        tokens = pos_tag(words)  # tuples of (word, tag)
+
+        # steps
+        prons = []
+        for word, pos in tokens:
+            if re.search("[a-z]", word) is None:
+                pron = [word]
+
+            elif word in self.homograph2features:  # Check homograph
+                pron1, pron2, pos1 = self.homograph2features[word]
+                if pos.startswith(pos1):
+                    pron = pron1
+                else:
+                    pron = pron2
+            elif word in self.cmu:  # lookup CMU dict
+                pron = self.cmu[word][0]
+            else:  # predict for oov
+                pron = self.predict(word)
+
+            prons.extend(pron)
+            prons.extend([" "])
+
+        return prons[:-1]
+
+
+@register_txt_processors('en')
+class TxtProcessor(BaseTxtProcessor):
+    g2p = EnG2p()
+
+    @staticmethod
+    def preprocess_text(text):
+        text = normalize_numbers(text)
+        text = ''.join(char for char in unicodedata.normalize('NFD', text)
+                       if unicodedata.category(char) != 'Mn')  # Strip accents
+        text = text.lower()
+        text = re.sub("[\'\"()]+", "", text)
+        text = re.sub("[-]+", " ", text)
+        text = re.sub(f"[^ a-z{PUNCS}]", "", text)
+        text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text)  # !! -> !
+        text = re.sub(f"([{PUNCS}])+", r"\1", text)  # !! -> !
+        text = text.replace("i.e.", "that is")
+        text = text.replace("i.e.", "that is")
+        text = text.replace("etc.", "etc")
+        text = re.sub(f"([{PUNCS}])", r" \1 ", text)
+        text = re.sub(rf"\s+", r" ", text)
+        return text
+
+    @classmethod
+    def process(cls, txt, preprocess_args):
+        txt = cls.preprocess_text(txt).strip()
+        phs = cls.g2p(txt)
+        txt_struct = [[w, []] for w in txt.split(" ")]
+        i_word = 0
+        for p in phs:
+            if p == ' ':
+                i_word += 1
+            else:
+                txt_struct[i_word][1].append(p)
+        txt_struct = cls.postprocess(txt_struct, preprocess_args)
+        return txt_struct, txt
\ No newline at end of file
diff --git a/data_gen/tts/wav_processors/__init__.py b/data_gen/tts/wav_processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4be97b377dcb95a0e6bceb876ac0ce93c8290249
--- /dev/null
+++ b/data_gen/tts/wav_processors/__init__.py
@@ -0,0 +1,2 @@
+from . import base_processor
+from . import common_processors
diff --git a/data_gen/tts/wav_processors/base_processor.py b/data_gen/tts/wav_processors/base_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8200dc58a9388ac94a5ec34b8a65f75e380255b
--- /dev/null
+++ b/data_gen/tts/wav_processors/base_processor.py
@@ -0,0 +1,25 @@
+REGISTERED_WAV_PROCESSORS = {}
+
+
+def register_wav_processors(name):
+    def _f(cls):
+        REGISTERED_WAV_PROCESSORS[name] = cls
+        return cls
+
+    return _f
+
+
+def get_wav_processor_cls(name):
+    return REGISTERED_WAV_PROCESSORS.get(name, None)
+
+
+class BaseWavProcessor:
+    @property
+    def name(self):
+        raise NotImplementedError
+
+    def output_fn(self, input_fn):
+        return f'{input_fn[:-4]}_{self.name}.wav'
+
+    def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args):
+        raise NotImplementedError
diff --git a/data_gen/tts/wav_processors/common_processors.py b/data_gen/tts/wav_processors/common_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..de0b49f4a31cb6737f2cffc6c8d010d88d11c853
--- /dev/null
+++ b/data_gen/tts/wav_processors/common_processors.py
@@ -0,0 +1,86 @@
+import os
+import subprocess
+import librosa
+import numpy as np
+from data_gen.tts.wav_processors.base_processor import BaseWavProcessor, register_wav_processors
+from data_gen.tts.data_gen_utils import trim_long_silences
+from utils.audio import save_wav
+from utils.rnnoise import rnnoise
+from utils.hparams import hparams
+
+
+@register_wav_processors(name='sox_to_wav')
+class ConvertToWavProcessor(BaseWavProcessor):
+    @property
+    def name(self):
+        return 'ToWav'
+
+    def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args):
+        if input_fn[-4:] == '.wav':
+            return input_fn, sr
+        else:
+            output_fn = self.output_fn(input_fn)
+            subprocess.check_call(f'sox -v 0.95 "{input_fn}" -t wav "{output_fn}"', shell=True)
+            return output_fn, sr
+
+
+@register_wav_processors(name='sox_resample')
+class ResampleProcessor(BaseWavProcessor):
+    @property
+    def name(self):
+        return 'Resample'
+
+    def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args):
+        output_fn = self.output_fn(input_fn)
+        sr_file = librosa.core.get_samplerate(input_fn)
+        if sr != sr_file:
+            subprocess.check_call(f'sox -v 0.95 "{input_fn}" -r{sr} "{output_fn}"', shell=True)
+            y, _ = librosa.core.load(input_fn, sr=sr)
+            y, _ = librosa.effects.trim(y)
+            save_wav(y, output_fn, sr)
+            return output_fn, sr
+        else:
+            return input_fn, sr
+
+
+@register_wav_processors(name='trim_sil')
+class TrimSILProcessor(BaseWavProcessor):
+    @property
+    def name(self):
+        return 'TrimSIL'
+
+    def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args):
+        output_fn = self.output_fn(input_fn)
+        y, _ = librosa.core.load(input_fn, sr=sr)
+        y, _ = librosa.effects.trim(y)
+        save_wav(y, output_fn, sr)
+        return output_fn
+
+
+@register_wav_processors(name='trim_all_sil')
+class TrimAllSILProcessor(BaseWavProcessor):
+    @property
+    def name(self):
+        return 'TrimSIL'
+
+    def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args):
+        output_fn = self.output_fn(input_fn)
+        y, audio_mask, _ = trim_long_silences(
+            input_fn, vad_max_silence_length=preprocess_args.get('vad_max_silence_length', 12))
+        save_wav(y, output_fn, sr)
+        if preprocess_args['save_sil_mask']:
+            os.makedirs(f'{processed_dir}/sil_mask', exist_ok=True)
+            np.save(f'{processed_dir}/sil_mask/{item_name}.npy', audio_mask)
+        return output_fn, sr
+
+
+@register_wav_processors(name='denoise')
+class DenoiseProcessor(BaseWavProcessor):
+    @property
+    def name(self):
+        return 'Denoise'
+
+    def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args):
+        output_fn = self.output_fn(input_fn)
+        rnnoise(input_fn, output_fn, out_sample_rate=sr)
+        return output_fn, sr
diff --git a/egs/datasets/audio/libritts/base_text2mel.yaml b/egs/datasets/audio/libritts/base_text2mel.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..85c389a45a7bc2be4927867d07fa881754814a2c
--- /dev/null
+++ b/egs/datasets/audio/libritts/base_text2mel.yaml
@@ -0,0 +1,14 @@
+raw_data_dir: 'data/raw/LibriTTS'
+processed_data_dir: 'data/processed/libritts'
+binary_data_dir: 'data/binary/libritts'
+pre_align_cls: egs.datasets.audio.libritts.pre_align.LibrittsPreAlign
+binarization_args:
+  shuffle: true
+use_spk_id: true
+test_num: 200
+num_spk: 2320
+pitch_type: frame
+min_frames: 128
+num_test_samples: 30
+mel_loss: "ssim:0.5|l1:0.5"
+vocoder_ckpt: ''
\ No newline at end of file
diff --git a/egs/datasets/audio/libritts/fs2.yaml b/egs/datasets/audio/libritts/fs2.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..95ae09730837aa10606f9771f41f40429b99d9ac
--- /dev/null
+++ b/egs/datasets/audio/libritts/fs2.yaml
@@ -0,0 +1,3 @@
+base_config:
+  - egs/egs_bases/tts/fs2.yaml
+  - ./base_text2mel.yaml
diff --git a/egs/datasets/audio/libritts/pre_align.py b/egs/datasets/audio/libritts/pre_align.py
new file mode 100755
index 0000000000000000000000000000000000000000..335b43d913edb02fb02e10c2479aa3dd9e07bb2f
--- /dev/null
+++ b/egs/datasets/audio/libritts/pre_align.py
@@ -0,0 +1,18 @@
+import os
+
+from data_gen.tts.base_pre_align import BasePreAlign
+import glob
+
+
+class LibrittsPreAlign(BasePreAlign):
+    def meta_data(self):
+        wav_fns = sorted(glob.glob(f'{self.raw_data_dir}/*/*/*/*.wav'))
+        for wav_fn in wav_fns:
+            item_name = os.path.basename(wav_fn)[:-4]
+            txt_fn = f'{wav_fn[:-4]}.normalized.txt'
+            spk = item_name.split("_")[0]
+            yield item_name, wav_fn, (self.load_txt, txt_fn), spk
+
+
+if __name__ == "__main__":
+    LibrittsPreAlign().process()
diff --git a/egs/datasets/audio/libritts/pwg.yaml b/egs/datasets/audio/libritts/pwg.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..a0fd70869274c3f8a7ac02a9ca5d7e1202dc895d
--- /dev/null
+++ b/egs/datasets/audio/libritts/pwg.yaml
@@ -0,0 +1,8 @@
+base_config: egs/egs_bases/tts/vocoder/pwg.yaml
+raw_data_dir: 'data/raw/LibriTTS'
+processed_data_dir: 'data/processed/libritts'
+binary_data_dir: 'data/binary/libritts_wav'
+generator_params:
+  kernel_size: 5
+num_spk: 400
+max_samples: 20480
diff --git a/egs/datasets/audio/lj/base_mel2wav.yaml b/egs/datasets/audio/lj/base_mel2wav.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..df4355bc38d0568c0b3acfa4f7fc040cef5995d6
--- /dev/null
+++ b/egs/datasets/audio/lj/base_mel2wav.yaml
@@ -0,0 +1,5 @@
+raw_data_dir: 'data/raw/LJSpeech-1.1'
+processed_data_dir: 'data/processed/ljspeech'
+binary_data_dir: 'data/binary/ljspeech_wav'
+binarization_args:
+  with_spk_embed: false
\ No newline at end of file
diff --git a/egs/datasets/audio/lj/pre_align.py b/egs/datasets/audio/lj/pre_align.py
new file mode 100755
index 0000000000000000000000000000000000000000..847b9f87b4e74cd634dd5bb2313f78afd5602ad7
--- /dev/null
+++ b/egs/datasets/audio/lj/pre_align.py
@@ -0,0 +1,13 @@
+from data_gen.tts.base_preprocess import BasePreprocessor
+
+
+class LJPreAlign(BasePreprocessor):
+    def meta_data(self):
+        for l in open(f'{self.raw_data_dir}/metadata.csv').readlines():
+            item_name, _, txt = l.strip().split("|")
+            wav_fn = f"{self.raw_data_dir}/wavs/{item_name}.wav"
+            yield item_name, wav_fn, txt, 'SPK1'
+
+
+if __name__ == "__main__":
+    LJPreAlign().process()
diff --git a/egs/datasets/audio/lj/pwg.yaml b/egs/datasets/audio/lj/pwg.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..e0c6dc6da4367bad9bafd1a7ba492a5cfb18a347
--- /dev/null
+++ b/egs/datasets/audio/lj/pwg.yaml
@@ -0,0 +1,3 @@
+base_config:
+  - egs/egs_bases/tts/vocoder/pwg.yaml
+  - ./base_mel2wav.yaml
\ No newline at end of file
diff --git a/egs/datasets/audio/vctk/base_mel2wav.yaml b/egs/datasets/audio/vctk/base_mel2wav.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..b5210a1361259d30449430f70849061d17a8e59f
--- /dev/null
+++ b/egs/datasets/audio/vctk/base_mel2wav.yaml
@@ -0,0 +1,3 @@
+raw_data_dir: 'data/raw/VCTK-Corpus'
+processed_data_dir: 'data/processed/vctk'
+binary_data_dir: 'data/binary/vctk_wav'
diff --git a/egs/datasets/audio/vctk/fs2.yaml b/egs/datasets/audio/vctk/fs2.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..49bb983bd98caca74af0ea70414e71aa34334b38
--- /dev/null
+++ b/egs/datasets/audio/vctk/fs2.yaml
@@ -0,0 +1,12 @@
+base_config:
+  - egs/egs_bases/tts/fs2.yaml
+raw_data_dir: 'data/raw/VCTK-Corpus'
+processed_data_dir: 'data/processed/vctk'
+binary_data_dir: 'data/binary/vctk'
+pre_align_cls: egs.datasets.audio.vctk.pre_align.VCTKPreAlign
+use_spk_id: true
+test_num: 200
+num_spk: 400
+binarization_args:
+  shuffle: true
+  trim_eos_bos: true
\ No newline at end of file
diff --git a/egs/datasets/audio/vctk/pre_align.py b/egs/datasets/audio/vctk/pre_align.py
new file mode 100755
index 0000000000000000000000000000000000000000..a03b3e12af245fa603403432f4487c53e8b13eab
--- /dev/null
+++ b/egs/datasets/audio/vctk/pre_align.py
@@ -0,0 +1,22 @@
+import os
+
+from data_gen.tts.base_pre_align import BasePreAlign
+import glob
+
+
+class VCTKPreAlign(BasePreAlign):
+    def meta_data(self):
+        wav_fns = glob.glob(f'{self.raw_data_dir}/wav48/*/*.wav')
+        for wav_fn in wav_fns:
+            item_name = os.path.basename(wav_fn)[:-4]
+            spk = item_name.split("_")[0]
+            txt_fn = wav_fn.split("/")
+            txt_fn[-1] = f'{item_name}.txt'
+            txt_fn[-3] = f'txt'
+            txt_fn = "/".join(txt_fn)
+            if os.path.exists(txt_fn) and os.path.exists(wav_fn):
+                yield item_name, wav_fn, (self.load_txt, txt_fn), spk
+
+
+if __name__ == "__main__":
+    VCTKPreAlign().process()
diff --git a/egs/datasets/audio/vctk/pwg.yaml b/egs/datasets/audio/vctk/pwg.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..8c7c557fee02b0a62881f36a2abe010ede0d5f39
--- /dev/null
+++ b/egs/datasets/audio/vctk/pwg.yaml
@@ -0,0 +1,6 @@
+base_config:
+  - egs/egs_bases/tts/vocoder/pwg.yaml
+  - ./base_mel2wav.yaml
+
+num_spk: 400
+max_samples: 20480
diff --git a/egs/egs_bases/config_base.yaml b/egs/egs_bases/config_base.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..39240ceb8cd52e4c0a146aa579f1ea046c727c84
--- /dev/null
+++ b/egs/egs_bases/config_base.yaml
@@ -0,0 +1,46 @@
+# task
+binary_data_dir: ''
+work_dir: '' # experiment directory.
+infer: false # inference
+amp: false
+seed: 1234
+debug: false
+save_codes: []
+#  - configs
+#  - modules
+#  - tasks
+#  - utils
+#  - usr
+
+#############
+# dataset
+#############
+ds_workers: 1
+test_num: 100
+endless_ds: false
+sort_by_len: true
+
+#########
+# train and eval
+#########
+print_nan_grads: false
+load_ckpt: ''
+save_best: true
+num_ckpt_keep: 3
+clip_grad_norm: 0
+accumulate_grad_batches: 1
+tb_log_interval: 100
+num_sanity_val_steps: 5  # steps of validation at the beginning
+check_val_every_n_epoch: 10
+val_check_interval: 2000
+valid_monitor_key: 'val_loss'
+valid_monitor_mode: 'min'
+max_epochs: 1000
+max_updates: 1000000
+max_tokens: 31250
+max_sentences: 100000
+max_valid_tokens: -1
+max_valid_sentences: -1
+test_input_dir: ''
+resume_from_checkpoint: 0
+rename_tmux: true
\ No newline at end of file
diff --git a/egs/egs_bases/tts/base.yaml b/egs/egs_bases/tts/base.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..255e4cce1bb52d58d24443bc277621d70732583f
--- /dev/null
+++ b/egs/egs_bases/tts/base.yaml
@@ -0,0 +1,112 @@
+# task
+base_config: ../config_base.yaml
+task_cls: ''
+#############
+# dataset
+#############
+raw_data_dir: ''
+processed_data_dir: ''
+binary_data_dir: ''
+dict_dir: ''
+pre_align_cls: ''
+binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
+pre_align_args:
+  txt_processor: en
+  use_tone: true # for ZH
+  sox_resample: false
+  sox_to_wav: false
+  allow_no_txt: false
+  trim_sil: false
+  denoise: false
+binarization_args:
+  shuffle: false
+  with_txt: true
+  with_wav: false
+  with_align: true
+  with_spk_embed: false
+  with_spk_id: true
+  with_f0: true
+  with_f0cwt: false
+  with_linear: false
+  with_word: true
+  trim_sil: false
+  trim_eos_bos: false
+  reset_phone_dict: true
+  reset_word_dict: true
+word_size: 30000
+pitch_extractor: parselmouth
+
+loud_norm: false
+endless_ds: true
+
+test_num: 100
+min_frames: 0
+max_frames: 1548
+frames_multiple: 1
+max_input_tokens: 1550
+audio_num_mel_bins: 80
+audio_sample_rate: 22050
+hop_size: 256  # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
+win_size: 1024  # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
+fmin: 80  # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+fmax: 7600  # To be increased/reduced depending on data.
+fft_size: 1024  # Extra window size is filled with 0 paddings to match this parameter
+min_level_db: -100
+ref_level_db: 20
+griffin_lim_iters: 60
+num_spk: 1
+mel_vmin: -6
+mel_vmax: 1.5
+ds_workers: 1
+
+#########
+# model
+#########
+dropout: 0.1
+enc_layers: 4
+dec_layers: 4
+hidden_size: 256
+num_heads: 2
+enc_ffn_kernel_size: 9
+dec_ffn_kernel_size: 9
+ffn_act: gelu
+ffn_padding: 'SAME'
+use_spk_id: false
+use_split_spk_id: false
+use_spk_embed: false
+
+
+###########
+# optimization
+###########
+lr: 2.0
+scheduler: rsqrt # rsqrt|none
+warmup_updates: 8000
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+weight_decay: 0
+clip_grad_norm: 1
+clip_grad_value: 0
+
+
+###########
+# train and eval
+###########
+max_tokens: 30000
+max_sentences: 100000
+max_valid_sentences: 1
+max_valid_tokens: 60000
+valid_infer_interval: 10000
+train_set_name: 'train'
+train_sets: ''
+valid_set_name: 'valid'
+test_set_name: 'test'
+num_test_samples: 0
+num_valid_plots: 10
+test_ids: [ ]
+vocoder_denoise_c: 0.0
+profile_infer: false
+out_wav_norm: false
+save_gt: true
+save_f0: false
+gen_dir_name: ''
\ No newline at end of file
diff --git a/egs/egs_bases/tts/fs2.yaml b/egs/egs_bases/tts/fs2.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..c200a50ad2e04ff685bd30c7d5c3be69ff7cdf3f
--- /dev/null
+++ b/egs/egs_bases/tts/fs2.yaml
@@ -0,0 +1,102 @@
+base_config: ./base.yaml
+task_cls: tasks.tts.fs2.FastSpeech2Task
+
+# model
+hidden_size: 256
+dropout: 0.1
+encoder_type: fft # rel_fft|fft|tacotron|tacotron2|conformer
+decoder_type: fft # fft|rnn|conv|conformer|wn
+
+# rnn enc/dec
+encoder_K: 8
+decoder_rnn_dim: 0 # for rnn decoder, 0 -> hidden_size * 2
+
+# fft enc/dec
+use_pos_embed: true
+dec_num_heads: 2
+dec_layers: 4
+ffn_hidden_size: 1024
+enc_ffn_kernel_size: 9
+dec_ffn_kernel_size: 9
+
+# conv enc/dec
+enc_dec_norm: ln
+conv_use_pos: false
+layers_in_block: 2
+enc_dilations: [ 1, 1, 1, 1 ]
+enc_kernel_size: 5
+dec_dilations: [ 1, 1, 1, 1 ] # for conv decoder
+dec_kernel_size: 5
+dur_loss: mse # huber|mol
+
+# duration
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 2
+dur_predictor_kernel: 3
+dur_predictor_layers: 2
+predictor_dropout: 0.5
+
+# pitch and energy
+pitch_norm: standard # standard|log
+use_pitch_embed: true
+pitch_type: frame # frame|ph|cwt
+use_uv: true
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_add_f0_loss: false
+cwt_std_scale: 0.8
+
+pitch_ar: false
+pitch_embed_type: 0
+pitch_loss: 'l1' # l1|l2|ssim
+pitch_ssim_win: 11
+use_energy_embed: false
+
+# reference encoder and speaker embedding
+use_ref_enc: false
+use_var_enc: false
+lambda_commit: 0.25
+var_enc_vq_codes: 64
+ref_norm_layer: bn
+dec_inp_add_noise: false
+sil_add_noise: false
+ref_hidden_stride_kernel:
+  - 0,3,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
+  - 0,3,5
+  - 0,2,5
+  - 0,2,5
+  - 0,2,5
+pitch_enc_hidden_stride_kernel:
+  - 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
+  - 0,2,5
+  - 0,2,5
+dur_enc_hidden_stride_kernel:
+  - 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
+  - 0,2,3
+  - 0,1,3
+
+# mel
+mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5
+
+# loss lambda
+lambda_f0: 1.0
+lambda_uv: 1.0
+lambda_energy: 0.1
+lambda_ph_dur: 0.1
+lambda_sent_dur: 1.0
+lambda_word_dur: 1.0
+predictor_grad: 0.1
+
+# train and eval
+pretrain_fs_ckpt: ''
+warmup_updates: 2000
+max_tokens: 32000
+max_sentences: 100000
+max_valid_sentences: 1
+max_updates: 120000
+use_gt_dur: false
+use_gt_f0: false
+ds_workers: 2
+lr: 1.0
diff --git a/egs/egs_bases/tts/vocoder/base.yaml b/egs/egs_bases/tts/vocoder/base.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..92a1a74e269f5245af09a1bc00a5998f069b6801
--- /dev/null
+++ b/egs/egs_bases/tts/vocoder/base.yaml
@@ -0,0 +1,34 @@
+base_config: ../base.yaml
+binarization_args:
+  with_wav: true
+  with_spk_embed: false
+  with_align: false
+  with_word: false
+  with_txt: false
+
+###########
+# train and eval
+###########
+max_samples: 25600
+max_sentences: 5
+max_valid_sentences: 1
+max_updates: 1000000
+val_check_interval: 2000
+
+###########################################################
+#                FEATURE EXTRACTION SETTING               #
+###########################################################
+fft_size: 1024           # FFT size.
+hop_size: 256            # Hop size.
+win_length: null         # Window length.
+# If set to null, it will be the same as fft_size.
+window: "hann"           # Window function.
+num_mels: 80             # Number of mel basis.
+fmin: 80                 # Minimum freq in mel basis calculation.
+fmax: 7600               # Maximum frequency in mel basis calculation.
+aux_context_window: 0 # Context window size for auxiliary feature.
+use_pitch_embed: false
+
+generator_grad_norm: 10    # Generator's gradient norm.
+discriminator_grad_norm: 1 # Discriminator's gradient norm.
+disc_start_steps: 40000 # Number of steps to start to train discriminator.
diff --git a/egs/egs_bases/tts/vocoder/pwg.yaml b/egs/egs_bases/tts/vocoder/pwg.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..2d95bbd92abcdf70fb7a38b15877509390696fd2
--- /dev/null
+++ b/egs/egs_bases/tts/vocoder/pwg.yaml
@@ -0,0 +1,82 @@
+base_config: ./base.yaml
+task_cls: tasks.vocoder.pwg.PwgTask
+
+aux_context_window: 2 # Context window size for auxiliary feature.
+use_pitch_embed: false
+###########################################################
+#         GENERATOR NETWORK ARCHITECTURE SETTING          #
+###########################################################
+generator_params:
+  in_channels: 1        # Number of input channels.
+  out_channels: 1       # Number of output channels.
+  kernel_size: 3        # Kernel size of dilated convolution.
+  layers: 30            # Number of residual block layers.
+  stacks: 3             # Number of stacks i.e., dilation cycles.
+  residual_channels: 64 # Number of channels in residual conv.
+  gate_channels: 128    # Number of channels in gated conv.
+  skip_channels: 64     # Number of channels in skip conv.
+  aux_channels: 80      # Number of channels for auxiliary feature conv.
+  # Must be the same as num_mels.
+  # If set to 2, previous 2 and future 2 frames will be considered.
+  dropout: 0.0          # Dropout rate. 0.0 means no dropout applied.
+  use_weight_norm: true # Whether to use weight norm.
+  # If set to true, it will be applied to all of the conv layers.
+  upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
+  upsample_params:                      # Upsampling network parameters.
+    upsample_scales: [4, 4, 4, 4]     # Upsampling scales. Prodcut of these must be the same as hop size.
+  use_pitch_embed: false
+  use_nsf: false
+###########################################################
+#       DISCRIMINATOR NETWORK ARCHITECTURE SETTING        #
+###########################################################
+discriminator_params:
+  in_channels: 1        # Number of input channels.
+  out_channels: 1       # Number of output channels.
+  kernel_size: 3        # Number of output channels.
+  layers: 10            # Number of conv layers.
+  conv_channels: 64     # Number of chnn layers.
+  bias: true            # Whether to use bias parameter in conv.
+  use_weight_norm: true # Whether to use weight norm.
+  # If set to true, it will be applied to all of the conv layers.
+  nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
+  nonlinear_activation_params:      # Nonlinear function parameters
+    negative_slope: 0.2           # Alpha in LeakyReLU.
+rerun_gen: true
+
+###########################################################
+#                   STFT LOSS SETTING                     #
+###########################################################
+stft_loss_params:
+  fft_sizes: [1024, 2048, 512]  # List of FFT size for STFT-based loss.
+  hop_sizes: [120, 240, 50]     # List of hop size for STFT-based loss
+  win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
+  window: "hann_window"         # Window function for STFT-based loss
+use_mel_loss: false
+
+###########################################################
+#               ADVERSARIAL LOSS SETTING                  #
+###########################################################
+lambda_adv: 4.0  # Loss balancing coefficient.
+
+###########################################################
+#             OPTIMIZER & SCHEDULER SETTING               #
+###########################################################
+generator_optimizer_params:
+  lr: 0.0001             # Generator's learning rate.
+  eps: 1.0e-6            # Generator's epsilon.
+  weight_decay: 0.0      # Generator's weight decay coefficient.
+generator_scheduler_params:
+  step_size: 200000      # Generator's scheduler step size.
+  gamma: 0.5             # Generator's scheduler gamma.
+  # At each step size, lr will be multiplied by this parameter.
+generator_grad_norm: 10    # Generator's gradient norm.
+discriminator_optimizer_params:
+  lr: 0.00005            # Discriminator's learning rate.
+  eps: 1.0e-6            # Discriminator's epsilon.
+  weight_decay: 0.0      # Discriminator's weight decay coefficient.
+discriminator_scheduler_params:
+  step_size: 200000      # Discriminator's scheduler step size.
+  gamma: 0.5             # Discriminator's scheduler gamma.
+  # At each step size, lr will be multiplied by this parameter.
+discriminator_grad_norm: 1 # Discriminator's gradient norm.
+disc_start_steps: 40000 # Number of steps to start to train discriminator.
diff --git a/inference/ProDiff.py b/inference/ProDiff.py
new file mode 100644
index 0000000000000000000000000000000000000000..945497f353e50c7b9314476a615f4b1b94a2a940
--- /dev/null
+++ b/inference/ProDiff.py
@@ -0,0 +1,49 @@
+import torch
+from inference.base_tts_infer import BaseTTSInfer
+from utils.ckpt_utils import load_ckpt, get_last_checkpoint
+from utils.hparams import hparams
+from modules.ProDiff.model.ProDiff import GaussianDiffusion
+from usr.diff.net import DiffNet
+import os
+import numpy as np
+from functools import partial
+
+class ProDiffInfer(BaseTTSInfer):
+    def build_model(self):
+        f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy'
+        if os.path.exists(f0_stats_fn):
+            hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn)
+            hparams['f0_mean'] = float(hparams['f0_mean'])
+            hparams['f0_std'] = float(hparams['f0_std'])
+        model = GaussianDiffusion(
+            phone_encoder=self.ph_encoder,
+            out_dims=80, denoise_fn=DiffNet(hparams['audio_num_mel_bins']),
+            timesteps=hparams['timesteps'],
+            loss_type=hparams['diff_loss_type'],
+            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+        )
+        checkpoint = torch.load(hparams['teacher_ckpt'], map_location='cpu')["state_dict"]['model']
+        teacher_timesteps = int(checkpoint['timesteps'].item())
+        teacher_timescales = int(checkpoint['timescale'].item())
+        student_timesteps = teacher_timesteps // 2
+        student_timescales = teacher_timescales * 2
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+        model.register_buffer('timesteps', to_torch(student_timesteps))      # beta
+        model.register_buffer('timescale', to_torch(student_timescales))      # beta
+        model.eval()
+        load_ckpt(model, hparams['work_dir'], 'model')
+        return model
+
+    def forward_model(self, inp):
+        sample = self.input_to_batch(inp)
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+        with torch.no_grad():
+            output = self.model(txt_tokens, infer=True)
+            mel_out = output['mel_out']
+            wav_out = self.run_vocoder(mel_out)
+        wav_out = wav_out.squeeze().cpu().numpy()
+        return wav_out
+
+
+if __name__ == '__main__':
+    ProDiffInfer.example_run()
diff --git a/inference/ProDiff_Teacher.py b/inference/ProDiff_Teacher.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e10278e1864d0709c8667a7dd2106dfb4b1cd38
--- /dev/null
+++ b/inference/ProDiff_Teacher.py
@@ -0,0 +1,41 @@
+import torch
+from inference.base_tts_infer import BaseTTSInfer
+from utils.ckpt_utils import load_ckpt, get_last_checkpoint
+from utils.hparams import hparams
+from modules.ProDiff.model.ProDiff_teacher import GaussianDiffusion
+from usr.diff.net import DiffNet
+import os
+import numpy as np
+
+class ProDiffTeacherInfer(BaseTTSInfer):
+    def build_model(self):
+        f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy'
+        if os.path.exists(f0_stats_fn):
+            hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn)
+            hparams['f0_mean'] = float(hparams['f0_mean'])
+            hparams['f0_std'] = float(hparams['f0_std'])
+        model = GaussianDiffusion(
+            phone_encoder=self.ph_encoder,
+            out_dims=80, denoise_fn=DiffNet(hparams['audio_num_mel_bins']),
+            timesteps=hparams['timesteps'],
+            loss_type=hparams['diff_loss_type'],
+            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+        )
+
+        model.eval()
+        load_ckpt(model, hparams['work_dir'], 'model')
+        return model
+
+    def forward_model(self, inp):
+        sample = self.input_to_batch(inp)
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+        with torch.no_grad():
+            output = self.model(txt_tokens, infer=True)
+            mel_out = output['mel_out']
+            wav_out = self.run_vocoder(mel_out)
+        wav_out = wav_out.squeeze().cpu().numpy()
+        return wav_out
+
+
+if __name__ == '__main__':
+    ProDiffTeacherInfer.example_run()
diff --git a/inference/base_tts_infer.py b/inference/base_tts_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f34f207ace872dc6f075cf645a5692c536c640b6
--- /dev/null
+++ b/inference/base_tts_infer.py
@@ -0,0 +1,167 @@
+import os
+
+import torch
+
+from tasks.tts.dataset_utils import FastSpeechWordDataset
+from tasks.tts.tts_utils import load_data_preprocessor
+import numpy as np
+from modules.FastDiff.module.util import compute_hyperparams_given_schedule, sampling_given_noise_schedule
+
+import os
+
+import torch
+
+from modules.FastDiff.module.FastDiff_model import FastDiff
+from utils.ckpt_utils import load_ckpt
+from utils.hparams import set_hparams
+
+
+class BaseTTSInfer:
+    def __init__(self, hparams, device=None):
+        if device is None:
+            device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        self.hparams = hparams
+        self.device = device
+        self.data_dir = hparams['binary_data_dir']
+        self.preprocessor, self.preprocess_args = load_data_preprocessor()
+        self.ph_encoder = self.preprocessor.load_dict(self.data_dir)
+        self.spk_map = self.preprocessor.load_spk_map(self.data_dir)
+        self.ds_cls = FastSpeechWordDataset
+        self.model = self.build_model()
+        self.model.eval()
+        self.model.to(self.device)
+        self.vocoder, self.diffusion_hyperparams, self.noise_schedule = self.build_vocoder()
+        self.vocoder.eval()
+        self.vocoder.to(self.device)
+
+    def build_model(self):
+        raise NotImplementedError
+
+    def forward_model(self, inp):
+        raise NotImplementedError
+
+    def build_vocoder(self):
+        base_dir = self.hparams['vocoder_ckpt']
+        config_path = f'{base_dir}/config.yaml'
+        config = set_hparams(config_path, global_hparams=False)
+        vocoder = FastDiff(audio_channels=config['audio_channels'],
+                 inner_channels=config['inner_channels'],
+                 cond_channels=config['cond_channels'],
+                 upsample_ratios=config['upsample_ratios'],
+                 lvc_layers_each_block=config['lvc_layers_each_block'],
+                 lvc_kernel_size=config['lvc_kernel_size'],
+                 kpnet_hidden_channels=config['kpnet_hidden_channels'],
+                 kpnet_conv_size=config['kpnet_conv_size'],
+                 dropout=config['dropout'],
+                 diffusion_step_embed_dim_in=config['diffusion_step_embed_dim_in'],
+                 diffusion_step_embed_dim_mid=config['diffusion_step_embed_dim_mid'],
+                 diffusion_step_embed_dim_out=config['diffusion_step_embed_dim_out'],
+                 use_weight_norm=config['use_weight_norm'])
+        load_ckpt(vocoder, base_dir, 'model')
+
+        # Init hyperparameters by linear schedule
+        noise_schedule = torch.linspace(float(config["beta_0"]), float(config["beta_T"]), int(config["T"]))
+        diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule)
+
+        if config['noise_schedule'] != '':
+            noise_schedule = config['noise_schedule']
+            if isinstance(noise_schedule, list):
+                noise_schedule = torch.FloatTensor(noise_schedule)
+        else:
+            # Select Schedule
+            try:
+                reverse_step = int(self.hparams.get('N'))
+            except:
+                print(
+                    'Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.')
+                reverse_step = 4
+            if reverse_step == 1000:
+                noise_schedule = torch.linspace(0.000001, 0.01, 1000)
+            elif reverse_step == 200:
+                noise_schedule = torch.linspace(0.0001, 0.02, 200)
+
+            # Below are schedules derived by Noise Predictor.
+            # We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned!
+            elif reverse_step == 8:
+                noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513,
+                                  0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593,
+                                  0.5]
+            elif reverse_step == 6:
+                noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984,
+                                  0.006634317338466644, 0.09357017278671265, 0.6000000238418579]
+            elif reverse_step == 4:
+                noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]
+            elif reverse_step == 3:
+                noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01]
+            else:
+                raise NotImplementedError
+
+        if isinstance(noise_schedule, list):
+            noise_schedule = torch.FloatTensor(noise_schedule)
+
+        return vocoder, diffusion_hyperparams, noise_schedule
+
+    def run_vocoder(self, c):
+        c = c.transpose(2, 1)
+        audio_length = c.shape[-1] * self.hparams["hop_size"]
+        y = sampling_given_noise_schedule(
+            self.vocoder, (1, 1, audio_length), self.diffusion_hyperparams, self.noise_schedule, condition=c, ddim=False, return_sequence=False)
+        return y
+
+    def preprocess_input(self, inp):
+        """
+        :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
+        :return:
+        """
+        preprocessor, preprocess_args = self.preprocessor, self.preprocess_args
+        text_raw = inp['text']
+        item_name = inp.get('item_name', '<ITEM_NAME>')
+        spk_name = inp.get('spk_name', 'SPK1')
+        ph, txt = preprocessor.txt_to_ph(
+            preprocessor.txt_processor, text_raw, preprocess_args)
+        ph_token = self.ph_encoder.encode(ph)
+        spk_id = self.spk_map[spk_name]
+        item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, 'ph_token': ph_token}
+        item['ph_len'] = len(item['ph_token'])
+        return item
+
+    def input_to_batch(self, item):
+        item_names = [item['item_name']]
+        text = [item['text']]
+        ph = [item['ph']]
+        txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
+        txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
+        spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)
+        batch = {
+            'item_name': item_names,
+            'text': text,
+            'ph': ph,
+            'txt_tokens': txt_tokens,
+            'txt_lengths': txt_lengths,
+            'spk_ids': spk_ids,
+        }
+        return batch
+
+    def postprocess_output(self, output):
+        return output
+
+    def infer_once(self, inp):
+        inp = self.preprocess_input(inp)
+        output = self.forward_model(inp)
+        output = self.postprocess_output(output)
+        return output
+
+    @classmethod
+    def example_run(cls):
+        from utils.hparams import set_hparams
+        from utils.hparams import hparams as hp
+        from utils.audio import save_wav
+
+        set_hparams()
+        inp = {
+            'text': hp['text']
+        }
+        infer_ins = cls(hp)
+        out = infer_ins.infer_once(inp)
+        os.makedirs('infer_out', exist_ok=True)
+        save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate'])
diff --git a/inference/gradio/gradio_settings.yaml b/inference/gradio/gradio_settings.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6cefb0cd6edbd36e5c544532adf5de7048e5f047
--- /dev/null
+++ b/inference/gradio/gradio_settings.yaml
@@ -0,0 +1,41 @@
+title: 'Extremely-Fast diffusion text-to-speech synthesis pipeline with ProDiff and FastDiff'
+description: |
+  Gradio demo for **2-iter** ProDiff and **4-iter** FastDiff. To use it, simply add your audio, or click one of the examples to load them. **This space is running on CPU, inference will be slower.**
+
+  ## Key Features
+  - **Extremely-Fast** diffusion text-to-speech synthesis pipeline for potential **industrial deployment**.
+  - **Tutorial and code base** for speech diffusion models.
+  - More **supported diffusion mechanism** (e.g., guided diffusion) will be available.
+
+
+article: |
+  ## Reference
+  Link to <a href='https://github.com/Rongjiehuang/ProDiff' style='color:blue;' target='_blank\'>ProDiff Github REPO</a>
+  
+  If you find this code useful in your research, please cite our work:
+  ```
+    @inproceedings{huang2022prodiff,
+      title={ProDiff: Progressive Fast Diffusion Model For High-Quality Text-to-Speech},
+      author={Huang, Rongjie and Zhao, Zhou and Liu, Huadai and Liu, Jinglin and Cui, Chenye and Ren, Yi},
+      booktitle={Proceedings of the 30th ACM International Conference on Multimedia},
+      year={2022}
+  
+    @inproceedings{huang2022fastdiff,
+      title={FastDiff: A Fast Conditional Diffusion Model for High-Quality Speech Synthesis},
+      author={Huang, Rongjie and Lam, Max WY and Wang, Jun and Su, Dan and Yu, Dong and Ren, Yi and Zhao, Zhou},
+      booktitle = {Proceedings of the Thirty-First International Joint Conference on Artificial Intelligence, {IJCAI-22}},
+      year={2022}
+    }
+  ```
+  
+  ## Disclaimer
+  Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
+
+example_inputs:
+  - |-
+    the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.
+  - |-
+    Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition.
+inference_cls: inference.ProDiff.ProDiffInfer
+exp_name: ProDiff
+config: modules/ProDiff/config/prodiff.yaml
\ No newline at end of file
diff --git a/inference/gradio/infer.py b/inference/gradio/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..27acc399c78e024672013cd03048448c22e59df4
--- /dev/null
+++ b/inference/gradio/infer.py
@@ -0,0 +1,69 @@
+import importlib
+import re
+
+import gradio as gr
+import yaml
+from gradio.inputs import Textbox
+
+from inference.base_tts_infer import BaseTTSInfer
+from utils.hparams import set_hparams
+from utils.hparams import hparams as hp
+import numpy as np
+
+from data_gen.tts.data_gen_utils import is_sil_phoneme, PUNCS
+
+class GradioInfer:
+    def __init__(self, exp_name, config, inference_cls, title, description, article, example_inputs):
+        self.exp_name = exp_name
+        self.config = config
+        self.title = title
+        self.description = description
+        self.article = article
+        self.example_inputs = example_inputs
+        pkg = ".".join(inference_cls.split(".")[:-1])
+        cls_name = inference_cls.split(".")[-1]
+        self.inference_cls = getattr(importlib.import_module(pkg), cls_name)
+
+    def greet(self, text):
+        sents = re.split(rf'([{PUNCS}])', text.replace('\n', ','))
+        if sents[-1] not in list(PUNCS):
+            sents = sents + ['.']
+        audio_outs = []
+        s = ""
+        for i in range(0, len(sents), 2):
+            if len(sents[i]) > 0:
+                s += sents[i] + sents[i + 1]
+            if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0):
+                audio_out = self.infer_ins.infer_once({
+                    'text': s
+                })
+                audio_out = audio_out * 32767
+                audio_out = audio_out.astype(np.int16)
+                audio_outs.append(audio_out)
+                audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16))
+                s = ""
+        audio_outs = np.concatenate(audio_outs)
+        return hp['audio_sample_rate'], audio_outs
+
+    def run(self):
+        set_hparams(exp_name=self.exp_name, config=self.config)
+        infer_cls = self.inference_cls
+        self.infer_ins: BaseTTSInfer = infer_cls(hp)
+        example_inputs = self.example_inputs
+        iface = gr.Interface(fn=self.greet,
+                             inputs=Textbox(
+                                 lines=10, placeholder=None, default=example_inputs[0], label="input text"),
+                             outputs="audio",
+                             allow_flagging="never",
+                             title=self.title,
+                             description=self.description,
+                             article=self.article,
+                             examples=example_inputs,
+                             enable_queue=True)
+        iface.launch(share=True,cache_examples=True)
+
+
+if __name__ == '__main__':
+    gradio_config = yaml.safe_load(open('inference/gradio/gradio_settings.yaml'))
+    g = GradioInfer(**gradio_config)
+    g.run()
diff --git a/modules/FastDiff/config/FastDiff.yaml b/modules/FastDiff/config/FastDiff.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..af6f1d1511deea17dd0d5fc08bf0270682a2623e
--- /dev/null
+++ b/modules/FastDiff/config/FastDiff.yaml
@@ -0,0 +1,7 @@
+base_config:
+  - ./base.yaml
+
+audio_sample_rate: 22050
+raw_data_dir: 'data/raw/LJSpeech-1.1'
+processed_data_dir: 'data/processed/LJSpeech'
+binary_data_dir: 'data/binary/LJSpeech'
diff --git a/modules/FastDiff/config/FastDiff_libritts.yaml b/modules/FastDiff/config/FastDiff_libritts.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..372f3e07e7ecbe546ee91a6ac32f5bd370de81d9
--- /dev/null
+++ b/modules/FastDiff/config/FastDiff_libritts.yaml
@@ -0,0 +1,7 @@
+base_config:
+  - ./base.yaml
+
+audio_sample_rate: 22050
+raw_data_dir: 'data/raw/LibriTTS'
+processed_data_dir: 'data/processed/LibriTTS'
+binary_data_dir: 'data/binary/LibriTTS'
\ No newline at end of file
diff --git a/modules/FastDiff/config/FastDiff_sc09.yaml b/modules/FastDiff/config/FastDiff_sc09.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..23322fd8b6d50595ff1e2ae4b3d080ed707ade74
--- /dev/null
+++ b/modules/FastDiff/config/FastDiff_sc09.yaml
@@ -0,0 +1,25 @@
+base_config:
+  - egs/egs_bases/tts/vocoder/base.yaml
+  - egs/datasets/audio/lj/base_mel2wav.yaml
+  - ./base.yaml
+
+#raw_data_dir: '/home1/huangrongjie/dataset/sc09/data/'
+#processed_data_dir: 'data/processed/SC09'
+#binary_data_dir: 'data/binary/SC09'
+
+raw_data_dir: '/home1/huangrongjie/Project/AdaGrad/data/raw/SC09/'
+processed_data_dir: 'data/processed/SC09_ten_processed'
+binary_data_dir: 'data/binary/SC09_ten_processed'
+
+pre_align_cls: egs.datasets.audio.sc09.pre_align.Sc09PreAlign
+audio_sample_rate: 16000
+max_samples: 12800
+
+pre_align_args:
+  sox_resample: false
+  sox_to_wav: false
+  allow_no_txt: true
+  trim_sil: true
+  denoise: true
+
+loud_norm: true
\ No newline at end of file
diff --git a/modules/FastDiff/config/FastDiff_tacotron.yaml b/modules/FastDiff/config/FastDiff_tacotron.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1c180b02789ac6a38b22b88496f000dcc259b330
--- /dev/null
+++ b/modules/FastDiff/config/FastDiff_tacotron.yaml
@@ -0,0 +1,58 @@
+base_config:
+  - egs/egs_bases/tts/vocoder/pwg.yaml
+  - egs/egs_bases/tts/base_mel2wav.yaml
+  - egs/datasets/audio/lj/pwg.yaml
+
+raw_data_dir: 'data/raw/LJSpeech-1.1'
+processed_data_dir: 'data/processed/LJSpeech_FastDiff'
+#binary_data_dir: 'data/binary/LJSpeech_Taco'
+binary_data_dir: /apdcephfs/private_nlphuang/preprocess/AdaGrad/data/binary/LJSpeech_Taco
+
+binarizer_cls: data_gen.tts.vocoder_binarizer.VocoderBinarizer
+pre_align_cls: egs.datasets.audio.lj.pre_align.LJPreAlign
+task_cls: modules.FastDiff.task.FastDiff.FastDiffTask
+binarization_args:
+  with_wav: true
+  with_spk_embed: false
+  with_align: false
+  with_word: false
+  with_txt: false
+  with_f0: false
+
+# data
+num_spk: 400
+max_samples: 25600
+aux_context_window: 0
+max_sentences: 20
+test_input_dir: '' # 'wavs' # wav->wav inference
+test_mel_dir: '' # 'mels' # mel->wav inference
+use_wav: True # mel->wav inference
+
+# training
+num_sanity_val_steps: -1
+max_updates: 1000000
+lr: 2e-4
+weight_decay: 0
+
+# FastDiff
+audio_channels: 1
+inner_channels: 32
+cond_channels: 80
+upsample_ratios: [8, 8, 4]
+lvc_layers_each_block: 4
+lvc_kernel_size: 3
+kpnet_hidden_channels: 64
+kpnet_conv_size: 3
+dropout: 0.0
+diffusion_step_embed_dim_in: 128
+diffusion_step_embed_dim_mid: 512
+diffusion_step_embed_dim_out: 512
+use_weight_norm: True
+
+# Diffusion
+T: 1000
+beta_0: 0.000001
+beta_T: 0.01
+noise_schedule: ''
+N: ''
+
diff --git a/modules/FastDiff/config/FastDiff_vctk.yaml b/modules/FastDiff/config/FastDiff_vctk.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..54bb0000db9351348d7923c921e1040bfd48890e
--- /dev/null
+++ b/modules/FastDiff/config/FastDiff_vctk.yaml
@@ -0,0 +1,7 @@
+base_config:
+  - ./base.yaml
+
+audio_sample_rate: 22050
+raw_data_dir: 'data/raw/VCTK'
+processed_data_dir: 'data/processed/VCTK'
+binary_data_dir: 'data/binary/VCTK'
\ No newline at end of file
diff --git a/modules/FastDiff/config/base.yaml b/modules/FastDiff/config/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7d8ace1c05811117c1798d1ef49f307d0d7929e5
--- /dev/null
+++ b/modules/FastDiff/config/base.yaml
@@ -0,0 +1,157 @@
+#############
+# Custom dataset preprocess
+#############
+audio_num_mel_bins: 80
+audio_sample_rate: 22050
+hop_size: 256  # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
+win_size: 1024  # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
+fmin: 80  # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+fmax: 7600  # To be increased/reduced depending on data.
+fft_size: 1024  # Extra window size is filled with 0 paddings to match this parameter
+min_level_db: -100
+ref_level_db: 20
+griffin_lim_iters: 60
+num_spk: 1 # number of speakers
+mel_vmin: -6
+mel_vmax: 1.5
+
+#############
+# FastDiff Model
+#############
+audio_channels: 1
+inner_channels: 32
+cond_channels: 80
+upsample_ratios: [8, 8, 4]
+lvc_layers_each_block: 4
+lvc_kernel_size: 3
+kpnet_hidden_channels: 64
+kpnet_conv_size: 3
+dropout: 0.0
+diffusion_step_embed_dim_in: 128
+diffusion_step_embed_dim_mid: 512
+diffusion_step_embed_dim_out: 512
+use_weight_norm: True
+
+###########
+# Diffusion
+###########
+T: 1000
+beta_0: 0.000001
+beta_T: 0.01
+noise_schedule: ''
+N: ''
+
+
+###########
+# train and eval
+###########
+task_cls: modules.FastDiff.task.FastDiff.FastDiffTask
+max_updates: 1000000 # max training steps
+max_samples: 25600 # audio length in training
+max_sentences: 20 # max batch size in training
+num_sanity_val_steps: -1
+max_valid_sentences: 1
+valid_infer_interval: 10000
+val_check_interval: 2000
+num_test_samples: 0
+num_valid_plots: 10
+
+
+#############
+# Stage 1 of data processing
+#############
+pre_align_cls: egs.datasets.audio.pre_align.PreAlign
+pre_align_args:
+  nsample_per_mfa_group: 1000
+  txt_processor: en
+  use_tone: true # for ZH
+  sox_resample: false
+  sox_to_wav: false
+  allow_no_txt: true
+  trim_sil: false
+  denoise: false
+
+
+#############
+# Stage 2 of data processing
+#############
+binarizer_cls: data_gen.tts.vocoder_binarizer.VocoderBinarizer
+binarization_args:
+  with_wav: true
+  with_spk_embed: false
+  with_align: false
+  with_word: false
+  with_txt: false
+  with_f0: false
+  shuffle: false
+  with_spk_id: true
+  with_f0cwt: false
+  with_linear: false
+  trim_eos_bos: false
+  reset_phone_dict: true
+  reset_word_dict: true
+
+
+###########
+# optimization
+###########
+lr: 2e-4    # learning rate
+weight_decay: 0
+scheduler: rsqrt # rsqrt|none
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+clip_grad_norm: 1
+clip_grad_value: 0
+
+#############
+# Setting for this Pytorch framework
+#############
+max_input_tokens: 1550
+frames_multiple: 1
+use_word_input: false
+vocoder: FastDiff
+vocoder_ckpt: checkpoints/FastDiff
+vocoder_denoise_c: 0.0
+max_tokens: 30000
+max_valid_tokens: 60000
+test_ids: [ ]
+profile_infer: false
+out_wav_norm: false
+save_gt: true
+save_f0: false
+aux_context_window: 0
+test_input_dir: '' # 'wavs' # wav->wav inference
+test_mel_dir: '' # 'mels' # mel->wav inference
+use_wav: True # mel->wav inference
+pitch_extractor: parselmouth
+loud_norm: false
+endless_ds: true
+test_num: 100
+min_frames: 0
+max_frames: 1548
+ds_workers: 1
+gen_dir_name: ''
+accumulate_grad_batches: 1
+tb_log_interval: 100
+print_nan_grads: false
+work_dir: '' # experiment directory.
+infer: false # inference
+amp: false
+debug: false
+save_codes: []
+save_best: true
+num_ckpt_keep: 3
+sort_by_len: true
+load_ckpt: ''
+check_val_every_n_epoch: 10
+max_epochs: 1000
+eval_max_batches: -1
+resume_from_checkpoint: 0
+rename_tmux: true
+valid_monitor_key: 'val_loss'
+valid_monitor_mode: 'min'
+train_set_name: 'train'
+train_sets: ''
+valid_set_name: 'valid'
+test_set_name: 'test'
+seed: 1234
\ No newline at end of file
diff --git a/modules/FastDiff/module/FastDiff_model.py b/modules/FastDiff/module/FastDiff_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e52a488bf53432eadc79f3127a64e46bda4d4532
--- /dev/null
+++ b/modules/FastDiff/module/FastDiff_model.py
@@ -0,0 +1,123 @@
+import torch.nn as nn
+import torch
+import logging
+from modules.FastDiff.module.modules import DiffusionDBlock, TimeAware_LVCBlock
+from modules.FastDiff.module.util import calc_diffusion_step_embedding
+
+def swish(x):
+    return x * torch.sigmoid(x)
+
+class FastDiff(nn.Module):
+    """FastDiff module."""
+
+    def __init__(self,
+                 audio_channels=1,
+                 inner_channels=32,
+                 cond_channels=80,
+                 upsample_ratios=[8, 8, 4],
+                 lvc_layers_each_block=4,
+                 lvc_kernel_size=3,
+                 kpnet_hidden_channels=64,
+                 kpnet_conv_size=3,
+                 dropout=0.0,
+                 diffusion_step_embed_dim_in=128,
+                 diffusion_step_embed_dim_mid=512,
+                 diffusion_step_embed_dim_out=512,
+                 use_weight_norm=True):
+        super().__init__()
+
+        self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
+
+        self.audio_channels = audio_channels
+        self.cond_channels = cond_channels
+        self.lvc_block_nums = len(upsample_ratios)
+        self.first_audio_conv = nn.Conv1d(1, inner_channels,
+                                    kernel_size=7, padding=(7 - 1) // 2,
+                                    dilation=1, bias=True)
+
+        # define residual blocks
+        self.lvc_blocks = nn.ModuleList()
+        self.downsample = nn.ModuleList()
+
+        # the layer-specific fc for noise scale embedding
+        self.fc_t = nn.ModuleList()
+        self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
+        self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)
+
+        cond_hop_length = 1
+        for n in range(self.lvc_block_nums):
+            cond_hop_length = cond_hop_length * upsample_ratios[n]
+            lvcb = TimeAware_LVCBlock(
+                in_channels=inner_channels,
+                cond_channels=cond_channels,
+                upsample_ratio=upsample_ratios[n],
+                conv_layers=lvc_layers_each_block,
+                conv_kernel_size=lvc_kernel_size,
+                cond_hop_length=cond_hop_length,
+                kpnet_hidden_channels=kpnet_hidden_channels,
+                kpnet_conv_size=kpnet_conv_size,
+                kpnet_dropout=dropout,
+                noise_scale_embed_dim_out=diffusion_step_embed_dim_out
+            )
+            self.lvc_blocks += [lvcb]
+            self.downsample.append(DiffusionDBlock(inner_channels, inner_channels, upsample_ratios[self.lvc_block_nums-n-1]))
+
+
+        # define output layers
+        self.final_conv = nn.Sequential(nn.Conv1d(inner_channels, audio_channels, kernel_size=7, padding=(7 - 1) // 2,
+                                        dilation=1, bias=True))
+
+        # apply weight norm
+        if use_weight_norm:
+            self.apply_weight_norm()
+
+    def forward(self, data):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Input noise signal (B, 1, T).
+            c (Tensor): Local conditioning auxiliary features (B, C ,T').
+        Returns:
+            Tensor: Output tensor (B, out_channels, T)
+        """
+        audio, c, diffusion_steps = data
+
+        # embed diffusion step t
+        diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in)
+        diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
+        diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
+
+        audio = self.first_audio_conv(audio)
+        downsample = []
+        for down_layer in self.downsample:
+            downsample.append(audio)
+            audio = down_layer(audio)
+
+        x = audio
+        for n, audio_down in enumerate(reversed(downsample)):
+            x = self.lvc_blocks[n]((x, audio_down, c, diffusion_step_embed))
+
+        # apply final layers
+        x = self.final_conv(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        """Remove weight normalization module from all of the layers."""
+        def _remove_weight_norm(m):
+            try:
+                logging.debug(f"Weight norm is removed from {m}.")
+                torch.nn.utils.remove_weight_norm(m)
+            except ValueError:  # this module didn't have weight norm
+                return
+
+        self.apply(_remove_weight_norm)
+
+    def apply_weight_norm(self):
+        """Apply weight normalization module from all of the layers."""
+        def _apply_weight_norm(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+                torch.nn.utils.weight_norm(m)
+                logging.debug(f"Weight norm is applied to {m}.")
+
+        self.apply(_apply_weight_norm)
+
diff --git a/modules/FastDiff/module/WaveNet.py b/modules/FastDiff/module/WaveNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..15f5fdc75ff696646c86551642deaebf2dd89ead
--- /dev/null
+++ b/modules/FastDiff/module/WaveNet.py
@@ -0,0 +1,189 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from modules.FastDiff.module.util import calc_noise_scale_embedding
+def swish(x):
+    return x * torch.sigmoid(x)
+
+
+# dilated conv layer with kaiming_normal initialization
+# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
+class Conv(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
+        super(Conv, self).__init__()
+        self.padding = dilation * (kernel_size - 1) // 2
+        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
+        self.conv = nn.utils.weight_norm(self.conv)
+        nn.init.kaiming_normal_(self.conv.weight)
+
+    def forward(self, x):
+        out = self.conv(x)
+        return out
+
+
+# conv1x1 layer with zero initialization
+# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
+class ZeroConv1d(nn.Module):
+    def __init__(self, in_channel, out_channel):
+        super(ZeroConv1d, self).__init__()
+        self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
+        self.conv.weight.data.zero_()
+        self.conv.bias.data.zero_()
+
+    def forward(self, x):
+        out = self.conv(x)
+        return out
+
+
+# every residual block (named residual layer in paper)
+# contains one noncausal dilated conv
+class Residual_block(nn.Module):
+    def __init__(self, res_channels, skip_channels, dilation, 
+                 noise_scale_embed_dim_out, multiband=True):
+        super(Residual_block, self).__init__()
+        self.res_channels = res_channels
+
+        # the layer-specific fc for noise scale embedding
+        self.fc_t = nn.Linear(noise_scale_embed_dim_out, self.res_channels)
+
+        # dilated conv layer
+        self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)
+
+        # add mel spectrogram upsampler and conditioner conv1x1 layer
+        self.upsample_conv2d = torch.nn.ModuleList()
+        if multiband is True:
+             params = 8
+        else:
+             params = 16
+        for s in [params, params]:  #######  Very  Important!!!!!  #######
+            conv_trans2d = torch.nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
+            conv_trans2d = torch.nn.utils.weight_norm(conv_trans2d)
+            torch.nn.init.kaiming_normal_(conv_trans2d.weight)
+            self.upsample_conv2d.append(conv_trans2d)
+        self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)  # 80 is mel bands
+
+        # residual conv1x1 layer, connect to next residual layer
+        self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
+        self.res_conv = nn.utils.weight_norm(self.res_conv)
+        nn.init.kaiming_normal_(self.res_conv.weight)
+
+        # skip conv1x1 layer, add to all skip outputs through skip connections
+        self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
+        self.skip_conv = nn.utils.weight_norm(self.skip_conv)
+        nn.init.kaiming_normal_(self.skip_conv.weight)
+
+    def forward(self, input_data):
+        x, mel_spec, noise_scale_embed = input_data
+        h = x
+        B, C, L = x.shape   # B, res_channels, L
+        assert C == self.res_channels
+
+        # add in noise scale embedding
+        part_t = self.fc_t(noise_scale_embed)
+        part_t = part_t.view([B, self.res_channels, 1])
+        h += part_t
+
+        # dilated conv layer
+        h = self.dilated_conv_layer(h)
+
+        # add mel spectrogram as (local) conditioner
+        assert mel_spec is not None
+
+        # Upsample spectrogram to size of audio
+        mel_spec = torch.unsqueeze(mel_spec, dim=1)  # (B, 1, 80, T')
+        mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4)
+        mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4)
+        mel_spec = torch.squeeze(mel_spec, dim=1)
+
+        assert(mel_spec.size(2) >= L)
+        if mel_spec.size(2) > L:
+            mel_spec = mel_spec[:, :, :L]
+
+        mel_spec = self.mel_conv(mel_spec)
+        h += mel_spec
+
+        # gated-tanh nonlinearity
+        out = torch.tanh(h[:,:self.res_channels,:]) * torch.sigmoid(h[:,self.res_channels:,:])
+
+        # residual and skip outputs
+        res = self.res_conv(out)
+        assert x.shape == res.shape
+        skip = self.skip_conv(out)
+
+        return (x + res) * math.sqrt(0.5), skip  # normalize for training stability
+
+
+class Residual_group(nn.Module):
+    def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, 
+                 noise_scale_embed_dim_in, 
+                 noise_scale_embed_dim_mid,
+                 noise_scale_embed_dim_out, multiband):
+        super(Residual_group, self).__init__()
+        self.num_res_layers = num_res_layers
+        self.noise_scale_embed_dim_in = noise_scale_embed_dim_in
+
+        # the shared two fc layers for noise scale embedding
+        self.fc_t1 = nn.Linear(noise_scale_embed_dim_in, noise_scale_embed_dim_mid)
+        self.fc_t2 = nn.Linear(noise_scale_embed_dim_mid, noise_scale_embed_dim_out)
+
+        # stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
+        self.residual_blocks = nn.ModuleList()
+        for n in range(self.num_res_layers):
+            self.residual_blocks.append(Residual_block(res_channels, skip_channels, 
+                                                       dilation=2 ** (n % dilation_cycle),
+                                                       noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband))
+
+    def forward(self, input_data):
+        x, mel_spectrogram, noise_scales = input_data
+
+        # embed noise scale
+        noise_scale_embed = calc_noise_scale_embedding(noise_scales, self.noise_scale_embed_dim_in)
+        noise_scale_embed = swish(self.fc_t1(noise_scale_embed))
+        noise_scale_embed = swish(self.fc_t2(noise_scale_embed))
+
+        # pass all residual layers
+        h = x
+        skip = 0
+        for n in range(self.num_res_layers):
+            h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, noise_scale_embed))  # use the output from last residual layer
+            skip += skip_n  # accumulate all skip outputs
+
+        return skip * math.sqrt(1.0 / self.num_res_layers)  # normalize for training stability
+
+
+class WaveNet_vocoder(nn.Module):
+    def __init__(self, in_channels, res_channels, skip_channels, out_channels, 
+                 num_res_layers, dilation_cycle, 
+                 noise_scale_embed_dim_in, 
+                 noise_scale_embed_dim_mid,
+                 noise_scale_embed_dim_out, multiband):
+        super(WaveNet_vocoder, self).__init__()
+
+        # initial conv1x1 with relu
+        self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU())
+        
+        # all residual layers
+        self.residual_layer = Residual_group(res_channels=res_channels, 
+                                             skip_channels=skip_channels, 
+                                             num_res_layers=num_res_layers, 
+                                             dilation_cycle=dilation_cycle,
+                                             noise_scale_embed_dim_in=noise_scale_embed_dim_in,
+                                             noise_scale_embed_dim_mid=noise_scale_embed_dim_mid,
+                                             noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband)
+        
+        # final conv1x1 -> relu -> zeroconv1x1
+        self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
+                                        nn.ReLU(),
+                                        ZeroConv1d(skip_channels, out_channels))
+
+    def forward(self, input_data):
+        audio, mel_spectrogram, noise_scales = input_data  # b x band x T, b x 80 x T', b x 1
+        x = audio
+        x = self.init_conv(x)
+        x = self.residual_layer((x, mel_spectrogram, noise_scales))
+        x = self.final_conv(x)
+
+        return x
+
diff --git a/modules/FastDiff/module/modules.py b/modules/FastDiff/module/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b0f42123b10a0518093c23592277b9622b5266
--- /dev/null
+++ b/modules/FastDiff/module/modules.py
@@ -0,0 +1,343 @@
+import math
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.nn import Conv1d
+
+LRELU_SLOPE = 0.1
+
+
+
+def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
+    ''' Sinusoid position encoding table '''
+
+    def cal_angle(position, hid_idx):
+        return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
+
+    def get_posi_angle_vec(position):
+        return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
+
+    sinusoid_table = np.array([get_posi_angle_vec(pos_i)
+                               for pos_i in range(n_position)])
+
+    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
+    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
+
+    if padding_idx is not None:
+        # zero vector for padding dimension
+        sinusoid_table[padding_idx] = 0.
+
+    return torch.FloatTensor(sinusoid_table)
+
+
+def overlap_and_add(signal, frame_step):
+    """Reconstructs a signal from a framed representation.
+
+    Adds potentially overlapping frames of a signal with shape
+    `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
+    The resulting tensor has shape `[..., output_size]` where
+
+        output_size = (frames - 1) * frame_step + frame_length
+
+    Args:
+        signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
+        frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
+
+    Returns:
+        A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
+        output_size = (frames - 1) * frame_step + frame_length
+
+    Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
+    """
+    outer_dimensions = signal.size()[:-2]
+    frames, frame_length = signal.size()[-2:]
+
+    # gcd=Greatest Common Divisor
+    subframe_length = math.gcd(frame_length, frame_step)
+    subframe_step = frame_step // subframe_length
+    subframes_per_frame = frame_length // subframe_length
+    output_size = frame_step * (frames - 1) + frame_length
+    output_subframes = output_size // subframe_length
+
+    subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
+
+    frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
+    frame = signal.new_tensor(frame).long()  # signal may in GPU or CPU
+    frame = frame.contiguous().view(-1)
+
+    result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
+    device_of_result = result.device
+    result.index_add_(-2, frame.to(device_of_result), subframe_signal)
+    result = result.view(*outer_dimensions, -1)
+    return result
+
+
+class LastLayer(nn.Module):
+    def __init__(self, in_channels, out_channels,
+                 nonlinear_activation, nonlinear_activation_params,
+                 pad, kernel_size, pad_params, bias):
+        super(LastLayer, self).__init__()
+        self.activation = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
+        self.pad = getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params)
+        self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias)
+
+    def forward(self, x):
+        x = self.activation(x)
+        x = self.pad(x)
+        x = self.conv(x)
+        return x
+
+
+class WeightConv1d(Conv1d):
+    """Conv1d module with customized initialization."""
+
+    def __init__(self, *args, **kwargs):
+        """Initialize Conv1d module."""
+        super(Conv1d, self).__init__(*args, **kwargs)
+
+    def reset_parameters(self):
+        """Reset parameters."""
+        torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
+        if self.bias is not None:
+            torch.nn.init.constant_(self.bias, 0.0)
+
+
+class Conv1d1x1(Conv1d):
+    """1x1 Conv1d with customized initialization."""
+
+    def __init__(self, in_channels, out_channels, bias):
+        """Initialize 1x1 Conv1d module."""
+        super(Conv1d1x1, self).__init__(in_channels, out_channels,
+                                        kernel_size=1, padding=0,
+                                        dilation=1, bias=bias)
+
+class DiffusionDBlock(nn.Module):
+  def __init__(self, input_size, hidden_size, factor):
+    super().__init__()
+    self.factor = factor
+    self.residual_dense = Conv1d(input_size, hidden_size, 1)
+    self.conv = nn.ModuleList([
+        Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
+        Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
+        Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
+    ])
+
+  def forward(self, x):
+    size = x.shape[-1] // self.factor
+
+    residual = self.residual_dense(x)
+    residual = F.interpolate(residual, size=size)
+
+    x = F.interpolate(x, size=size)
+    for layer in self.conv:
+      x = F.leaky_relu(x, 0.2)
+      x = layer(x)
+
+    return x + residual
+
+
+class TimeAware_LVCBlock(torch.nn.Module):
+    ''' time-aware location-variable convolutions
+    '''
+    def __init__(self,
+                 in_channels,
+                 cond_channels,
+                 upsample_ratio,
+                 conv_layers=4,
+                 conv_kernel_size=3,
+                 cond_hop_length=256,
+                 kpnet_hidden_channels=64,
+                 kpnet_conv_size=3,
+                 kpnet_dropout=0.0,
+                 noise_scale_embed_dim_out=512
+                 ):
+        super().__init__()
+
+        self.cond_hop_length = cond_hop_length
+        self.conv_layers = conv_layers
+        self.conv_kernel_size = conv_kernel_size
+        self.convs = torch.nn.ModuleList()
+
+        self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels,
+                                    kernel_size=upsample_ratio*2, stride=upsample_ratio,
+                                    padding=upsample_ratio // 2 + upsample_ratio % 2,
+                                    output_padding=upsample_ratio % 2)
+
+
+        self.kernel_predictor = KernelPredictor(
+            cond_channels=cond_channels,
+            conv_in_channels=in_channels,
+            conv_out_channels=2 * in_channels,
+            conv_layers=conv_layers,
+            conv_kernel_size=conv_kernel_size,
+            kpnet_hidden_channels=kpnet_hidden_channels,
+            kpnet_conv_size=kpnet_conv_size,
+            kpnet_dropout=kpnet_dropout
+        )
+
+        # the layer-specific fc for noise scale embedding
+        self.fc_t = torch.nn.Linear(noise_scale_embed_dim_out, cond_channels)
+
+        for i in range(conv_layers):
+            padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
+            conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i)
+
+            self.convs.append(conv)
+
+
+    def forward(self, data):
+        ''' forward propagation of the time-aware location-variable convolutions.
+        Args:
+            x (Tensor): the input sequence (batch, in_channels, in_length)
+            c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
+
+        Returns:
+            Tensor: the output sequence (batch, in_channels, in_length)
+        '''
+        x, audio_down, c, noise_embedding = data
+        batch, in_channels, in_length = x.shape
+
+        noise = (self.fc_t(noise_embedding)).unsqueeze(-1)  # (B, 80)
+        condition = c + noise  # (B, 80, T)
+        kernels, bias = self.kernel_predictor(condition)
+        x = F.leaky_relu(x, 0.2)
+        x = self.upsample(x)
+
+        for i in range(self.conv_layers):
+            x += audio_down
+            y = F.leaky_relu(x, 0.2)
+            y = self.convs[i](y)
+            y = F.leaky_relu(y, 0.2)
+
+            k = kernels[:, i, :, :, :, :]
+            b = bias[:, i, :, :]
+            y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
+            x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
+        return x
+
+    def location_variable_convolution(self, x, kernel, bias, dilation, hop_size):
+        ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
+        Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
+        Args:
+            x (Tensor): the input sequence (batch, in_channels, in_length).
+            kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
+            bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
+            dilation (int): the dilation of convolution.
+            hop_size (int): the hop_size of the conditioning sequence.
+        Returns:
+            (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
+        '''
+        batch, in_channels, in_length = x.shape
+        batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape
+
+
+        assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
+
+        padding = dilation * int((kernel_size - 1) / 2)
+        x = F.pad(x, (padding, padding), 'constant', 0)  # (batch, in_channels, in_length + 2*padding)
+        x = x.unfold(2, hop_size + 2 * padding, hop_size)  # (batch, in_channels, kernel_length, hop_size + 2*padding)
+
+        if hop_size < dilation:
+            x = F.pad(x, (0, dilation), 'constant', 0)
+        x = x.unfold(3, dilation,
+                     dilation)  # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
+        x = x[:, :, :, :, :hop_size]
+        x = x.transpose(3, 4)  # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
+        x = x.unfold(4, kernel_size, 1)  # (batch, in_channels, kernel_length, dilation, _, kernel_size)
+
+        o = torch.einsum('bildsk,biokl->bolsd', x, kernel)
+        o = o + bias.unsqueeze(-1).unsqueeze(-1)
+        o = o.contiguous().view(batch, out_channels, -1)
+        return o
+
+
+
+class KernelPredictor(torch.nn.Module):
+    ''' Kernel predictor for the time-aware location-variable convolutions
+    '''
+
+    def __init__(self,
+                 cond_channels,
+                 conv_in_channels,
+                 conv_out_channels,
+                 conv_layers,
+                 conv_kernel_size=3,
+                 kpnet_hidden_channels=64,
+                 kpnet_conv_size=3,
+                 kpnet_dropout=0.0,
+                 kpnet_nonlinear_activation="LeakyReLU",
+                 kpnet_nonlinear_activation_params={"negative_slope": 0.1}
+                 ):
+        '''
+        Args:
+            cond_channels (int): number of channel for the conditioning sequence,
+            conv_in_channels (int): number of channel for the input sequence,
+            conv_out_channels (int): number of channel for the output sequence,
+            conv_layers (int):
+            kpnet_
+        '''
+        super().__init__()
+
+        self.conv_in_channels = conv_in_channels
+        self.conv_out_channels = conv_out_channels
+        self.conv_kernel_size = conv_kernel_size
+        self.conv_layers = conv_layers
+
+        l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
+        l_b = conv_out_channels * conv_layers
+
+        padding = (kpnet_conv_size - 1) // 2
+        self.input_conv = torch.nn.Sequential(
+            torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
+            getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+        )
+
+        self.residual_conv = torch.nn.Sequential(
+            torch.nn.Dropout(kpnet_dropout),
+            torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
+            getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+            torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
+            getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+            torch.nn.Dropout(kpnet_dropout),
+            torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
+            getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+            torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
+            getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+            torch.nn.Dropout(kpnet_dropout),
+            torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
+            getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+            torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
+            getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+        )
+
+        self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size,
+                                           padding=padding, bias=True)
+        self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding,
+                                         bias=True)
+
+    def forward(self, c):
+        '''
+        Args:
+            c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
+        Returns:
+        '''
+        batch, cond_channels, cond_length = c.shape
+
+        c = self.input_conv(c)
+        c = c + self.residual_conv(c)
+        k = self.kernel_conv(c)
+        b = self.bias_conv(c)
+
+        kernels = k.contiguous().view(batch,
+                                      self.conv_layers,
+                                      self.conv_in_channels,
+                                      self.conv_out_channels,
+                                      self.conv_kernel_size,
+                                      cond_length)
+        bias = b.contiguous().view(batch,
+                                   self.conv_layers,
+                                   self.conv_out_channels,
+                                   cond_length)
+        return kernels, bias
diff --git a/modules/FastDiff/module/util.py b/modules/FastDiff/module/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f3b5ff412c70ae6674596ed5e5903d347ad167b
--- /dev/null
+++ b/modules/FastDiff/module/util.py
@@ -0,0 +1,429 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import copy
+from tqdm import tqdm
+def flatten(v):
+    """
+    Flatten a list of lists/tuples
+    """
+
+    return [x for y in v for x in y]
+
+
+def rescale(x):
+    """
+    Rescale a tensor to 0-1
+    """
+
+    return (x - x.min()) / (x.max() - x.min())
+
+
+def find_max_epoch(path):
+    """
+    Find maximum epoch/iteration in path, formatted ${n_iter}.pkl
+    E.g. 100000.pkl
+
+    Parameters:
+    path (str): checkpoint path
+    
+    Returns:
+    maximum iteration, -1 if there is no (valid) checkpoint
+    """
+
+    files = os.listdir(path)
+    epoch = -1
+    for f in files:
+        if len(f) <= 4:
+            continue
+        if f[-4:]  == '.pkl':
+            try:
+                epoch = max(epoch, int(f[:-4]))
+            except:
+                continue
+    #print(path, epoch, flush=True)
+    return epoch
+
+
+def print_size(net):
+    """
+    Print the number of parameters of a network
+    """
+
+    if net is not None and isinstance(net, torch.nn.Module):
+        module_parameters = filter(lambda p: p.requires_grad, net.parameters())
+        params = sum([np.prod(p.size()) for p in module_parameters])
+        print("{} Parameters: {:.6f}M".format(
+            net.__class__.__name__, params / 1e6), flush=True)
+
+
+# Utilities for diffusion models
+
+def std_normal(size):
+    """
+    Generate the standard Gaussian variable of a certain size
+    """
+
+    return torch.normal(0, 1, size=size)
+
+
+def calc_noise_scale_embedding(noise_scales, noise_scale_embed_dim_in):
+    """
+    Embed a noise scale $t$ into a higher dimensional space
+    E.g. the embedding vector in the 128-dimensional space is
+    [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
+
+    Parameters:
+    noise_scales (torch.long tensor, shape=(batchsize, 1)):     
+                                noise scales for batch data
+    noise_scale_embed_dim_in (int, default=128):  
+                                dimensionality of the embedding space for discrete noise scales
+    
+    Returns:
+    the embedding vectors (torch.tensor, shape=(batchsize, noise_scale_embed_dim_in)):
+    """
+
+    assert noise_scale_embed_dim_in % 2 == 0
+
+    half_dim = noise_scale_embed_dim_in // 2
+    _embed = np.log(10000) / (half_dim - 1)
+    _embed = torch.exp(torch.arange(half_dim) * -_embed)
+    _embed = noise_scales * _embed
+    noise_scale_embed = torch.cat((torch.sin(_embed), 
+                                      torch.cos(_embed)), 1)
+    
+    return noise_scale_embed
+
+
+def calc_diffusion_hyperparams_given_beta(beta):
+    """
+    Compute diffusion process hyperparameters
+
+    Parameters:
+    beta (tensor):  beta schedule 
+    
+    Returns:
+    a dictionary of diffusion hyperparameters including:
+        T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, ))
+        These cpu tensors are changed to cuda tensors on each individual gpu
+    """
+
+    T = len(beta)
+    alpha = 1 - beta
+    sigma = beta + 0
+    for t in range(1, T):
+        alpha[t] *= alpha[t-1]  # \alpha^2_t = \prod_{s=1}^t (1-\beta_s)
+        sigma[t] *= (1-alpha[t-1]) / (1-alpha[t])  # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t)
+    alpha = torch.sqrt(alpha)
+    sigma = torch.sqrt(sigma)
+    
+    _dh = {}
+    _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma
+    diffusion_hyperparams = _dh
+    return diffusion_hyperparams
+
+
+def calc_diffusion_hyperparams(T, beta_0, beta_T, tau, N, beta_N, alpha_N, rho):
+    """
+    Compute diffusion process hyperparameters
+
+    Parameters:
+    T (int):                    number of noise scales
+    beta_0 and beta_T (float):  beta schedule start/end value, 
+                                where any beta_t in the middle is linearly interpolated
+    
+    Returns:
+    a dictionary of diffusion hyperparameters including:
+        T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, ))
+        These cpu tensors are changed to cuda tensors on each individual gpu
+    """
+
+    beta = torch.linspace(beta_0, beta_T, T)
+    alpha = 1 - beta
+    sigma = beta + 0
+    for t in range(1, T):
+        alpha[t] *= alpha[t-1]  # \alpha^2_t = \prod_{s=1}^t (1-\beta_s)
+        sigma[t] *= (1-alpha[t-1]) / (1-alpha[t])  # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t)
+    alpha = torch.sqrt(alpha)
+    sigma = torch.sqrt(sigma)
+    
+    _dh = {}
+    _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma
+    _dh["tau"], _dh["N"], _dh["betaN"], _dh["alphaN"], _dh["rho"] = tau, N, beta_N, alpha_N, rho
+    diffusion_hyperparams = _dh
+    return diffusion_hyperparams
+
+
+def sampling_given_noise_schedule(
+        net,
+        size,
+        diffusion_hyperparams,
+        inference_noise_schedule,
+        condition=None,
+        ddim=False,
+        return_sequence=False):
+    """
+    Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t)
+
+    Parameters:
+    net (torch network):            the wavenet models
+    size (tuple):                   size of tensor to be generated,
+                                    usually is (number of audios to generate, channels=1, length of audio)
+    diffusion_hyperparams (dict):   dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams
+                                    note, the tensors need to be cuda tensors
+    condition (torch.tensor):       ground truth mel spectrogram read from disk
+                                    None if used for unconditional generation
+
+    Returns:
+    the generated audio(s) in torch.tensor, shape=size
+    """
+
+    _dh = diffusion_hyperparams
+    T, alpha = _dh["T"], _dh["alpha"]
+    assert len(alpha) == T
+    assert len(size) == 3
+
+    N = len(inference_noise_schedule)
+    beta_infer = inference_noise_schedule
+    alpha_infer = 1 - beta_infer
+    sigma_infer = beta_infer + 0
+    for n in range(1, N):
+        alpha_infer[n] *= alpha_infer[n - 1]
+        sigma_infer[n] *= (1 - alpha_infer[n - 1]) / (1 - alpha_infer[n])
+    alpha_infer = torch.sqrt(alpha_infer)
+    sigma_infer = torch.sqrt(sigma_infer)
+
+    # Mapping noise scales to time steps
+    steps_infer = []
+    for n in range(N):
+        step = map_noise_scale_to_time_step(alpha_infer[n], alpha)
+        if step >= 0:
+            steps_infer.append(step)
+    steps_infer = torch.FloatTensor(steps_infer)
+
+    # N may change since alpha_infer can be out of the range of alpha
+    N = len(steps_infer)
+
+    x = std_normal(size)
+    if return_sequence:
+        x_ = copy.deepcopy(x)
+        xs = [x_]
+    with torch.no_grad():
+        for n in tqdm(range(N - 1, -1, -1), desc='FastDiff sample time step', total=N):
+            diffusion_steps = (steps_infer[n] * torch.ones((size[0], 1)))
+            epsilon_theta = net((x, condition, diffusion_steps,))
+            if ddim:
+                alpha_next = alpha_infer[n] / (1 - beta_infer[n]).sqrt()
+                c1 = alpha_next / alpha_infer[n]
+                c2 = -(1 - alpha_infer[n] ** 2.).sqrt() * c1
+                c3 = (1 - alpha_next ** 2.).sqrt()
+                x = c1 * x + c2 * epsilon_theta + c3 * epsilon_theta  # std_normal(size)
+            else:
+                x -= beta_infer[n] / torch.sqrt(1 - alpha_infer[n] ** 2.) * epsilon_theta
+                x /= torch.sqrt(1 - beta_infer[n])
+                if n > 0:
+                    x = x + sigma_infer[n] * std_normal(size)
+            if return_sequence:
+                x_ = copy.deepcopy(x)
+                xs.append(x_)
+    if return_sequence:
+        return xs
+    return x
+
+def noise_scheduling(net, size, diffusion_hyperparams, condition=None, ddim=False):
+    """
+    Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t)
+
+    Parameters:
+    net (torch network):            the wavenet models
+    size (tuple):                   size of tensor to be generated,
+                                    usually is (number of audios to generate, channels=1, length of audio)
+    diffusion_hyperparams (dict):   dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams
+                                    note, the tensors need to be cuda tensors
+    condition (torch.tensor):       ground truth mel spectrogram read from disk
+                                    None if used for unconditional generation
+
+    Returns:
+    noise schedule:                 a list of noise scales in torch.tensor, length <= N
+    """
+
+    _dh = diffusion_hyperparams
+    N, betaN, alphaN, rho, alpha = _dh["N"], _dh["betaN"], _dh["alphaN"], _dh["rho"], _dh["alpha"]
+
+    print('begin noise scheduling, maximum number of reverse steps = %d' % (N))
+
+    betas = []
+    x = std_normal(size)
+    with torch.no_grad():
+        beta_cur = torch.ones(1, 1, 1).cuda() * betaN
+        alpha_cur = torch.ones(1, 1, 1).cuda() * alphaN
+        for n in range(N - 1, -1, -1):
+            # print(n, beta_cur.squeeze().item(), alpha_cur.squeeze().item())
+            step = map_noise_scale_to_time_step(alpha_cur.squeeze().item(), alpha)
+            if step >= 0:
+                betas.append(beta_cur.squeeze().item())
+            diffusion_steps = (step * torch.ones((size[0], 1))).cuda()
+            epsilon_theta = net((x, condition, diffusion_steps,))
+            if ddim:
+                alpha_nxt = alpha_cur / (1 - beta_cur).sqrt()
+                c1 = alpha_nxt / alpha_cur
+                c2 = -(1 - alpha_cur ** 2.).sqrt() * c1
+                c3 = (1 - alpha_nxt ** 2.).sqrt()
+                x = c1 * x + c2 * epsilon_theta + c3 * epsilon_theta  # std_normal(size)
+            else:
+                x -= beta_cur / torch.sqrt(1 - alpha_cur ** 2.) * epsilon_theta
+                x /= torch.sqrt(1 - beta_cur)
+            alpha_nxt, beta_nxt = alpha_cur, beta_cur
+            alpha_cur = alpha_nxt / (1 - beta_nxt).sqrt()
+            if alpha_cur > 1:
+                break
+            beta_cur = net.noise_pred(
+                x.squeeze(1), (beta_nxt.view(-1, 1), (1 - alpha_cur ** 2.).view(-1, 1)))
+            if beta_cur.squeeze().item() < rho:
+                break
+    return torch.FloatTensor(betas[::-1]).cuda()
+
+
+def theta_timestep_loss(net, X, diffusion_hyperparams, reverse=False):
+    """
+    Compute the training loss for learning theta
+
+    Parameters:
+    net (torch network):            the wavenet models
+    X (tuple, shape=(2,)):          training data in tuple form (mel_spectrograms, audios)
+                                    mel_spectrograms: torch.tensor, shape is batchsize followed by each mel_spectrogram shape
+                                    audios: torch.tensor, shape=(batchsize, 1, length of audio)
+    diffusion_hyperparams (dict):   dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams
+                                    note, the tensors need to be cuda tensors
+
+    Returns:
+    theta loss
+    """
+    assert type(X) == tuple and len(X) == 2
+    loss_fn = nn.MSELoss()
+
+    _dh = diffusion_hyperparams
+    T, alpha = _dh["T"], _dh["alpha"]
+
+    mel_spectrogram, audio = X
+    B, C, L = audio.shape  # B is batchsize, C=1, L is audio length
+    ts = torch.randint(T, size=(B, 1, 1)).cuda()  # randomly sample steps from 1~T
+    z = std_normal(audio.shape)
+    delta = (1 - alpha[ts] ** 2.).sqrt()
+    alpha_cur = alpha[ts]
+    noisy_audio = alpha_cur * audio + delta * z  # compute x_t from q(x_t|x_0)
+    epsilon_theta = net((noisy_audio, mel_spectrogram, ts.view(B, 1),))
+
+    if reverse:
+        x0 = (noisy_audio - delta * epsilon_theta) / alpha_cur
+        return loss_fn(epsilon_theta, z), x0
+
+    return loss_fn(epsilon_theta, z)
+
+
+def phi_loss(net, X, diffusion_hyperparams):
+    """
+    Compute the training loss for learning phi
+    Parameters:
+    net (torch network):            the wavenet models
+    X (tuple, shape=(2,)):          training data in tuple form (mel_spectrograms, audios)
+                                    mel_spectrograms: torch.tensor, shape is batchsize followed by each mel_spectrogram shape
+                                    audios: torch.tensor, shape=(batchsize, 1, length of audio)
+    diffusion_hyperparams (dict):   dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams
+                                    note, the tensors need to be cuda tensors
+
+    Returns:
+    phi loss
+    """
+    assert type(X) == tuple and len(X) == 2
+    _dh = diffusion_hyperparams
+    T, alpha, tau = _dh["T"], _dh["alpha"], _dh["tau"]
+
+    mel_spectrogram, audio = X
+    B, C, L = audio.shape  # B is batchsize, C=1, L is audio length
+    ts = torch.randint(tau, T - tau, size=(B,)).cuda()  # randomly sample steps from 1~T
+    alpha_cur = alpha.index_select(0, ts).view(B, 1, 1)
+    alpha_nxt = alpha.index_select(0, ts + tau).view(B, 1, 1)
+    beta_nxt = 1 - (alpha_nxt / alpha_cur) ** 2.
+    delta = (1 - alpha_cur ** 2.).sqrt()
+    z = std_normal(audio.shape)
+    noisy_audio = alpha_cur * audio + delta * z  # compute x_t from q(x_t|x_0)
+    epsilon_theta = net((noisy_audio, mel_spectrogram, ts.view(B, 1),))
+    beta_est = net.noise_pred(noisy_audio.squeeze(1), (beta_nxt.view(B, 1), delta.view(B, 1) ** 2.))
+    phi_loss = 1 / (2. * (delta ** 2. - beta_est)) * (
+            delta * z - beta_est / delta * epsilon_theta) ** 2.
+    phi_loss += torch.log(1e-8 + delta ** 2. / (beta_est + 1e-8)) / 4.
+    phi_loss = (torch.mean(phi_loss, -1, keepdim=True) + beta_est / delta ** 2 / 2.).mean()
+
+    return phi_loss
+
+
+def compute_hyperparams_given_schedule(beta):
+    """
+    Compute diffusion process hyperparameters
+
+    Parameters:
+    beta (tensor):  beta schedule
+
+    Returns:
+    a dictionary of diffusion hyperparameters including:
+        T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, ))
+        These cpu tensors are changed to cuda tensors on each individual gpu
+    """
+
+    T = len(beta)
+    alpha = 1 - beta
+    sigma = beta + 0
+    for t in range(1, T):
+        alpha[t] *= alpha[t - 1]  # \alpha^2_t = \prod_{s=1}^t (1-\beta_s)
+        sigma[t] *= (1 - alpha[t - 1]) / (1 - alpha[t])  # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t)
+    alpha = torch.sqrt(alpha)
+    sigma = torch.sqrt(sigma)
+
+    _dh = {}
+    _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma
+    diffusion_hyperparams = _dh
+    return diffusion_hyperparams
+
+
+
+def map_noise_scale_to_time_step(alpha_infer, alpha):
+    if alpha_infer < alpha[-1]:
+        return len(alpha) - 1
+    if alpha_infer > alpha[0]:
+        return 0
+    for t in range(len(alpha) - 1):
+        if alpha[t+1] <= alpha_infer <= alpha[t]:
+             step_diff = alpha[t] - alpha_infer
+             step_diff /= alpha[t] - alpha[t+1]
+             return t + step_diff.item()
+    return -1
+
+
+def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
+    """
+    Embed a diffusion step $t$ into a higher dimensional space
+    E.g. the embedding vector in the 128-dimensional space is
+    [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
+
+    Parameters:
+    diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
+                                diffusion steps for batch data
+    diffusion_step_embed_dim_in (int, default=128):
+                                dimensionality of the embedding space for discrete diffusion steps
+
+    Returns:
+    the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
+    """
+
+    assert diffusion_step_embed_dim_in % 2 == 0
+
+    half_dim = diffusion_step_embed_dim_in // 2
+    _embed = np.log(10000) / (half_dim - 1)
+    _embed = torch.exp(torch.arange(half_dim) * -_embed)
+    _embed = diffusion_steps * _embed
+    diffusion_step_embed = torch.cat((torch.sin(_embed),
+                                      torch.cos(_embed)), 1)
+
+    return diffusion_step_embed
\ No newline at end of file
diff --git a/modules/FastDiff/task/FastDiff.py b/modules/FastDiff/task/FastDiff.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8902b4309ff45b4c1b88707e45c43238f52b795
--- /dev/null
+++ b/modules/FastDiff/task/FastDiff.py
@@ -0,0 +1,133 @@
+import os
+
+import torch
+import utils
+from modules.FastDiff.module.FastDiff_model import FastDiff
+from tasks.vocoder.vocoder_base import VocoderBaseTask
+from utils import audio
+from utils.hparams import hparams
+from modules.FastDiff.module.util import theta_timestep_loss, compute_hyperparams_given_schedule, sampling_given_noise_schedule
+
+
+class FastDiffTask(VocoderBaseTask):
+    def __init__(self):
+        super(FastDiffTask, self).__init__()
+
+    def build_model(self):
+        self.model = FastDiff(audio_channels=hparams['audio_channels'],
+                 inner_channels=hparams['inner_channels'],
+                 cond_channels=hparams['cond_channels'],
+                 upsample_ratios=hparams['upsample_ratios'],
+                 lvc_layers_each_block=hparams['lvc_layers_each_block'],
+                 lvc_kernel_size=hparams['lvc_kernel_size'],
+                 kpnet_hidden_channels=hparams['kpnet_hidden_channels'],
+                 kpnet_conv_size=hparams['kpnet_conv_size'],
+                 dropout=hparams['dropout'],
+                 diffusion_step_embed_dim_in=hparams['diffusion_step_embed_dim_in'],
+                 diffusion_step_embed_dim_mid=hparams['diffusion_step_embed_dim_mid'],
+                 diffusion_step_embed_dim_out=hparams['diffusion_step_embed_dim_out'],
+                 use_weight_norm=hparams['use_weight_norm'])
+        utils.print_arch(self.model)
+
+        # Init hyperparameters by linear schedule
+        noise_schedule = torch.linspace(float(hparams["beta_0"]), float(hparams["beta_T"]), int(hparams["T"])).cuda()
+        diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule)
+
+        # map diffusion hyperparameters to gpu
+        for key in diffusion_hyperparams:
+            if key in ["beta", "alpha", "sigma"]:
+                diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()
+        self.diffusion_hyperparams = diffusion_hyperparams
+
+        return self.model
+
+    def _training_step(self, sample, batch_idx, optimizer_idx):
+        mels = sample['mels']
+        y = sample['wavs']
+        X = (mels, y)
+        loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams)
+        return loss, {'loss': loss}
+
+
+    def validation_step(self, sample, batch_idx):
+        mels = sample['mels']
+        y = sample['wavs']
+        X = (mels, y)
+        loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams)
+        return loss, {'loss': loss}
+
+
+    def test_step(self, sample, batch_idx):
+        mels = sample['mels']
+        y = sample['wavs']
+        loss_output = {}
+
+        if hparams['noise_schedule'] != '':
+            noise_schedule = hparams['noise_schedule']
+            if isinstance(noise_schedule, list):
+                noise_schedule = torch.FloatTensor(noise_schedule).cuda()
+        else:
+            # Select Schedule
+            try:
+                reverse_step = int(hparams.get('N'))
+            except:
+                print('Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.')
+                reverse_step = 4
+            if reverse_step == 1000:
+                noise_schedule = torch.linspace(0.000001, 0.01, 1000).cuda()
+            elif reverse_step == 200:
+                noise_schedule = torch.linspace(0.0001, 0.02, 200).cuda()
+
+            # Below are schedules derived by Noise Predictor.
+            # We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned!
+            elif reverse_step == 8:
+                noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513,
+                                 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5]
+            elif reverse_step == 6:
+                noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984,
+                                  0.006634317338466644, 0.09357017278671265, 0.6000000238418579]
+            elif reverse_step == 4:
+                noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]
+            elif reverse_step == 3:
+                noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01]
+            else:
+                raise NotImplementedError
+
+        if isinstance(noise_schedule, list):
+            noise_schedule = torch.FloatTensor(noise_schedule).cuda()
+
+        audio_length = mels.shape[-1] * hparams["hop_size"]
+        # generate using DDPM reverse process
+
+        y_ = sampling_given_noise_schedule(
+            self.model, (1, 1, audio_length), self.diffusion_hyperparams, noise_schedule,
+            condition=mels, ddim=False, return_sequence=False)
+        gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
+        os.makedirs(gen_dir, exist_ok=True)
+
+        if len(y) == 0:
+            # Inference from mel
+            for idx, (wav_pred, item_name) in enumerate(zip(y_, sample["item_name"])):
+                wav_pred = wav_pred / wav_pred.abs().max()
+                audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav',
+                               hparams['audio_sample_rate'])
+        else:
+            for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])):
+                wav_gt = wav_gt / wav_gt.abs().max()
+                wav_pred = wav_pred / wav_pred.abs().max()
+                audio.save_wav(wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', hparams['audio_sample_rate'])
+                audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', hparams['audio_sample_rate'])
+        return loss_output
+        
+    def build_optimizer(self, model):
+        self.optimizer = optimizer = torch.optim.AdamW(
+            self.model.parameters(),
+            lr=float(hparams['lr']), weight_decay=float(hparams['weight_decay']))
+        return optimizer
+
+    def compute_rtf(self, sample, generation_time, sample_rate=22050):
+        """
+        Computes RTF for a given sample.
+        """
+        total_length = sample.shape[-1]
+        return float(generation_time * sample_rate / total_length)
\ No newline at end of file
diff --git a/modules/ProDiff/config/base.yaml b/modules/ProDiff/config/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..696546e3302a55ee97f633136ead9c44d3e45a12
--- /dev/null
+++ b/modules/ProDiff/config/base.yaml
@@ -0,0 +1,67 @@
+base_config:
+  - egs/egs_bases/tts/fs2.yaml
+
+# diffusion model
+diff_decoder_type: 'wavenet'
+dilation_cycle_length: 1
+residual_layers: 20
+residual_channels: 256
+keep_bins: 80
+spec_min: [ ]
+spec_max: [ ]
+diff_loss_type: l1
+timesteps: 100
+max_beta: 0.06
+
+# train
+max_sentences: 48
+max_updates: 200000
+
+
+# FastDiff vocoder
+vocoder: FastDiff
+N: 4 # denoising steps
+vocoder_ckpt: checkpoints/FastDiff
+
+# eval
+use_gt_dur: true
+use_gt_f0: true
+gen_tgt_spk_id: -1
+num_sanity_val_steps: -1
+num_valid_plots: 10
+use_cond_disc: true
+save_gt: true
+num_test_samples: 20
+max_valid_sentences: 1
+text: the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.
+
+
+# variation
+pitch_type: frame
+pitch_extractor: 'parselmouth'
+use_pitch_embed: true
+use_energy_embed: true
+mel_loss: "ssim:0.5|l1:0.5"
+
+
+# dataset
+preprocess_cls: egs.datasets.audio.lj.pre_align.LJPreAlign
+preprocess_args:
+  nsample_per_mfa_group: 1000
+  # text process
+  txt_processor: en
+  use_mfa: true
+  with_phsep: true
+  reset_phone_dict: true
+  reset_word_dict: true
+  add_eos_bos: true
+  # mfa
+  mfa_group_shuffle: false
+  mfa_offset: 0.02
+  # wav processors
+  wav_processors: [ ]
+  save_sil_mask: true
+  vad_max_silence_length: 12
+
+
+
diff --git a/modules/ProDiff/config/prodiff.yaml b/modules/ProDiff/config/prodiff.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..64ee8eb61c97f2453726750e58cba71f08fef532
--- /dev/null
+++ b/modules/ProDiff/config/prodiff.yaml
@@ -0,0 +1,16 @@
+base_config:
+  - ./base.yaml
+
+raw_data_dir: 'data/raw/LJSpeech'
+processed_data_dir: 'data/processed/LJSpeech'
+binary_data_dir: 'data/binary/LJSpeech'
+
+task_cls: modules.ProDiff.task.ProDiff_task.ProDiff_Task
+
+
+# diffusion
+timesteps: 4
+teacher_ckpt: checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt
+diff_decoder_type: 'wavenet'
+schedule_type: 'vpsde'
+
diff --git a/modules/ProDiff/config/prodiff_teacher.yaml b/modules/ProDiff/config/prodiff_teacher.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e2e38cfc4a2a8895783664203098a94878688bf
--- /dev/null
+++ b/modules/ProDiff/config/prodiff_teacher.yaml
@@ -0,0 +1,13 @@
+base_config:
+  - ./base.yaml
+
+raw_data_dir: 'data/raw/LJSpeech'
+processed_data_dir: 'data/processed/LJSpeech'
+binary_data_dir: 'data/binary/LJSpeech'
+
+task_cls: modules.ProDiff.task.ProDiff_teacher_task.ProDiff_teacher_Task
+
+# diffusion
+timesteps: 4
+timescale: 1
+schedule_type: 'vpsde'
diff --git a/modules/ProDiff/model/ProDiff.py b/modules/ProDiff/model/ProDiff.py
new file mode 100644
index 0000000000000000000000000000000000000000..4caba78873c81691e92014c099bf9a36d0ba076b
--- /dev/null
+++ b/modules/ProDiff/model/ProDiff.py
@@ -0,0 +1,210 @@
+import math
+import random
+from functools import partial
+from usr.diff.shallow_diffusion_tts import *
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from tqdm import tqdm
+from einops import rearrange
+
+from modules.fastspeech.fs2 import FastSpeech2
+from utils.hparams import hparams
+
+
+class GaussianDiffusion(nn.Module):
+    def __init__(self, phone_encoder, out_dims, denoise_fn, teacher_steps=4,
+                 timesteps=4, time_scale=1, loss_type='l1', betas=None, spec_min=None, spec_max=None):
+        super().__init__()
+        self.denoise_fn = denoise_fn
+        self.fs2 = FastSpeech2(phone_encoder, out_dims)
+        self.fs2.decoder = None
+        self.mel_bins = out_dims
+
+        if exists(betas):
+            betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
+        else:
+            betas = get_noise_schedule_list(
+                schedule_mode=hparams['schedule_type'],
+                timesteps=teacher_steps + 1,
+                min_beta=0.1,
+                max_beta=40,
+                s=0.008,
+            )
+
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        self.time_scale = time_scale
+        self.num_timesteps = int(timesteps)
+        self.loss_type = loss_type
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))      # beta
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) # alphacum_t
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # alphacum_{t-1}
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
+        self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
+
+    def q_mean_variance(self, x_start, t):
+        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def q_posterior_sample(self, x_start, x_t, t, repeat_noise=False):
+        b, *_, device = *x_start.shape, x_start.device
+        model_mean, _, model_log_variance = self.q_posterior(x_start=x_start, x_t=x_t, t=t)
+        noise = noise_like(x_start.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_start.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample(self, x_t, t, cond, spk_emb=None, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x_t.shape, x_t.device
+        x_0_pred = self.denoise_fn(x_t, t, cond)
+
+        return self.q_posterior_sample(x_start=x_0_pred, x_t=x_t, t=t)
+
+    def sample_q(self, x_0, ts, epsilon=None):
+        """
+        Sample from q(x_t | x_0) for a batch of x_0.
+        """
+        alpha, sigma = self.get_schedule(x_0, ts)
+        return alpha * x_0 + sigma * epsilon
+
+    @torch.no_grad()
+    def p_sample_ddim(self, x_t, t, cond):
+        b, *_, device = *x_t.shape, x_t.device
+        x_0_pred = self.denoise_fn(x_t, t, cond)
+        alpha, sigma = self.get_schedule(x_t, t)
+        eps = (x_t - x_0_pred * alpha) / sigma
+        return self.sample_q(x_0_pred, t-self.time_scale, eps)
+
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+
+        return (
+                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+        )
+
+    def get_schedule(self, x_t, t):
+        return extract(self.sqrt_alphas_cumprod, t, x_t.shape), extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
+
+    def diffuse_fn(self, x_start, t, noise=None):
+        x_start = self.norm_spec(x_start)
+        x_start = x_start.transpose(1, 2)[:, None, :, :]  # [B, 1, M, T]
+        zero_idx = t < 0 # for items where t is -1
+        t[zero_idx] = 0
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        out = self.q_sample(x_start=x_start, t=t, noise=noise)
+        out[zero_idx] = x_start[zero_idx] # set x_{-1} as the gt mel
+        return out
+
+    def forward(self, txt_tokens, teacher_fn=None, mel2ph=None, spk_embed=None,
+                ref_mels=None, f0=None, uv=None, energy=None, infer=False):
+        b, *_, device = *txt_tokens.shape, txt_tokens.device
+        ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
+                       skip_decoder=True, infer=infer)
+        cond = ret['decoder_inp'].transpose(1, 2)
+        if not infer:
+            with torch.no_grad():
+                t = self.time_scale * torch.randint(1, self.num_timesteps+1, (b,), device=device).long() # [2, 4]
+                nonpadding = (mel2ph != 0).float().unsqueeze(1).unsqueeze(1)
+                noise = default(None, lambda: torch.randn_like(ref_mels.transpose(1, 2)[:, None, :, :]))
+
+                # Diffusion
+                x_t = self.diffuse_fn(ref_mels, t, noise) * nonpadding
+
+                # 2 steps of DDIM
+                x0_pred = teacher_fn.denoise_fn(x_t, t, cond) * nonpadding  # p(x_0|x_t,t) correct
+                alpha, sigma = self.get_schedule(x_t, t)
+                alpha_pre, sigma_pre = self.get_schedule(x_t, t - self.time_scale // 2)
+                alpha_pre_pre, sigma_pre_pre = self.get_schedule(x_t, t - self.time_scale)
+                x_t_pre = alpha_pre * x0_pred + sigma_pre / sigma * (x_t - alpha * x0_pred)  # correct
+                x0_pred1 = teacher_fn.denoise_fn(x_t_pre, t - self.time_scale // 2, cond) * nonpadding  # correct
+                x_t_pre_pre = alpha_pre_pre * x0_pred1 + sigma_pre_pre / sigma_pre * (
+                            x_t_pre - alpha_pre * x0_pred1)  # correct
+                x_target = (x_t_pre_pre - (sigma_pre_pre / sigma) * x_t) / (alpha_pre_pre - sigma_pre_pre / sigma * alpha) * nonpadding
+
+            x_pred = self.denoise_fn(x_t, t - self.time_scale, cond) * nonpadding  # student [0, 1]: 8 steps correct
+            x_t_prev = self.diffuse_fn(ref_mels, t - self.time_scale - 1, noise) * nonpadding  # teacher [-1, 1]
+            x_t_prev_pred = self.q_posterior_sample(x_pred, x_t, t - self.time_scale) * nonpadding # [-1, 1] p(x_t-1|x_t,x_0,t)
+
+            if self.loss_type == 'l1':
+                if nonpadding is not None:  # [B, T]
+                    loss = ((x_pred - x_target).abs() * nonpadding).mean()  # [B, B, M, T].mean()
+                else:
+                    # print('are you sure w/o nonpadding?')
+                    loss = (x_pred - x_target).abs().mean()
+
+            elif self.loss_type == 'l2':
+                loss = F.mse_loss(x_pred, x_target)
+            else:
+                raise NotImplementedError()
+
+            ret['mel_out'] = loss # [B, T, 80]
+            ret['x_t'] = x_t[:, 0].transpose(1, 2)
+            ret['x_t_prev'] = x_t_prev[:, 0].transpose(1, 2)
+            ret['x_t_prev_pred'] = x_t_prev_pred[:, 0].transpose(1, 2)
+            ret['t'] = t
+        else:
+            shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+            x = torch.randn(shape, device=device)  # noise
+            sample_steps = [self.time_scale * i for i in range(0, self.num_timesteps)]
+            for i in tqdm(reversed(sample_steps), desc='ProDiff sample time step', total=len(sample_steps)):
+                x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)  # x(mel), t, condition(phoneme)
+            x = x[:, 0].transpose(1, 2)
+            # p_sample: 0.1805
+            ret['mel_out'] = self.denorm_spec(x)  # 去除norm
+        return ret
+
+
+    def norm_spec(self, x):
+        return x
+
+    def denorm_spec(self, x):
+        return x
+
+    def out2mel(self, x):
+        return x
\ No newline at end of file
diff --git a/modules/ProDiff/model/ProDiff_teacher.py b/modules/ProDiff/model/ProDiff_teacher.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa93b0976c672eef9771999962e8204e7668e9db
--- /dev/null
+++ b/modules/ProDiff/model/ProDiff_teacher.py
@@ -0,0 +1,190 @@
+import math
+import random
+from functools import partial
+from usr.diff.shallow_diffusion_tts import *
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from tqdm import tqdm
+from einops import rearrange
+
+from modules.fastspeech.fs2 import FastSpeech2
+from utils.hparams import hparams
+
+
+
+class GaussianDiffusion(nn.Module):
+    def __init__(self, phone_encoder, out_dims, denoise_fn,
+                 timesteps=1000, time_scale=1, loss_type='l1', betas=None, spec_min=None, spec_max=None):
+        super().__init__()
+        self.denoise_fn = denoise_fn
+        self.fs2 = FastSpeech2(phone_encoder, out_dims)
+        self.fs2.decoder = None
+        self.mel_bins = out_dims
+
+        if exists(betas):
+            betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
+        else:
+            betas = get_noise_schedule_list(
+                schedule_mode=hparams['schedule_type'],
+                timesteps=timesteps + 1,
+                min_beta=0.1,
+                max_beta=40,
+                s=0.008,
+            )
+
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        self.time_scale = time_scale
+        self.num_timesteps = int(timesteps)
+        self.loss_type = loss_type
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('timesteps', to_torch(self.num_timesteps))      # beta
+        self.register_buffer('timescale', to_torch(self.time_scale))      # beta
+        self.register_buffer('betas', to_torch(betas))      # beta
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) # alphacum_t
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # alphacum_{t-1}
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
+        self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
+
+    def q_mean_variance(self, x_start, t):
+        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def q_posterior_sample(self, x_start, x_t, t, repeat_noise=False):
+        b, *_, device = *x_start.shape, x_start.device
+        model_mean, _, model_log_variance = self.q_posterior(x_start=x_start, x_t=x_t, t=t)
+        noise = noise_like(x_start.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_start.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample(self, x_t, t, cond, spk_emb=None, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x_t.shape, x_t.device
+        x_0_pred = self.denoise_fn(x_t, t, cond)
+
+        return self.q_posterior_sample(x_start=x_0_pred, x_t=x_t, t=t)
+
+    @torch.no_grad()
+    def interpolate(self, x1, x2, t, cond, spk_emb, lam=0.5):
+        b, *_, device = *x1.shape, x1.device
+        t = default(t, self.num_timesteps - 1)
+
+        assert x1.shape == x2.shape
+
+        t_batched = torch.stack([torch.tensor(t, device=device)] * b)
+        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
+
+        x = (1 - lam) * xt1 + lam * xt2
+        for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
+            x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond, spk_emb)
+        x = x[:, 0].transpose(1, 2)
+        return self.denorm_spec(x)
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+
+        return (
+                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+        )
+
+    def diffuse_trace(self, x_start, mask):
+        b, *_, device = *x_start.shape, x_start.device
+        trace = [self.norm_spec(x_start).clamp_(-1., 1.) * ~mask.unsqueeze(-1)]
+        for t in range(self.num_timesteps):
+            t = torch.full((b,), t, device=device, dtype=torch.long)
+            trace.append(
+                self.diffuse_fn(x_start, t)[:, 0].transpose(1, 2) * ~mask.unsqueeze(-1)
+            )
+        return trace
+
+    def diffuse_fn(self, x_start, t, noise=None):
+        x_start = self.norm_spec(x_start)
+        x_start = x_start.transpose(1, 2)[:, None, :, :]  # [B, 1, M, T]
+        zero_idx = t < 0 # for items where t is -1
+        t[zero_idx] = 0
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        out = self.q_sample(x_start=x_start, t=t, noise=noise)
+        out[zero_idx] = x_start[zero_idx] # set x_{-1} as the gt mel
+        return out
+
+    def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
+                ref_mels=None, f0=None, uv=None, energy=None, infer=False):
+        b, *_, device = *txt_tokens.shape, txt_tokens.device
+        ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
+                       skip_decoder=True, infer=infer)
+        nonpadding = (ret['mel2ph'] != 0).float().unsqueeze(1).unsqueeze(1) # [B, T]
+        cond = ret['decoder_inp'].transpose(1, 2)
+        if not infer:
+            t = torch.randint(0, self.num_timesteps + 1, (b,), device=device).long()
+            # Diffusion
+            x_t = self.diffuse_fn(ref_mels, t) * nonpadding
+
+            # Predict x_{start}
+            x_0_pred = self.denoise_fn(x_t, t, cond) * nonpadding
+
+            ret['mel_out'] = x_0_pred[:, 0].transpose(1, 2) # [B, T, 80]
+        else:
+            t = self.num_timesteps  # reverse总步数
+            shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+            x = torch.randn(shape, device=device)  # noise
+            for i in tqdm(reversed(range(0, t)), desc='ProDiff sample time step', total=t):
+                x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)  # x(mel), t, condition(phoneme)
+            x = x[:, 0].transpose(1, 2)
+            ret['mel_out'] = self.denorm_spec(x)  # 去除norm
+        return ret
+
+    def norm_spec(self, x):
+        return x
+
+    def denorm_spec(self, x):
+        return x
+
+    def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
+        return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
+
+    def out2mel(self, x):
+        return x
diff --git a/modules/ProDiff/task/ProDiff_task.py b/modules/ProDiff/task/ProDiff_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..795752e4414a8d28dd12ecf020465f9f299b6d0f
--- /dev/null
+++ b/modules/ProDiff/task/ProDiff_task.py
@@ -0,0 +1,137 @@
+import torch
+from torch import nn
+import utils
+from functools import partial
+from utils.hparams import hparams
+from modules.ProDiff.model.ProDiff import GaussianDiffusion
+from usr.diff.net import DiffNet
+from tasks.tts.fs2 import FastSpeech2Task
+from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
+from utils.pitch_utils import denorm_f0
+from tasks.tts.fs2_utils import FastSpeechDataset
+DIFF_DECODERS = {
+    'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
+}
+
+
+class ProDiff_Task(FastSpeech2Task):
+    def __init__(self):
+        super(ProDiff_Task, self).__init__()
+        self.dataset_cls = FastSpeechDataset
+        self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
+
+    def build_model(self):
+        self.build_tts_model()
+        if hparams['load_ckpt'] != '':
+            self.load_ckpt(hparams['load_ckpt'], strict=False)
+        utils.num_params(self.model, print_out=True, model_name="Generator: student")
+        utils.num_params(self.teacher, print_out=True, model_name="Generator: teacher")
+        if not hasattr(self, 'gen_params'):
+            self.gen_params = list(self.model.parameters())
+        return self.model
+
+    def build_tts_model(self):
+        mel_bins = hparams['audio_num_mel_bins']
+        checkpoint = torch.load(hparams['teacher_ckpt'], map_location='cpu')["state_dict"]['model']
+        teacher_timesteps = int(checkpoint['timesteps'].item())
+        teacher_timescales = int(checkpoint['timescale'].item())
+        student_timesteps = teacher_timesteps // 2
+        student_timescales = teacher_timescales * 2
+
+        self.teacher = GaussianDiffusion(
+            phone_encoder=self.phone_encoder,
+            out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
+            loss_type=hparams['diff_loss_type'],
+            timesteps=teacher_timesteps, time_scale=teacher_timescales,
+            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+        )
+        self.model = GaussianDiffusion(
+            phone_encoder=self.phone_encoder,
+            out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
+            timesteps=student_timesteps, time_scale=student_timescales,
+            loss_type=hparams['diff_loss_type'],
+            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+        )
+
+        utils.load_ckpt(self.teacher, hparams['teacher_ckpt'], 'model', strict=False)
+        utils.load_ckpt(self.model, hparams['teacher_ckpt'], 'model', strict=False)
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+        self.model.num_timesteps = student_timesteps
+        self.model.time_scale = student_timescales
+        self.model.register_buffer('timesteps', to_torch(student_timesteps))      # beta
+        self.model.register_buffer('timescale', to_torch(student_timescales))      # beta
+
+        for k, v in self.model.fs2.named_parameters():
+            if not 'denoise_fn' in k:
+                v.requires_grad = False
+
+        for param in self.teacher.parameters():
+            param.requires_grad = False
+
+
+    def run_model(self, model, sample, return_output=False, infer=False):
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+        target = sample['mels']  # [B, T_s, 80]
+        # mel2ph = sample['mel2ph'] if hparams['use_gt_dur'] else None # [B, T_s]
+        mel2ph = sample['mel2ph']
+        f0 = sample['f0']
+        uv = sample['uv']
+        energy = sample['energy']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        output = model(txt_tokens, self.teacher, mel2ph=mel2ph, spk_embed=spk_embed,
+                       ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
+
+        losses = {}
+        losses['l1'] = output['mel_out']
+        self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
+        if hparams['use_pitch_embed']:
+            self.add_pitch_loss(output, sample, losses)
+        if hparams['use_energy_embed']:
+            self.add_energy_loss(output['energy_pred'], energy, losses)
+        if not return_output:
+            return losses
+        else:
+            return losses, output
+
+    def validation_step(self, sample, batch_idx):
+        outputs = {}
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+
+        energy = sample['energy']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        mel2ph = sample['mel2ph']
+        f0 = sample['f0']
+        uv = sample['uv']
+
+        outputs['losses'] = {}
+        outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
+
+        outputs['total_loss'] = sum(outputs['losses'].values())
+        outputs['nsamples'] = sample['nsamples']
+        outputs = utils.tensors_to_scalars(outputs)
+        if batch_idx < hparams['num_valid_plots']:
+            # model_out = self.model(
+            #     txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, inference=True)
+            # self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}')
+            model_out = self.model(
+                txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True)
+            gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
+            self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm'))
+            self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'])
+        return outputs
+
+
+    ############
+    # validation plots
+    ############
+    def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None):
+        gt_wav = gt_wav[0].cpu().numpy()
+        wav_out = wav_out[0].cpu().numpy()
+        gt_f0 = gt_f0[0].cpu().numpy()
+        f0 = f0[0].cpu().numpy()
+        if is_mel:
+            gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0)
+            wav_out = self.vocoder.spec2wav(wav_out, f0=f0)
+        self.logger.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+        self.logger.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+
diff --git a/modules/ProDiff/task/ProDiff_teacher_task.py b/modules/ProDiff/task/ProDiff_teacher_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e4f89645ce09c0486c18429b0571b5bf6605b79
--- /dev/null
+++ b/modules/ProDiff/task/ProDiff_teacher_task.py
@@ -0,0 +1,101 @@
+import torch
+
+import utils
+from utils.hparams import hparams
+from modules.ProDiff.model.ProDiff_teacher import GaussianDiffusion
+from usr.diff.net import DiffNet
+from tasks.tts.fs2 import FastSpeech2Task
+from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
+from utils.pitch_utils import denorm_f0
+from tasks.tts.fs2_utils import FastSpeechDataset
+
+DIFF_DECODERS = {
+    'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
+}
+
+
+class ProDiff_teacher_Task(FastSpeech2Task):
+    def __init__(self):
+        super(ProDiff_teacher_Task, self).__init__()
+        self.dataset_cls = FastSpeechDataset
+        self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
+
+    def build_model(self):
+        self.build_tts_model()
+        utils.num_params(self.model)
+        return self.model
+
+    def build_tts_model(self):
+        self.model = GaussianDiffusion(
+            phone_encoder=self.phone_encoder,
+            out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
+            timesteps=hparams['timesteps'], time_scale=hparams['timescale'],
+            loss_type=hparams['diff_loss_type'],
+            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+        )
+
+
+    def run_model(self, model, sample, return_output=False, infer=False):
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+        target = sample['mels']  # [B, T_s, 80]
+        mel2ph = sample['mel2ph']
+        f0 = sample['f0']
+        uv = sample['uv']
+        energy = sample['energy']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
+                       ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
+
+        losses = {}
+        self.add_mel_loss(output['mel_out'], target, losses)
+        self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
+        if hparams['use_pitch_embed']:
+            self.add_pitch_loss(output, sample, losses)
+        if hparams['use_energy_embed']:
+            self.add_energy_loss(output['energy_pred'], energy, losses)
+        if not return_output:
+            return losses
+        else:
+            return losses, output
+
+    def validation_step(self, sample, batch_idx):
+        outputs = {}
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+
+        energy = sample['energy']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        mel2ph = sample['mel2ph']
+        f0 = sample['f0']
+        uv = sample['uv']
+
+        outputs['losses'] = {}
+        outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
+
+        outputs['total_loss'] = sum(outputs['losses'].values())
+        outputs['nsamples'] = sample['nsamples']
+        outputs = utils.tensors_to_scalars(outputs)
+        if batch_idx < hparams['num_valid_plots']:
+            # model_out = self.model(
+            #     txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, inference=True)
+            # self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}')
+            model_out = self.model(
+                txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True)
+            gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
+            self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm'))
+            self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'])
+        return outputs
+
+    ############
+    # validation plots
+    ############
+    def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None):
+        gt_wav = gt_wav[0].cpu().numpy()
+        wav_out = wav_out[0].cpu().numpy()
+        gt_f0 = gt_f0[0].cpu().numpy()
+        f0 = f0[0].cpu().numpy()
+        if is_mel:
+            gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0)
+            wav_out = self.vocoder.spec2wav(wav_out, f0=f0)
+        self.logger.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+        self.logger.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py
new file mode 100755
index 0000000000000000000000000000000000000000..fe8c664acd66ddc737ccc38b56d8bb077d636bf2
--- /dev/null
+++ b/modules/commons/common_layers.py
@@ -0,0 +1,971 @@
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter
+import torch.onnx.operators
+import torch.nn.functional as F
+from utils.tts_utils import make_positions, softmax, get_incremental_state, set_incremental_state
+
+
+class Reshape(nn.Module):
+    def __init__(self, *args):
+        super(Reshape, self).__init__()
+        self.shape = args
+
+    def forward(self, x):
+        return x.view(self.shape)
+
+
+class Permute(nn.Module):
+    def __init__(self, *args):
+        super(Permute, self).__init__()
+        self.args = args
+
+    def forward(self, x):
+        return x.permute(self.args)
+
+
+class LinearNorm(torch.nn.Module):
+    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+        super(LinearNorm, self).__init__()
+        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+        torch.nn.init.xavier_uniform_(
+            self.linear_layer.weight,
+            gain=torch.nn.init.calculate_gain(w_init_gain))
+
+    def forward(self, x):
+        return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
+        super(ConvNorm, self).__init__()
+        if padding is None:
+            assert (kernel_size % 2 == 1)
+            padding = int(dilation * (kernel_size - 1) / 2)
+
+        self.conv = torch.nn.Conv1d(in_channels, out_channels,
+                                    kernel_size=kernel_size, stride=stride,
+                                    padding=padding, dilation=dilation,
+                                    bias=bias)
+
+        torch.nn.init.xavier_uniform_(
+            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+    def forward(self, signal):
+        conv_signal = self.conv(signal)
+        return conv_signal
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None):
+    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+    if padding_idx is not None:
+        nn.init.constant_(m.weight[padding_idx], 0)
+    return m
+
+
+class GroupNorm1DTBC(nn.GroupNorm):
+    def forward(self, input):
+        return super(GroupNorm1DTBC, self).forward(input.permute(1, 2, 0)).permute(2, 0, 1)
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+    if not export and torch.cuda.is_available():
+        try:
+            from apex.normalization import FusedLayerNorm
+            return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+        except ImportError:
+            pass
+    return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+def Linear(in_features, out_features, bias=True):
+    m = nn.Linear(in_features, out_features, bias)
+    nn.init.xavier_uniform_(m.weight)
+    if bias:
+        nn.init.constant_(m.bias, 0.)
+    return m
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+    """This module produces sinusoidal positional embeddings of any length.
+
+    Padding symbols are ignored.
+    """
+
+    def __init__(self, embedding_dim, padding_idx, init_size=1024):
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.padding_idx = padding_idx
+        self.weights = SinusoidalPositionalEmbedding.get_embedding(
+            init_size,
+            embedding_dim,
+            padding_idx,
+        )
+        self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+    @staticmethod
+    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+        """Build sinusoidal embeddings.
+
+        This matches the implementation in tensor2tensor, but differs slightly
+        from the description in Section 3.5 of "Attention Is All You Need".
+        """
+        half_dim = embedding_dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+        if embedding_dim % 2 == 1:
+            # zero pad
+            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+        if padding_idx is not None:
+            emb[padding_idx, :] = 0
+        return emb
+
+    def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+        """Input is expected to be of size [bsz x seqlen]."""
+        bsz, seq_len = input.shape[:2]
+        max_pos = self.padding_idx + 1 + seq_len
+        if self.weights is None or max_pos > self.weights.size(0):
+            # recompute/expand embeddings if needed
+            self.weights = SinusoidalPositionalEmbedding.get_embedding(
+                max_pos,
+                self.embedding_dim,
+                self.padding_idx,
+            )
+        self.weights = self.weights.to(self._float_tensor)
+
+        if incremental_state is not None:
+            # positions is the same for every token when decoding a single step
+            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+            return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+        positions = make_positions(input, self.padding_idx) if positions is None else positions
+        return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+    def max_positions(self):
+        """Maximum number of supported positions."""
+        return int(1e5)  # an arbitrary large number
+
+
+class ConvTBC(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
+        super(ConvTBC, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.padding = padding
+
+        self.weight = torch.nn.Parameter(torch.Tensor(
+            self.kernel_size, in_channels, out_channels))
+        self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
+
+    def forward(self, input):
+        return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
+
+
+class MultiheadAttention(nn.Module):
+    def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+                 add_bias_kv=False, add_zero_attn=False, self_attention=False,
+                 encoder_decoder_attention=False):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.kdim = kdim if kdim is not None else embed_dim
+        self.vdim = vdim if vdim is not None else embed_dim
+        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+        self.scaling = self.head_dim ** -0.5
+
+        self.self_attention = self_attention
+        self.encoder_decoder_attention = encoder_decoder_attention
+
+        assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+                                                             'value to be of the same size'
+
+        if self.qkv_same_dim:
+            self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
+        else:
+            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+        if bias:
+            self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+        else:
+            self.register_parameter('in_proj_bias', None)
+
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+        if add_bias_kv:
+            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+        else:
+            self.bias_k = self.bias_v = None
+
+        self.add_zero_attn = add_zero_attn
+
+        self.reset_parameters()
+
+        self.enable_torch_version = False
+        if hasattr(F, "multi_head_attention_forward"):
+            self.enable_torch_version = True
+        else:
+            self.enable_torch_version = False
+        self.last_attn_probs = None
+
+    def reset_parameters(self):
+        if self.qkv_same_dim:
+            nn.init.xavier_uniform_(self.in_proj_weight)
+        else:
+            nn.init.xavier_uniform_(self.k_proj_weight)
+            nn.init.xavier_uniform_(self.v_proj_weight)
+            nn.init.xavier_uniform_(self.q_proj_weight)
+
+        nn.init.xavier_uniform_(self.out_proj.weight)
+        if self.in_proj_bias is not None:
+            nn.init.constant_(self.in_proj_bias, 0.)
+            nn.init.constant_(self.out_proj.bias, 0.)
+        if self.bias_k is not None:
+            nn.init.xavier_normal_(self.bias_k)
+        if self.bias_v is not None:
+            nn.init.xavier_normal_(self.bias_v)
+
+    def forward(
+            self,
+            query, key, value,
+            key_padding_mask=None,
+            incremental_state=None,
+            need_weights=True,
+            static_kv=False,
+            attn_mask=None,
+            before_softmax=False,
+            need_head_weights=False,
+            enc_dec_attn_constraint_mask=None,
+            reset_attn_weight=None
+    ):
+        """Input shape: Time x Batch x Channel
+
+        Args:
+            key_padding_mask (ByteTensor, optional): mask to exclude
+                keys that are pads, of shape `(batch, src_len)`, where
+                padding elements are indicated by 1s.
+            need_weights (bool, optional): return the attention weights,
+                averaged over heads (default: False).
+            attn_mask (ByteTensor, optional): typically used to
+                implement causal attention, where the mask prevents the
+                attention from looking forward in time (default: None).
+            before_softmax (bool, optional): return the raw attention
+                weights and values before the attention softmax.
+            need_head_weights (bool, optional): return the attention
+                weights for each head. Implies *need_weights*. Default:
+                return the average attention weights over all heads.
+        """
+        if need_head_weights:
+            need_weights = True
+
+        tgt_len, bsz, embed_dim = query.size()
+        assert embed_dim == self.embed_dim
+        assert list(query.size()) == [tgt_len, bsz, embed_dim]
+        if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+            if self.qkv_same_dim:
+                return F.multi_head_attention_forward(query, key, value,
+                                                      self.embed_dim, self.num_heads,
+                                                      self.in_proj_weight,
+                                                      self.in_proj_bias, self.bias_k, self.bias_v,
+                                                      self.add_zero_attn, self.dropout,
+                                                      self.out_proj.weight, self.out_proj.bias,
+                                                      self.training, key_padding_mask, need_weights,
+                                                      attn_mask)
+            else:
+                return F.multi_head_attention_forward(query, key, value,
+                                                      self.embed_dim, self.num_heads,
+                                                      torch.empty([0]),
+                                                      self.in_proj_bias, self.bias_k, self.bias_v,
+                                                      self.add_zero_attn, self.dropout,
+                                                      self.out_proj.weight, self.out_proj.bias,
+                                                      self.training, key_padding_mask, need_weights,
+                                                      attn_mask, use_separate_proj_weight=True,
+                                                      q_proj_weight=self.q_proj_weight,
+                                                      k_proj_weight=self.k_proj_weight,
+                                                      v_proj_weight=self.v_proj_weight)
+
+        if incremental_state is not None:
+            saved_state = self._get_input_buffer(incremental_state)
+            if 'prev_key' in saved_state:
+                # previous time steps are cached - no need to recompute
+                # key and value if they are static
+                if static_kv:
+                    assert self.encoder_decoder_attention and not self.self_attention
+                    key = value = None
+        else:
+            saved_state = None
+
+        if self.self_attention:
+            # self-attention
+            q, k, v = self.in_proj_qkv(query)
+        elif self.encoder_decoder_attention:
+            # encoder-decoder attention
+            q = self.in_proj_q(query)
+            if key is None:
+                assert value is None
+                k = v = None
+            else:
+                k = self.in_proj_k(key)
+                v = self.in_proj_v(key)
+
+        else:
+            q = self.in_proj_q(query)
+            k = self.in_proj_k(key)
+            v = self.in_proj_v(value)
+        q *= self.scaling
+
+        if self.bias_k is not None:
+            assert self.bias_v is not None
+            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+            if attn_mask is not None:
+                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+            if key_padding_mask is not None:
+                key_padding_mask = torch.cat(
+                    [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+        if k is not None:
+            k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+        if v is not None:
+            v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+        if saved_state is not None:
+            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+            if 'prev_key' in saved_state:
+                prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+                if static_kv:
+                    k = prev_key
+                else:
+                    k = torch.cat((prev_key, k), dim=1)
+            if 'prev_value' in saved_state:
+                prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+                if static_kv:
+                    v = prev_value
+                else:
+                    v = torch.cat((prev_value, v), dim=1)
+            if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
+                prev_key_padding_mask = saved_state['prev_key_padding_mask']
+                if static_kv:
+                    key_padding_mask = prev_key_padding_mask
+                else:
+                    key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
+
+            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
+            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
+            saved_state['prev_key_padding_mask'] = key_padding_mask
+
+            self._set_input_buffer(incremental_state, saved_state)
+
+        src_len = k.size(1)
+
+        # This is part of a workaround to get around fork/join parallelism
+        # not supporting Optional types.
+        if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+            key_padding_mask = None
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz
+            assert key_padding_mask.size(1) == src_len
+
+        if self.add_zero_attn:
+            src_len += 1
+            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+            if attn_mask is not None:
+                attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+            if key_padding_mask is not None:
+                key_padding_mask = torch.cat(
+                    [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+        attn_weights = torch.bmm(q, k.transpose(1, 2))
+        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+        if attn_mask is not None:
+            if len(attn_mask.shape) == 2:
+                attn_mask = attn_mask.unsqueeze(0)
+            elif len(attn_mask.shape) == 3:
+                attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+                    bsz * self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights + attn_mask
+
+        if enc_dec_attn_constraint_mask is not None:  # bs x head x L_kv
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.masked_fill(
+                enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+                -1e8,
+            )
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if key_padding_mask is not None:
+            # don't attend to padding symbols
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.masked_fill(
+                key_padding_mask.unsqueeze(1).unsqueeze(2),
+                -1e8,
+            )
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+        if before_softmax:
+            return attn_weights, v
+
+        attn_weights_float = softmax(attn_weights, dim=-1)
+        attn_weights = attn_weights_float.type_as(attn_weights)
+        attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+        if reset_attn_weight is not None:
+            if reset_attn_weight:
+                self.last_attn_probs = attn_probs.detach()
+            else:
+                assert self.last_attn_probs is not None
+                attn_probs = self.last_attn_probs
+        attn = torch.bmm(attn_probs, v)
+        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        attn = self.out_proj(attn)
+
+        if need_weights:
+            attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+            if not need_head_weights:
+                # average attention weights over heads
+                attn_weights = attn_weights.mean(dim=0)
+        else:
+            attn_weights = None
+
+        return attn, (attn_weights, attn_logits)
+
+    def in_proj_qkv(self, query):
+        return self._in_proj(query).chunk(3, dim=-1)
+
+    def in_proj_q(self, query):
+        if self.qkv_same_dim:
+            return self._in_proj(query, end=self.embed_dim)
+        else:
+            bias = self.in_proj_bias
+            if bias is not None:
+                bias = bias[:self.embed_dim]
+            return F.linear(query, self.q_proj_weight, bias)
+
+    def in_proj_k(self, key):
+        if self.qkv_same_dim:
+            return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+        else:
+            weight = self.k_proj_weight
+            bias = self.in_proj_bias
+            if bias is not None:
+                bias = bias[self.embed_dim:2 * self.embed_dim]
+            return F.linear(key, weight, bias)
+
+    def in_proj_v(self, value):
+        if self.qkv_same_dim:
+            return self._in_proj(value, start=2 * self.embed_dim)
+        else:
+            weight = self.v_proj_weight
+            bias = self.in_proj_bias
+            if bias is not None:
+                bias = bias[2 * self.embed_dim:]
+            return F.linear(value, weight, bias)
+
+    def _in_proj(self, input, start=0, end=None):
+        weight = self.in_proj_weight
+        bias = self.in_proj_bias
+        weight = weight[start:end, :]
+        if bias is not None:
+            bias = bias[start:end]
+        return F.linear(input, weight, bias)
+
+    def _get_input_buffer(self, incremental_state):
+        return get_incremental_state(
+            self,
+            incremental_state,
+            'attn_state',
+        ) or {}
+
+    def _set_input_buffer(self, incremental_state, buffer):
+        set_incremental_state(
+            self,
+            incremental_state,
+            'attn_state',
+            buffer,
+        )
+
+    def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+        return attn_weights
+
+    def clear_buffer(self, incremental_state=None):
+        if incremental_state is not None:
+            saved_state = self._get_input_buffer(incremental_state)
+            if 'prev_key' in saved_state:
+                del saved_state['prev_key']
+            if 'prev_value' in saved_state:
+                del saved_state['prev_value']
+            self._set_input_buffer(incremental_state, saved_state)
+
+
+class Swish(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, i):
+        result = i * torch.sigmoid(i)
+        ctx.save_for_backward(i)
+        return result
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        i = ctx.saved_variables[0]
+        sigmoid_i = torch.sigmoid(i)
+        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class CustomSwish(nn.Module):
+    def forward(self, input_tensor):
+        return Swish.apply(input_tensor)
+
+
+class TransformerFFNLayer(nn.Module):
+    def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+        super().__init__()
+        self.kernel_size = kernel_size
+        self.dropout = dropout
+        self.act = act
+        if padding == 'SAME':
+            self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
+        elif padding == 'LEFT':
+            self.ffn_1 = nn.Sequential(
+                nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+                nn.Conv1d(hidden_size, filter_size, kernel_size)
+            )
+        self.ffn_2 = Linear(filter_size, hidden_size)
+        if self.act == 'swish':
+            self.swish_fn = CustomSwish()
+
+    def forward(self, x, incremental_state=None):
+        # x: T x B x C
+        if incremental_state is not None:
+            saved_state = self._get_input_buffer(incremental_state)
+            if 'prev_input' in saved_state:
+                prev_input = saved_state['prev_input']
+                x = torch.cat((prev_input, x), dim=0)
+            x = x[-self.kernel_size:]
+            saved_state['prev_input'] = x
+            self._set_input_buffer(incremental_state, saved_state)
+
+        x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+        x = x * self.kernel_size ** -0.5
+
+        if incremental_state is not None:
+            x = x[-1:]
+        if self.act == 'gelu':
+            x = F.gelu(x)
+        if self.act == 'relu':
+            x = F.relu(x)
+        if self.act == 'swish':
+            x = self.swish_fn(x)
+        x = F.dropout(x, self.dropout, training=self.training)
+        x = self.ffn_2(x)
+        return x
+
+    def _get_input_buffer(self, incremental_state):
+        return get_incremental_state(
+            self,
+            incremental_state,
+            'f',
+        ) or {}
+
+    def _set_input_buffer(self, incremental_state, buffer):
+        set_incremental_state(
+            self,
+            incremental_state,
+            'f',
+            buffer,
+        )
+
+    def clear_buffer(self, incremental_state):
+        if incremental_state is not None:
+            saved_state = self._get_input_buffer(incremental_state)
+            if 'prev_input' in saved_state:
+                del saved_state['prev_input']
+            self._set_input_buffer(incremental_state, saved_state)
+
+
+class BatchNorm1dTBC(nn.Module):
+    def __init__(self, c):
+        super(BatchNorm1dTBC, self).__init__()
+        self.bn = nn.BatchNorm1d(c)
+
+    def forward(self, x):
+        """
+
+        :param x: [T, B, C]
+        :return: [T, B, C]
+        """
+        x = x.permute(1, 2, 0)  # [B, C, T]
+        x = self.bn(x)  # [B, C, T]
+        x = x.permute(2, 0, 1)  # [T, B, C]
+        return x
+
+
+class EncSALayer(nn.Module):
+    def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
+                 relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
+        super().__init__()
+        self.c = c
+        self.dropout = dropout
+        self.num_heads = num_heads
+        if num_heads > 0:
+            if norm == 'ln':
+                self.layer_norm1 = LayerNorm(c)
+            elif norm == 'bn':
+                self.layer_norm1 = BatchNorm1dTBC(c)
+            elif norm == 'gn':
+                self.layer_norm1 = GroupNorm1DTBC(8, c)
+            self.self_attn = MultiheadAttention(
+                self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
+        if norm == 'ln':
+            self.layer_norm2 = LayerNorm(c)
+        elif norm == 'bn':
+            self.layer_norm2 = BatchNorm1dTBC(c)
+        elif norm == 'gn':
+            self.layer_norm2 = GroupNorm1DTBC(8, c)
+        self.ffn = TransformerFFNLayer(
+            c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
+
+    def forward(self, x, encoder_padding_mask=None, **kwargs):
+        layer_norm_training = kwargs.get('layer_norm_training', None)
+        if layer_norm_training is not None:
+            self.layer_norm1.training = layer_norm_training
+            self.layer_norm2.training = layer_norm_training
+        if self.num_heads > 0:
+            residual = x
+            x = self.layer_norm1(x)
+            x, _, = self.self_attn(
+                query=x,
+                key=x,
+                value=x,
+                key_padding_mask=encoder_padding_mask
+            )
+            x = F.dropout(x, self.dropout, training=self.training)
+            x = residual + x
+            x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+
+        residual = x
+        x = self.layer_norm2(x)
+        x = self.ffn(x)
+        x = F.dropout(x, self.dropout, training=self.training)
+        x = residual + x
+        x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+        return x
+
+
+class DecSALayer(nn.Module):
+    def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+                 kernel_size=9, act='gelu', norm='ln'):
+        super().__init__()
+        self.c = c
+        self.dropout = dropout
+        if norm == 'ln':
+            self.layer_norm1 = LayerNorm(c)
+        elif norm == 'gn':
+            self.layer_norm1 = GroupNorm1DTBC(8, c)
+        self.self_attn = MultiheadAttention(
+            c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+        )
+        if norm == 'ln':
+            self.layer_norm2 = LayerNorm(c)
+        elif norm == 'gn':
+            self.layer_norm2 = GroupNorm1DTBC(8, c)
+        self.encoder_attn = MultiheadAttention(
+            c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
+        )
+        if norm == 'ln':
+            self.layer_norm3 = LayerNorm(c)
+        elif norm == 'gn':
+            self.layer_norm3 = GroupNorm1DTBC(8, c)
+        self.ffn = TransformerFFNLayer(
+            c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+
+    def forward(
+            self,
+            x,
+            encoder_out=None,
+            encoder_padding_mask=None,
+            incremental_state=None,
+            self_attn_mask=None,
+            self_attn_padding_mask=None,
+            attn_out=None,
+            reset_attn_weight=None,
+            **kwargs,
+    ):
+        layer_norm_training = kwargs.get('layer_norm_training', None)
+        if layer_norm_training is not None:
+            self.layer_norm1.training = layer_norm_training
+            self.layer_norm2.training = layer_norm_training
+            self.layer_norm3.training = layer_norm_training
+        residual = x
+        x = self.layer_norm1(x)
+        x, _ = self.self_attn(
+            query=x,
+            key=x,
+            value=x,
+            key_padding_mask=self_attn_padding_mask,
+            incremental_state=incremental_state,
+            attn_mask=self_attn_mask
+        )
+        x = F.dropout(x, self.dropout, training=self.training)
+        x = residual + x
+
+        attn_logits = None
+        if encoder_out is not None or attn_out is not None:
+            residual = x
+            x = self.layer_norm2(x)
+        if encoder_out is not None:
+            x, attn = self.encoder_attn(
+                query=x,
+                key=encoder_out,
+                value=encoder_out,
+                key_padding_mask=encoder_padding_mask,
+                incremental_state=incremental_state,
+                static_kv=True,
+                enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+                                                                   'enc_dec_attn_constraint_mask'),
+                reset_attn_weight=reset_attn_weight
+            )
+            attn_logits = attn[1]
+        elif attn_out is not None:
+            x = self.encoder_attn.in_proj_v(attn_out)
+        if encoder_out is not None or attn_out is not None:
+            x = F.dropout(x, self.dropout, training=self.training)
+            x = residual + x
+
+        residual = x
+        x = self.layer_norm3(x)
+        x = self.ffn(x, incremental_state=incremental_state)
+        x = F.dropout(x, self.dropout, training=self.training)
+        x = residual + x
+        return x, attn_logits
+
+    def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+        self.encoder_attn.clear_buffer(incremental_state)
+        self.ffn.clear_buffer(incremental_state)
+
+    def set_buffer(self, name, tensor, incremental_state):
+        return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class ConvBlock(nn.Module):
+    def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
+        super().__init__()
+        self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
+        self.norm = norm
+        if self.norm == 'bn':
+            self.norm = nn.BatchNorm1d(n_chans)
+        elif self.norm == 'in':
+            self.norm = nn.InstanceNorm1d(n_chans, affine=True)
+        elif self.norm == 'gn':
+            self.norm = nn.GroupNorm(n_chans // 16, n_chans)
+        elif self.norm == 'ln':
+            self.norm = LayerNorm(n_chans // 16, n_chans)
+        elif self.norm == 'wn':
+            self.conv = torch.nn.utils.weight_norm(self.conv.conv)
+        self.dropout = nn.Dropout(dropout)
+        self.relu = nn.ReLU()
+
+    def forward(self, x):
+        """
+
+        :param x: [B, C, T]
+        :return: [B, C, T]
+        """
+        x = self.conv(x)
+        if not isinstance(self.norm, str):
+            if self.norm == 'none':
+                pass
+            elif self.norm == 'ln':
+                x = self.norm(x.transpose(1, 2)).transpose(1, 2)
+            else:
+                x = self.norm(x)
+        x = self.relu(x)
+        x = self.dropout(x)
+        return x
+
+
+class ConvStacks(nn.Module):
+    def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
+                 dropout=0, strides=None, res=True):
+        super().__init__()
+        self.conv = torch.nn.ModuleList()
+        self.kernel_size = kernel_size
+        self.res = res
+        self.in_proj = Linear(idim, n_chans)
+        if strides is None:
+            strides = [1] * n_layers
+        else:
+            assert len(strides) == n_layers
+        for idx in range(n_layers):
+            self.conv.append(ConvBlock(
+                n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
+        self.out_proj = Linear(n_chans, odim)
+
+    def forward(self, x, return_hiddens=False):
+        """
+
+        :param x: [B, T, H]
+        :return: [B, T, H]
+        """
+        x = self.in_proj(x)
+        x = x.transpose(1, -1)  # (B, idim, Tmax)
+        hiddens = []
+        for f in self.conv:
+            x_ = f(x)
+            x = x + x_ if self.res else x_  # (B, C, Tmax)
+            hiddens.append(x)
+        x = x.transpose(1, -1)
+        x = self.out_proj(x)  # (B, Tmax, H)
+        if return_hiddens:
+            hiddens = torch.stack(hiddens, 1)  # [B, L, C, T]
+            return x, hiddens
+        return x
+
+
+class ConvGlobalStacks(nn.Module):
+    def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', dropout=0,
+                 strides=[2, 2, 2, 2, 2]):
+        super().__init__()
+        self.conv = torch.nn.ModuleList()
+        self.pooling = torch.nn.ModuleList()
+        self.kernel_size = kernel_size
+        self.in_proj = Linear(idim, n_chans)
+        for idx in range(n_layers):
+            self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=strides[idx],
+                                       norm=norm, dropout=dropout))
+            self.pooling.append(nn.MaxPool1d(strides[idx]))
+        self.out_proj = Linear(n_chans, odim)
+
+    def forward(self, x):
+        """
+
+        :param x: [B, T, H]
+        :return: [B, T, H]
+        """
+        x = self.in_proj(x)
+        x = x.transpose(1, -1)  # (B, idim, Tmax)
+        for f, p in zip(self.conv, self.pooling):
+            x = f(x)  # (B, C, T)
+        x = x.transpose(1, -1)
+        x = self.out_proj(x.mean(1))  # (B, H)
+        return x
+
+
+class ConvLSTMStacks(nn.Module):
+    def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=3, norm='gn', dropout=0):
+        super().__init__()
+        self.conv = torch.nn.ModuleList()
+        self.kernel_size = kernel_size
+        self.in_proj = Linear(idim, n_chans)
+        for idx in range(n_layers):
+            self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=1, norm=norm, dropout=dropout))
+        self.lstm = nn.LSTM(n_chans, n_chans, 1, batch_first=True, bidirectional=True)
+        self.out_proj = Linear(n_chans * 2, odim)
+
+    def forward(self, x):
+        """
+
+        :param x: [B, T, H]
+        :return: [B, T, H]
+        """
+        x = self.in_proj(x)
+        x = x.transpose(1, -1)  # (B, idim, Tmax)
+        for f in self.conv:
+            x = x + f(x)  # (B, C, Tmax)
+        x = x.transpose(1, -1)
+        x, _ = self.lstm(x)  # (B, Tmax, H*2)
+        x = self.out_proj(x)  # (B, Tmax, H)
+        return x
+
+
+class ResidualLayer(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, padding):
+        super(ResidualLayer, self).__init__()
+        self.conv1d_layer = nn.Sequential(nn.Conv1d(in_channels=in_channels,
+                                                    out_channels=out_channels,
+                                                    kernel_size=kernel_size,
+                                                    stride=1,
+                                                    padding=padding),
+                                          nn.InstanceNorm1d(num_features=out_channels,
+                                                            affine=True))
+
+        self.conv_layer_gates = nn.Sequential(nn.Conv1d(in_channels=in_channels,
+                                                        out_channels=out_channels,
+                                                        kernel_size=kernel_size,
+                                                        stride=1,
+                                                        padding=padding),
+                                              nn.InstanceNorm1d(num_features=out_channels,
+                                                                affine=True))
+
+        self.conv1d_out_layer = nn.Sequential(nn.Conv1d(in_channels=out_channels,
+                                                        out_channels=in_channels,
+                                                        kernel_size=kernel_size,
+                                                        stride=1,
+                                                        padding=padding),
+                                              nn.InstanceNorm1d(num_features=in_channels,
+                                                                affine=True))
+
+    def forward(self, input):
+        """
+
+        :param input: [B, H, T]
+        :return: input: [B, H, T]
+        """
+        h1_norm = self.conv1d_layer(input)
+        h1_gates_norm = self.conv_layer_gates(input)
+
+        # GLU
+        h1_glu = h1_norm * torch.sigmoid(h1_gates_norm)
+
+        h2_norm = self.conv1d_out_layer(h1_glu)
+        return input + h2_norm
+
+
+class ConvGLUStacks(nn.Module):
+    def __init__(self, idim=80, n_layers=3, n_chans=256, odim=32, kernel_size=5, dropout=0):
+        super().__init__()
+        self.convs = []
+        self.kernel_size = kernel_size
+        self.in_proj = Linear(idim, n_chans)
+        for idx in range(n_layers):
+            self.convs.append(
+                nn.Sequential(ResidualLayer(
+                    n_chans, n_chans, kernel_size, kernel_size // 2),
+                    nn.Dropout(dropout)
+                ))
+        self.convs = nn.Sequential(*self.convs)
+        self.out_proj = Linear(n_chans, odim)
+
+    def forward(self, x):
+        """
+
+        :param x: [B, T, H]
+        :return: [B, T, H]
+        """
+        x = self.in_proj(x)
+        x = x.transpose(1, -1)  # (B, idim, Tmax)
+        x = self.convs(x)  # (B, C, Tmax)
+        x = x.transpose(1, -1)
+        x = self.out_proj(x)  # (B, Tmax, H)
+        return x
diff --git a/modules/commons/espnet_positional_embedding.py b/modules/commons/espnet_positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..74decb6ab300951490ae08a4b93041a0542b5bb7
--- /dev/null
+++ b/modules/commons/espnet_positional_embedding.py
@@ -0,0 +1,113 @@
+import math
+import torch
+
+
+class PositionalEncoding(torch.nn.Module):
+    """Positional encoding.
+    Args:
+        d_model (int): Embedding dimension.
+        dropout_rate (float): Dropout rate.
+        max_len (int): Maximum input length.
+        reverse (bool): Whether to reverse the input position.
+    """
+
+    def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+        """Construct an PositionalEncoding object."""
+        super(PositionalEncoding, self).__init__()
+        self.d_model = d_model
+        self.reverse = reverse
+        self.xscale = math.sqrt(self.d_model)
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.pe = None
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+    def extend_pe(self, x):
+        """Reset the positional encodings."""
+        if self.pe is not None:
+            if self.pe.size(1) >= x.size(1):
+                if self.pe.dtype != x.dtype or self.pe.device != x.device:
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        pe = torch.zeros(x.size(1), self.d_model)
+        if self.reverse:
+            position = torch.arange(
+                x.size(1) - 1, -1, -1.0, dtype=torch.float32
+            ).unsqueeze(1)
+        else:
+            position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.d_model)
+        )
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def forward(self, x: torch.Tensor):
+        """Add positional encoding.
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+        """
+        self.extend_pe(x)
+        x = x * self.xscale + self.pe[:, : x.size(1)]
+        return self.dropout(x)
+
+
+class ScaledPositionalEncoding(PositionalEncoding):
+    """Scaled positional encoding module.
+    See Sec. 3.2  https://arxiv.org/abs/1809.08895
+    Args:
+        d_model (int): Embedding dimension.
+        dropout_rate (float): Dropout rate.
+        max_len (int): Maximum input length.
+    """
+
+    def __init__(self, d_model, dropout_rate, max_len=5000):
+        """Initialize class."""
+        super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
+        self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+
+    def reset_parameters(self):
+        """Reset parameters."""
+        self.alpha.data = torch.tensor(1.0)
+
+    def forward(self, x):
+        """Add positional encoding.
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+        """
+        self.extend_pe(x)
+        x = x + self.alpha * self.pe[:, : x.size(1)]
+        return self.dropout(x)
+
+
+class RelPositionalEncoding(PositionalEncoding):
+    """Relative positional encoding module.
+    See : Appendix B in https://arxiv.org/abs/1901.02860
+    Args:
+        d_model (int): Embedding dimension.
+        dropout_rate (float): Dropout rate.
+        max_len (int): Maximum input length.
+    """
+
+    def __init__(self, d_model, dropout_rate, max_len=5000):
+        """Initialize class."""
+        super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+    def forward(self, x):
+        """Compute positional encoding.
+        Args:
+            x (torch.Tensor): Input tensor (batch, time, `*`).
+        Returns:
+            torch.Tensor: Encoded tensor (batch, time, `*`).
+            torch.Tensor: Positional embedding tensor (1, time, `*`).
+        """
+        self.extend_pe(x)
+        x = x * self.xscale
+        pos_emb = self.pe[:, : x.size(1)]
+        return self.dropout(x) + self.dropout(pos_emb)
\ No newline at end of file
diff --git a/modules/commons/ssim.py b/modules/commons/ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d0241f267ef58b24979e022b05f2a9adf768826
--- /dev/null
+++ b/modules/commons/ssim.py
@@ -0,0 +1,391 @@
+# '''
+# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
+# '''
+#
+# import torch
+# import torch.jit
+# import torch.nn.functional as F
+#
+#
+# @torch.jit.script
+# def create_window(window_size: int, sigma: float, channel: int):
+#     '''
+#     Create 1-D gauss kernel
+#     :param window_size: the size of gauss kernel
+#     :param sigma: sigma of normal distribution
+#     :param channel: input channel
+#     :return: 1D kernel
+#     '''
+#     coords = torch.arange(window_size, dtype=torch.float)
+#     coords -= window_size // 2
+#
+#     g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
+#     g /= g.sum()
+#
+#     g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
+#     return g
+#
+#
+# @torch.jit.script
+# def _gaussian_filter(x, window_1d, use_padding: bool):
+#     '''
+#     Blur input with 1-D kernel
+#     :param x: batch of tensors to be blured
+#     :param window_1d: 1-D gauss kernel
+#     :param use_padding: padding image before conv
+#     :return: blured tensors
+#     '''
+#     C = x.shape[1]
+#     padding = 0
+#     if use_padding:
+#         window_size = window_1d.shape[3]
+#         padding = window_size // 2
+#     out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
+#     out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
+#     return out
+#
+#
+# @torch.jit.script
+# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
+#     '''
+#     Calculate ssim index for X and Y
+#     :param X: images [B, C, H, N_bins]
+#     :param Y: images [B, C, H, N_bins]
+#     :param window: 1-D gauss kernel
+#     :param data_range: value range of input images. (usually 1.0 or 255)
+#     :param use_padding: padding image before conv
+#     :return:
+#     '''
+#
+#     K1 = 0.01
+#     K2 = 0.03
+#     compensation = 1.0
+#
+#     C1 = (K1 * data_range) ** 2
+#     C2 = (K2 * data_range) ** 2
+#
+#     mu1 = _gaussian_filter(X, window, use_padding)
+#     mu2 = _gaussian_filter(Y, window, use_padding)
+#     sigma1_sq = _gaussian_filter(X * X, window, use_padding)
+#     sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
+#     sigma12 = _gaussian_filter(X * Y, window, use_padding)
+#
+#     mu1_sq = mu1.pow(2)
+#     mu2_sq = mu2.pow(2)
+#     mu1_mu2 = mu1 * mu2
+#
+#     sigma1_sq = compensation * (sigma1_sq - mu1_sq)
+#     sigma2_sq = compensation * (sigma2_sq - mu2_sq)
+#     sigma12 = compensation * (sigma12 - mu1_mu2)
+#
+#     cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
+#     # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
+#     cs_map = cs_map.clamp_min(0.)
+#     ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
+#
+#     ssim_val = ssim_map.mean(dim=(1, 2, 3))  # reduce along CHW
+#     cs = cs_map.mean(dim=(1, 2, 3))
+#
+#     return ssim_val, cs
+#
+#
+# @torch.jit.script
+# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
+#     '''
+#     interface of ms-ssim
+#     :param X: a batch of images, (N,C,H,W)
+#     :param Y: a batch of images, (N,C,H,W)
+#     :param window: 1-D gauss kernel
+#     :param data_range: value range of input images. (usually 1.0 or 255)
+#     :param weights: weights for different levels
+#     :param use_padding: padding image before conv
+#     :param eps: use for avoid grad nan.
+#     :return:
+#     '''
+#     levels = weights.shape[0]
+#     cs_vals = []
+#     ssim_vals = []
+#     for _ in range(levels):
+#         ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
+#         # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
+#         ssim_val = ssim_val.clamp_min(eps)
+#         cs = cs.clamp_min(eps)
+#         cs_vals.append(cs)
+#
+#         ssim_vals.append(ssim_val)
+#         padding = (X.shape[2] % 2, X.shape[3] % 2)
+#         X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
+#         Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
+#
+#     cs_vals = torch.stack(cs_vals, dim=0)
+#     ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
+#     return ms_ssim_val
+#
+#
+# class SSIM(torch.jit.ScriptModule):
+#     __constants__ = ['data_range', 'use_padding']
+#
+#     def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
+#         '''
+#         :param window_size: the size of gauss kernel
+#         :param window_sigma: sigma of normal distribution
+#         :param data_range: value range of input images. (usually 1.0 or 255)
+#         :param channel: input channels (default: 3)
+#         :param use_padding: padding image before conv
+#         '''
+#         super().__init__()
+#         assert window_size % 2 == 1, 'Window size must be odd.'
+#         window = create_window(window_size, window_sigma, channel)
+#         self.register_buffer('window', window)
+#         self.data_range = data_range
+#         self.use_padding = use_padding
+#
+#     @torch.jit.script_method
+#     def forward(self, X, Y):
+#         r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
+#         return r[0]
+#
+#
+# class MS_SSIM(torch.jit.ScriptModule):
+#     __constants__ = ['data_range', 'use_padding', 'eps']
+#
+#     def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
+#                  levels=None, eps=1e-8):
+#         '''
+#         class for ms-ssim
+#         :param window_size: the size of gauss kernel
+#         :param window_sigma: sigma of normal distribution
+#         :param data_range: value range of input images. (usually 1.0 or 255)
+#         :param channel: input channels
+#         :param use_padding: padding image before conv
+#         :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
+#         :param levels: number of downsampling
+#         :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
+#         '''
+#         super().__init__()
+#         assert window_size % 2 == 1, 'Window size must be odd.'
+#         self.data_range = data_range
+#         self.use_padding = use_padding
+#         self.eps = eps
+#
+#         window = create_window(window_size, window_sigma, channel)
+#         self.register_buffer('window', window)
+#
+#         if weights is None:
+#             weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
+#         weights = torch.tensor(weights, dtype=torch.float)
+#
+#         if levels is not None:
+#             weights = weights[:levels]
+#             weights = weights / weights.sum()
+#
+#         self.register_buffer('weights', weights)
+#
+#     @torch.jit.script_method
+#     def forward(self, X, Y):
+#         return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
+#                        use_padding=self.use_padding, eps=self.eps)
+#
+#
+# if __name__ == '__main__':
+#     print('Simple Test')
+#     im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
+#     img1 = im / 255
+#     img2 = img1 * 0.5
+#
+#     losser = SSIM(data_range=1.).cuda()
+#     loss = losser(img1, img2).mean()
+#
+#     losser2 = MS_SSIM(data_range=1.).cuda()
+#     loss2 = losser2(img1, img2).mean()
+#
+#     print(loss.item())
+#     print(loss2.item())
+#
+# if __name__ == '__main__':
+#     print('Training Test')
+#     import cv2
+#     import torch.optim
+#     import numpy as np
+#     import imageio
+#     import time
+#
+#     out_test_video = False
+#     # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
+#     video_use_gif = False
+#
+#     im = cv2.imread('test_img1.jpg', 1)
+#     t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
+#
+#     if out_test_video:
+#         if video_use_gif:
+#             fps = 0.5
+#             out_wh = (im.shape[1] // 2, im.shape[0] // 2)
+#             suffix = '.gif'
+#         else:
+#             fps = 5
+#             out_wh = (im.shape[1], im.shape[0])
+#             suffix = '.mkv'
+#         video_last_time = time.perf_counter()
+#         video = imageio.get_writer('ssim_test' + suffix, fps=fps)
+#
+#     # 测试ssim
+#     print('Training SSIM')
+#     rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
+#     rand_im.requires_grad = True
+#     optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
+#     losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
+#     ssim_score = 0
+#     while ssim_score < 0.999:
+#         optim.zero_grad()
+#         loss = losser(rand_im, t_im)
+#         (-loss).sum().backward()
+#         ssim_score = loss.item()
+#         optim.step()
+#         r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
+#         r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
+#
+#         if out_test_video:
+#             if time.perf_counter() - video_last_time > 1. / fps:
+#                 video_last_time = time.perf_counter()
+#                 out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
+#                 out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
+#                 if isinstance(out_frame, cv2.UMat):
+#                     out_frame = out_frame.get()
+#                 video.append_data(out_frame)
+#
+#         cv2.imshow('ssim', r_im)
+#         cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
+#         cv2.waitKey(1)
+#
+#     if out_test_video:
+#         video.close()
+#
+#     # 测试ms_ssim
+#     if out_test_video:
+#         if video_use_gif:
+#             fps = 0.5
+#             out_wh = (im.shape[1] // 2, im.shape[0] // 2)
+#             suffix = '.gif'
+#         else:
+#             fps = 5
+#             out_wh = (im.shape[1], im.shape[0])
+#             suffix = '.mkv'
+#         video_last_time = time.perf_counter()
+#         video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
+#
+#     print('Training MS_SSIM')
+#     rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
+#     rand_im.requires_grad = True
+#     optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
+#     losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
+#     ssim_score = 0
+#     while ssim_score < 0.999:
+#         optim.zero_grad()
+#         loss = losser(rand_im, t_im)
+#         (-loss).sum().backward()
+#         ssim_score = loss.item()
+#         optim.step()
+#         r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
+#         r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
+#
+#         if out_test_video:
+#             if time.perf_counter() - video_last_time > 1. / fps:
+#                 video_last_time = time.perf_counter()
+#                 out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
+#                 out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
+#                 if isinstance(out_frame, cv2.UMat):
+#                     out_frame = out_frame.get()
+#                 video.append_data(out_frame)
+#
+#         cv2.imshow('ms_ssim', r_im)
+#         cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
+#         cv2.waitKey(1)
+#
+#     if out_test_video:
+#         video.close()
+
+"""
+Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
+"""
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from math import exp
+
+
+def gaussian(window_size, sigma):
+    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+    return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+    return window
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+    mu1_sq = mu1.pow(2)
+    mu2_sq = mu2.pow(2)
+    mu1_mu2 = mu1 * mu2
+
+    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+    C1 = 0.01 ** 2
+    C2 = 0.03 ** 2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+    if size_average:
+        return ssim_map.mean()
+    else:
+        return ssim_map.mean(1)
+
+
+class SSIM(torch.nn.Module):
+    def __init__(self, window_size=11, size_average=True):
+        super(SSIM, self).__init__()
+        self.window_size = window_size
+        self.size_average = size_average
+        self.channel = 1
+        self.window = create_window(window_size, self.channel)
+
+    def forward(self, img1, img2):
+        (_, channel, _, _) = img1.size()
+
+        if channel == self.channel and self.window.data.type() == img1.data.type():
+            window = self.window
+        else:
+            window = create_window(self.window_size, channel)
+
+            if img1.is_cuda:
+                window = window.cuda(img1.get_device())
+            window = window.type_as(img1)
+
+            self.window = window
+            self.channel = channel
+
+        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
+
+
+window = None
+
+
+def ssim(img1, img2, window_size=11, size_average=True):
+    (_, channel, _, _) = img1.size()
+    global window
+    if window is None:
+        window = create_window(window_size, channel)
+        if img1.is_cuda:
+            window = window.cuda(img1.get_device())
+        window = window.type_as(img1)
+    return _ssim(img1, img2, window, window_size, channel, size_average)
diff --git a/modules/fastspeech/fs2.py b/modules/fastspeech/fs2.py
new file mode 100644
index 0000000000000000000000000000000000000000..52b4ac4aaa7ae49f06736a038bde83ca2cfa8483
--- /dev/null
+++ b/modules/fastspeech/fs2.py
@@ -0,0 +1,255 @@
+from modules.commons.common_layers import *
+from modules.commons.common_layers import Embedding
+from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
+    EnergyPredictor, FastspeechEncoder
+from utils.cwt import cwt2f0
+from utils.hparams import hparams
+from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
+
+FS_ENCODERS = {
+    'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
+        embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
+        num_heads=hp['num_heads']),
+}
+
+FS_DECODERS = {
+    'fft': lambda hp: FastspeechDecoder(
+        hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
+}
+
+
+class FastSpeech2(nn.Module):
+    def __init__(self, dictionary, out_dims=None):
+        super().__init__()
+        self.dictionary = dictionary
+        self.padding_idx = dictionary.pad()
+        self.enc_layers = hparams['enc_layers']
+        self.dec_layers = hparams['dec_layers']
+        self.hidden_size = hparams['hidden_size']
+        self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
+        self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
+        self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
+        self.out_dims = out_dims
+        if out_dims is None:
+            self.out_dims = hparams['audio_num_mel_bins']
+        self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
+
+        if hparams['use_spk_id']:
+            self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
+            if hparams['use_split_spk_id']:
+                self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
+                self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
+        elif hparams['use_spk_embed']:
+            self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
+        predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
+        self.dur_predictor = DurationPredictor(
+            self.hidden_size,
+            n_chans=predictor_hidden,
+            n_layers=hparams['dur_predictor_layers'],
+            dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
+            kernel_size=hparams['dur_predictor_kernel'])
+        self.length_regulator = LengthRegulator()
+        if hparams['use_pitch_embed']:
+            self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
+            if hparams['pitch_type'] == 'cwt':
+                h = hparams['cwt_hidden_size']
+                cwt_out_dims = 10
+                if hparams['use_uv']:
+                    cwt_out_dims = cwt_out_dims + 1
+                self.cwt_predictor = nn.Sequential(
+                    nn.Linear(self.hidden_size, h),
+                    PitchPredictor(
+                        h,
+                        n_chans=predictor_hidden,
+                        n_layers=hparams['predictor_layers'],
+                        dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
+                        padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
+                self.cwt_stats_layers = nn.Sequential(
+                    nn.Linear(self.hidden_size, h), nn.ReLU(),
+                    nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
+                )
+            else:
+                self.pitch_predictor = PitchPredictor(
+                    self.hidden_size,
+                    n_chans=predictor_hidden,
+                    n_layers=hparams['predictor_layers'],
+                    dropout_rate=hparams['predictor_dropout'],
+                    odim=2 if hparams['pitch_type'] == 'frame' else 1,
+                    padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
+        if hparams['use_energy_embed']:
+            self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
+            self.energy_predictor = EnergyPredictor(
+                self.hidden_size,
+                n_chans=predictor_hidden,
+                n_layers=hparams['predictor_layers'],
+                dropout_rate=hparams['predictor_dropout'], odim=1,
+                padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
+
+    def build_embedding(self, dictionary, embed_dim):
+        num_embeddings = len(dictionary)
+        emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
+        return emb
+
+    def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
+                ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
+                spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
+        ret = {}
+        encoder_out = self.encoder(txt_tokens)  # [B, T, C]
+        src_nonpadding = (txt_tokens > 0).float()[:, :, None]
+
+        # add ref style embed
+        # Not implemented
+        # variance encoder
+        var_embed = 0
+
+        # encoder_out_dur denotes encoder outputs for duration predictor
+        # in speech adaptation, duration predictor use old speaker embedding
+        if hparams['use_spk_embed']:
+            spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
+        elif hparams['use_spk_id']:
+            spk_embed_id = spk_embed
+            if spk_embed_dur_id is None:
+                spk_embed_dur_id = spk_embed_id
+            if spk_embed_f0_id is None:
+                spk_embed_f0_id = spk_embed_id
+            spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
+            spk_embed_dur = spk_embed_f0 = spk_embed
+            if hparams['use_split_spk_id']:
+                spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
+                spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
+        else:
+            spk_embed_dur = spk_embed_f0 = spk_embed = 0
+
+        # add dur
+        dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
+
+        mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
+
+        decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
+
+        mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
+        decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_)  # [B, T, H]
+
+        tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
+
+        # add pitch and energy embed
+        pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
+        if hparams['use_pitch_embed']:
+            pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
+            decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
+        if hparams['use_energy_embed']:
+            decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
+
+        ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
+
+        if skip_decoder:
+            return ret
+        ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
+
+        return ret
+
+    def add_dur(self, dur_input, mel2ph, txt_tokens, ret):
+        """
+
+        :param dur_input: [B, T_txt, H]
+        :param mel2ph: [B, T_mel]
+        :param txt_tokens: [B, T_txt]
+        :param ret:
+        :return:
+        """
+        src_padding = txt_tokens == 0
+        dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
+        if mel2ph is None:
+            dur, xs = self.dur_predictor.inference(dur_input, src_padding)
+            ret['dur'] = xs
+            ret['dur_choice'] = dur
+            mel2ph = self.length_regulator(dur, src_padding).detach()
+            # from modules.fastspeech.fake_modules import FakeLengthRegulator
+            # fake_lr = FakeLengthRegulator()
+            # fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach()
+            # print(mel2ph == fake_mel2ph)
+        else:
+            ret['dur'] = self.dur_predictor(dur_input, src_padding)
+        ret['mel2ph'] = mel2ph
+        return mel2ph
+
+    def add_energy(self, decoder_inp, energy, ret):
+        decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
+        ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
+        if energy is None:
+            energy = energy_pred
+        energy = torch.clamp(energy * 256 // 4, max=255).long()
+        energy_embed = self.energy_embed(energy)
+        return energy_embed
+
+    def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
+        if hparams['pitch_type'] == 'ph':
+            pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
+            pitch_padding = encoder_out.sum().abs() == 0
+            ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
+            if f0 is None:
+                f0 = pitch_pred[:, :, 0]
+            ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
+            pitch = f0_to_coarse(f0_denorm)  # start from 0 [B, T_txt]
+            pitch = F.pad(pitch, [1, 0])
+            pitch = torch.gather(pitch, 1, mel2ph)  # [B, T_mel]
+            pitch_embed = self.pitch_embed(pitch)
+            return pitch_embed
+        decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
+
+        pitch_padding = mel2ph == 0
+
+        if hparams['pitch_type'] == 'cwt':
+            pitch_padding = None
+            ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
+            stats_out = self.cwt_stats_layers(encoder_out[:, 0, :])  # [B, 2]
+            mean = ret['f0_mean'] = stats_out[:, 0]
+            std = ret['f0_std'] = stats_out[:, 1]
+            cwt_spec = cwt_out[:, :, :10]
+            if f0 is None:
+                std = std * hparams['cwt_std_scale']
+                f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
+                if hparams['use_uv']:
+                    assert cwt_out.shape[-1] == 11
+                    uv = cwt_out[:, :, -1] > 0
+        elif hparams['pitch_ar']:
+            ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
+            if f0 is None:
+                f0 = pitch_pred[:, :, 0]
+        else:
+            ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
+            if f0 is None:
+                f0 = pitch_pred[:, :, 0]
+            if hparams['use_uv'] and uv is None:
+                uv = pitch_pred[:, :, 1] > 0
+        ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
+        if pitch_padding is not None:
+            f0[pitch_padding] = 0
+
+        pitch = f0_to_coarse(f0_denorm)  # start from 0
+        pitch_embed = self.pitch_embed(pitch)
+        return pitch_embed
+
+    def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
+        x = decoder_inp  # [B, T, H]
+        x = self.decoder(x)
+        x = self.mel_out(x)
+        return x * tgt_nonpadding
+
+    def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
+        f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
+        f0 = torch.cat(
+            [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
+        f0_norm = norm_f0(f0, None, hparams)
+        return f0_norm
+
+    def out2mel(self, out):
+        return out
+
+    @staticmethod
+    def mel_norm(x):
+        return (x + 5.5) / (6.3 / 2) - 1
+
+    @staticmethod
+    def mel_denorm(x):
+        return (x + 1) * (6.3 / 2) - 5.5
diff --git a/modules/fastspeech/pe.py b/modules/fastspeech/pe.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9fa5098b378bb4ed10f97f05a9ff725d1d2239c
--- /dev/null
+++ b/modules/fastspeech/pe.py
@@ -0,0 +1,149 @@
+from modules.commons.common_layers import *
+from utils.hparams import hparams
+from modules.fastspeech.tts_modules import PitchPredictor
+from utils.pitch_utils import denorm_f0
+
+
+class Prenet(nn.Module):
+    def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
+        super(Prenet, self).__init__()
+        padding = kernel // 2
+        self.layers = []
+        self.strides = strides if strides is not None else [1] * n_layers
+        for l in range(n_layers):
+            self.layers.append(nn.Sequential(
+                nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
+                nn.ReLU(),
+                nn.BatchNorm1d(out_dim)
+            ))
+            in_dim = out_dim
+        self.layers = nn.ModuleList(self.layers)
+        self.out_proj = nn.Linear(out_dim, out_dim)
+
+    def forward(self, x):
+        """
+
+        :param x: [B, T, 80]
+        :return: [L, B, T, H], [B, T, H]
+        """
+        padding_mask = x.abs().sum(-1).eq(0).data  # [B, T]
+        nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :]  # [B, 1, T]
+        x = x.transpose(1, 2)
+        hiddens = []
+        for i, l in enumerate(self.layers):
+            nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
+            x = l(x) * nonpadding_mask_TB
+        hiddens.append(x)
+        hiddens = torch.stack(hiddens, 0)  # [L, B, H, T]
+        hiddens = hiddens.transpose(2, 3)  # [L, B, T, H]
+        x = self.out_proj(x.transpose(1, 2))  # [B, T, H]
+        x = x * nonpadding_mask_TB.transpose(1, 2)
+        return hiddens, x
+
+
+class ConvBlock(nn.Module):
+    def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
+        super().__init__()
+        self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
+        self.norm = norm
+        if self.norm == 'bn':
+            self.norm = nn.BatchNorm1d(n_chans)
+        elif self.norm == 'in':
+            self.norm = nn.InstanceNorm1d(n_chans, affine=True)
+        elif self.norm == 'gn':
+            self.norm = nn.GroupNorm(n_chans // 16, n_chans)
+        elif self.norm == 'ln':
+            self.norm = LayerNorm(n_chans // 16, n_chans)
+        elif self.norm == 'wn':
+            self.conv = torch.nn.utils.weight_norm(self.conv.conv)
+        self.dropout = nn.Dropout(dropout)
+        self.relu = nn.ReLU()
+
+    def forward(self, x):
+        """
+
+        :param x: [B, C, T]
+        :return: [B, C, T]
+        """
+        x = self.conv(x)
+        if not isinstance(self.norm, str):
+            if self.norm == 'none':
+                pass
+            elif self.norm == 'ln':
+                x = self.norm(x.transpose(1, 2)).transpose(1, 2)
+            else:
+                x = self.norm(x)
+        x = self.relu(x)
+        x = self.dropout(x)
+        return x
+
+
+class ConvStacks(nn.Module):
+    def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
+                 dropout=0, strides=None, res=True):
+        super().__init__()
+        self.conv = torch.nn.ModuleList()
+        self.kernel_size = kernel_size
+        self.res = res
+        self.in_proj = Linear(idim, n_chans)
+        if strides is None:
+            strides = [1] * n_layers
+        else:
+            assert len(strides) == n_layers
+        for idx in range(n_layers):
+            self.conv.append(ConvBlock(
+                n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
+        self.out_proj = Linear(n_chans, odim)
+
+    def forward(self, x, return_hiddens=False):
+        """
+
+        :param x: [B, T, H]
+        :return: [B, T, H]
+        """
+        x = self.in_proj(x)
+        x = x.transpose(1, -1)  # (B, idim, Tmax)
+        hiddens = []
+        for f in self.conv:
+            x_ = f(x)
+            x = x + x_ if self.res else x_  # (B, C, Tmax)
+            hiddens.append(x)
+        x = x.transpose(1, -1)
+        x = self.out_proj(x)  # (B, Tmax, H)
+        if return_hiddens:
+            hiddens = torch.stack(hiddens, 1)  # [B, L, C, T]
+            return x, hiddens
+        return x
+
+
+class PitchExtractor(nn.Module):
+    def __init__(self, n_mel_bins=80, conv_layers=2):
+        super().__init__()
+        self.hidden_size = hparams['hidden_size']
+        self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
+        self.conv_layers = conv_layers
+
+        self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
+        if self.conv_layers > 0:
+            self.mel_encoder = ConvStacks(
+                    idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
+        self.pitch_predictor = PitchPredictor(
+            self.hidden_size, n_chans=self.predictor_hidden,
+            n_layers=5, dropout_rate=0.1, odim=2,
+            padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
+
+    def forward(self, mel_input=None):
+        ret = {}
+        mel_hidden = self.mel_prenet(mel_input)[1]
+        if self.conv_layers > 0:
+            mel_hidden = self.mel_encoder(mel_hidden)
+
+        ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
+
+        pitch_padding = mel_input.abs().sum(-1) == 0
+        use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
+
+        ret['f0_denorm_pred'] = denorm_f0(
+            pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
+            hparams, pitch_padding=pitch_padding)
+        return ret
\ No newline at end of file
diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..195eff279de781dd2565cfb2da65533c58f6c332
--- /dev/null
+++ b/modules/fastspeech/tts_modules.py
@@ -0,0 +1,357 @@
+import logging
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from modules.commons.espnet_positional_embedding import RelPositionalEncoding
+from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
+from utils.hparams import hparams
+
+DEFAULT_MAX_SOURCE_POSITIONS = 2000
+DEFAULT_MAX_TARGET_POSITIONS = 2000
+
+
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.dropout = dropout
+        self.num_heads = num_heads
+        self.op = EncSALayer(
+            hidden_size, num_heads, dropout=dropout,
+            attention_dropout=0.0, relu_dropout=dropout,
+            kernel_size=kernel_size
+            if kernel_size is not None else hparams['enc_ffn_kernel_size'],
+            padding=hparams['ffn_padding'],
+            norm=norm, act=hparams['ffn_act'])
+
+    def forward(self, x, **kwargs):
+        return self.op(x, **kwargs)
+
+
+######################
+# fastspeech modules
+######################
+class LayerNorm(torch.nn.LayerNorm):
+    """Layer normalization module.
+    :param int nout: output dim size
+    :param int dim: dimension to be normalized
+    """
+
+    def __init__(self, nout, dim=-1):
+        """Construct an LayerNorm object."""
+        super(LayerNorm, self).__init__(nout, eps=1e-12)
+        self.dim = dim
+
+    def forward(self, x):
+        """Apply layer normalization.
+        :param torch.Tensor x: input tensor
+        :return: layer normalized tensor
+        :rtype torch.Tensor
+        """
+        if self.dim == -1:
+            return super(LayerNorm, self).forward(x)
+        return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class DurationPredictor(torch.nn.Module):
+    """Duration predictor module.
+    This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
+    The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
+    .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
+        https://arxiv.org/pdf/1905.09263.pdf
+    Note:
+        The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
+        the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
+    """
+
+    def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
+        """Initilize duration predictor module.
+        Args:
+            idim (int): Input dimension.
+            n_layers (int, optional): Number of convolutional layers.
+            n_chans (int, optional): Number of channels of convolutional layers.
+            kernel_size (int, optional): Kernel size of convolutional layers.
+            dropout_rate (float, optional): Dropout rate.
+            offset (float, optional): Offset value to avoid nan in log domain.
+        """
+        super(DurationPredictor, self).__init__()
+        self.offset = offset
+        self.conv = torch.nn.ModuleList()
+        self.kernel_size = kernel_size
+        self.padding = padding
+        for idx in range(n_layers):
+            in_chans = idim if idx == 0 else n_chans
+            self.conv += [torch.nn.Sequential(
+                torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
+                                       if padding == 'SAME'
+                                       else (kernel_size - 1, 0), 0),
+                torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
+                torch.nn.ReLU(),
+                LayerNorm(n_chans, dim=1),
+                torch.nn.Dropout(dropout_rate)
+            )]
+        if hparams['dur_loss'] in ['mse', 'huber']:
+            odims = 1
+        elif hparams['dur_loss'] == 'mog':
+            odims = 15
+        elif hparams['dur_loss'] == 'crf':
+            odims = 32
+            from torchcrf import CRF
+            self.crf = CRF(odims, batch_first=True)
+        self.linear = torch.nn.Linear(n_chans, odims)
+
+    def _forward(self, xs, x_masks=None, is_inference=False):
+        xs = xs.transpose(1, -1)  # (B, idim, Tmax)
+        for f in self.conv:
+            xs = f(xs)  # (B, C, Tmax)
+            if x_masks is not None:
+                xs = xs * (1 - x_masks.float())[:, None, :]
+
+        xs = self.linear(xs.transpose(1, -1))  # [B, T, C]
+        xs = xs * (1 - x_masks.float())[:, :, None]  # (B, T, C)
+        if is_inference:
+            return self.out2dur(xs), xs
+        else:
+            if hparams['dur_loss'] in ['mse']:
+                xs = xs.squeeze(-1)  # (B, Tmax)
+        return xs
+
+    def out2dur(self, xs):
+        if hparams['dur_loss'] in ['mse']:
+            # NOTE: calculate in log domain
+            xs = xs.squeeze(-1)  # (B, Tmax)
+            dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long()  # avoid negative value
+        elif hparams['dur_loss'] == 'mog':
+            return NotImplementedError
+        elif hparams['dur_loss'] == 'crf':
+            dur = torch.LongTensor(self.crf.decode(xs)).cuda()
+        return dur
+
+    def forward(self, xs, x_masks=None):
+        """Calculate forward propagation.
+        Args:
+            xs (Tensor): Batch of input sequences (B, Tmax, idim).
+            x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
+        Returns:
+            Tensor: Batch of predicted durations in log domain (B, Tmax).
+        """
+        return self._forward(xs, x_masks, False)
+
+    def inference(self, xs, x_masks=None):
+        """Inference duration.
+        Args:
+            xs (Tensor): Batch of input sequences (B, Tmax, idim).
+            x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
+        Returns:
+            LongTensor: Batch of predicted durations in linear domain (B, Tmax).
+        """
+        return self._forward(xs, x_masks, True)
+
+
+class LengthRegulator(torch.nn.Module):
+    def __init__(self, pad_value=0.0):
+        super(LengthRegulator, self).__init__()
+        self.pad_value = pad_value
+
+    def forward(self, dur, dur_padding=None, alpha=1.0):
+        """
+        Example (no batch dim version):
+            1. dur = [2,2,3]
+            2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
+            3. token_mask = [[1,1,0,0,0,0,0],
+                             [0,0,1,1,0,0,0],
+                             [0,0,0,0,1,1,1]]
+            4. token_idx * token_mask = [[1,1,0,0,0,0,0],
+                                         [0,0,2,2,0,0,0],
+                                         [0,0,0,0,3,3,3]]
+            5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
+
+        :param dur: Batch of durations of each frame (B, T_txt)
+        :param dur_padding: Batch of padding of each frame (B, T_txt)
+        :param alpha: duration rescale coefficient
+        :return:
+            mel2ph (B, T_speech)
+        """
+        assert alpha > 0
+        dur = torch.round(dur.float() * alpha).long()
+        if dur_padding is not None:
+            dur = dur * (1 - dur_padding.long())
+        token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
+        dur_cumsum = torch.cumsum(dur, 1)
+        dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
+
+        pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
+        token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
+        mel2ph = (token_idx * token_mask.long()).sum(1)
+        return mel2ph
+
+
+class PitchPredictor(torch.nn.Module):
+    def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
+                 dropout_rate=0.1, padding='SAME'):
+        """Initilize pitch predictor module.
+        Args:
+            idim (int): Input dimension.
+            n_layers (int, optional): Number of convolutional layers.
+            n_chans (int, optional): Number of channels of convolutional layers.
+            kernel_size (int, optional): Kernel size of convolutional layers.
+            dropout_rate (float, optional): Dropout rate.
+        """
+        super(PitchPredictor, self).__init__()
+        self.conv = torch.nn.ModuleList()
+        self.kernel_size = kernel_size
+        self.padding = padding
+        for idx in range(n_layers):
+            in_chans = idim if idx == 0 else n_chans
+            self.conv += [torch.nn.Sequential(
+                torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
+                                       if padding == 'SAME'
+                                       else (kernel_size - 1, 0), 0),
+                torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
+                torch.nn.ReLU(),
+                LayerNorm(n_chans, dim=1),
+                torch.nn.Dropout(dropout_rate)
+            )]
+        self.linear = torch.nn.Linear(n_chans, odim)
+        self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
+        self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
+
+    def forward(self, xs):
+        """
+
+        :param xs: [B, T, H]
+        :return: [B, T, H]
+        """
+        positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
+        xs = xs + positions
+        xs = xs.transpose(1, -1)  # (B, idim, Tmax)
+        for f in self.conv:
+            xs = f(xs)  # (B, C, Tmax)
+        # NOTE: calculate in log domain
+        xs = self.linear(xs.transpose(1, -1))  # (B, Tmax, H)
+        return xs
+
+
+class EnergyPredictor(PitchPredictor):
+    pass
+
+
+def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
+    B, _ = mel2ph.shape
+    dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
+    dur = dur[:, 1:]
+    if max_dur is not None:
+        dur = dur.clamp(max=max_dur)
+    return dur
+
+
+class FFTBlocks(nn.Module):
+    def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
+                 use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
+        super().__init__()
+        self.num_layers = num_layers
+        embed_dim = self.hidden_size = hidden_size
+        self.dropout = dropout if dropout is not None else hparams['dropout']
+        self.use_pos_embed = use_pos_embed
+        self.use_last_norm = use_last_norm
+        if use_pos_embed:
+            self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
+            self.padding_idx = 0
+            self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
+            self.embed_positions = SinusoidalPositionalEmbedding(
+                embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+            )
+
+        self.layers = nn.ModuleList([])
+        self.layers.extend([
+            TransformerEncoderLayer(self.hidden_size, self.dropout,
+                                    kernel_size=ffn_kernel_size, num_heads=num_heads)
+            for _ in range(self.num_layers)
+        ])
+        if self.use_last_norm:
+            if norm == 'ln':
+                self.layer_norm = nn.LayerNorm(embed_dim)
+            elif norm == 'bn':
+                self.layer_norm = BatchNorm1dTBC(embed_dim)
+        else:
+            self.layer_norm = None
+
+    def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
+        """
+        :param x: [B, T, C]
+        :param padding_mask: [B, T]
+        :return: [B, T, C] or [L, B, T, C]
+        """
+        padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+        nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None]  # [T, B, 1]
+        if self.use_pos_embed:
+            positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+            x = x + positions
+            x = F.dropout(x, p=self.dropout, training=self.training)
+        # B x T x C -> T x B x C
+        x = x.transpose(0, 1) * nonpadding_mask_TB
+        hiddens = []
+        for layer in self.layers:
+            x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+            hiddens.append(x)
+        if self.use_last_norm:
+            x = self.layer_norm(x) * nonpadding_mask_TB
+        if return_hiddens:
+            x = torch.stack(hiddens, 0)  # [L, T, B, C]
+            x = x.transpose(1, 2)  # [L, B, T, C]
+        else:
+            x = x.transpose(0, 1)  # [B, T, C]
+        return x
+
+
+class FastspeechEncoder(FFTBlocks):
+    def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
+        hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
+        kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
+        num_layers = hparams['dec_layers'] if num_layers is None else num_layers
+        super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
+                         use_pos_embed=False)  # use_pos_embed_alpha for compatibility
+        self.embed_tokens = embed_tokens
+        self.embed_scale = math.sqrt(hidden_size)
+        self.padding_idx = 0
+        if hparams.get('rel_pos') is not None and hparams['rel_pos']:
+            self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
+        else:
+            self.embed_positions = SinusoidalPositionalEmbedding(
+                hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+            )
+
+    def forward(self, txt_tokens):
+        """
+
+        :param txt_tokens: [B, T]
+        :return: {
+            'encoder_out': [T x B x C]
+        }
+        """
+        encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
+        x = self.forward_embedding(txt_tokens)  # [B, T, H]
+        x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
+        return x
+
+    def forward_embedding(self, txt_tokens):
+        # embed tokens and positions
+        x = self.embed_scale * self.embed_tokens(txt_tokens)
+        if hparams['use_pos_embed']:
+            positions = self.embed_positions(txt_tokens)
+            x = x + positions
+        x = F.dropout(x, p=self.dropout, training=self.training)
+        return x
+
+
+class FastspeechDecoder(FFTBlocks):
+    def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
+        num_heads = hparams['num_heads'] if num_heads is None else num_heads
+        hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
+        kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
+        num_layers = hparams['dec_layers'] if num_layers is None else num_layers
+        super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
+
diff --git a/modules/hifigan/hifigan.py b/modules/hifigan/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae7e61f56b00d60bcc49a18ece3edbe54746f7ea
--- /dev/null
+++ b/modules/hifigan/hifigan.py
@@ -0,0 +1,365 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
+from modules.parallel_wavegan.models.source import SourceModuleHnNSF
+import numpy as np
+
+LRELU_SLOPE = 0.1
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.h = h
+        self.convs1 = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+                               padding=get_padding(kernel_size, dilation[0]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+                               padding=get_padding(kernel_size, dilation[1]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+                               padding=get_padding(kernel_size, dilation[2])))
+        ])
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+                               padding=get_padding(kernel_size, 1)))
+        ])
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.h = h
+        self.convs = nn.ModuleList([
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+                               padding=get_padding(kernel_size, dilation[0]))),
+            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+                               padding=get_padding(kernel_size, dilation[1])))
+        ])
+        self.convs.apply(init_weights)
+
+    def forward(self, x):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class Conv1d1x1(Conv1d):
+    """1x1 Conv1d with customized initialization."""
+
+    def __init__(self, in_channels, out_channels, bias):
+        """Initialize 1x1 Conv1d module."""
+        super(Conv1d1x1, self).__init__(in_channels, out_channels,
+                                        kernel_size=1, padding=0,
+                                        dilation=1, bias=bias)
+
+
+class HifiGanGenerator(torch.nn.Module):
+    def __init__(self, h, c_out=1):
+        super(HifiGanGenerator, self).__init__()
+        self.h = h
+        self.num_kernels = len(h['resblock_kernel_sizes'])
+        self.num_upsamples = len(h['upsample_rates'])
+
+        if h['use_pitch_embed']:
+            self.harmonic_num = 8
+            self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
+            self.m_source = SourceModuleHnNSF(
+                sampling_rate=h['audio_sample_rate'],
+                harmonic_num=self.harmonic_num)
+            self.noise_convs = nn.ModuleList()
+        self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
+        resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
+            c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
+            self.ups.append(weight_norm(
+                ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
+            if h['use_pitch_embed']:
+                if i + 1 < len(h['upsample_rates']):
+                    stride_f0 = np.prod(h['upsample_rates'][i + 1:])
+                    self.noise_convs.append(Conv1d(
+                        1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
+                else:
+                    self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = h['upsample_initial_channel'] // (2 ** (i + 1))
+            for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
+                self.resblocks.append(resblock(h, ch, k, d))
+
+        self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x, f0=None):
+        if f0 is not None:
+            # harmonic-source signal, noise-source signal, uv flag
+            f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
+            har_source, noi_source, uv = self.m_source(f0)
+            har_source = har_source.transpose(1, 2)
+
+        x = self.conv_pre(x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            if f0 is not None:
+                x_source = self.noise_convs[i](har_source)
+                x = x + x_source
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print('Removing weight norm...')
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+
+
+class DiscriminatorP(torch.nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
+        super(DiscriminatorP, self).__init__()
+        self.use_cond = use_cond
+        if use_cond:
+            from utils.hparams import hparams
+            t = hparams['hop_size']
+            self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
+            c_in = 2
+
+        self.period = period
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList([
+            norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+            norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
+        ])
+        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x, mel):
+        fmap = []
+        if self.use_cond:
+            x_mel = self.cond_net(mel)
+            x = torch.cat([x_mel, x], 1)
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+    def __init__(self, use_cond=False, c_in=1):
+        super(MultiPeriodDiscriminator, self).__init__()
+        self.discriminators = nn.ModuleList([
+            DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
+            DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
+            DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
+            DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
+            DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
+        ])
+
+    def forward(self, y, y_hat, mel=None):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y, mel)
+            y_d_g, fmap_g = d(y_hat, mel)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorS(torch.nn.Module):
+    def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
+        super(DiscriminatorS, self).__init__()
+        self.use_cond = use_cond
+        if use_cond:
+            t = np.prod(upsample_rates)
+            self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
+            c_in = 2
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList([
+            norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
+            norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+            norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+            norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+            norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+            norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+            norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+        ])
+        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x, mel):
+        if self.use_cond:
+            x_mel = self.cond_net(mel)
+            x = torch.cat([x_mel, x], 1)
+        fmap = []
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class MultiScaleDiscriminator(torch.nn.Module):
+    def __init__(self, use_cond=False, c_in=1):
+        super(MultiScaleDiscriminator, self).__init__()
+        from utils.hparams import hparams
+        self.discriminators = nn.ModuleList([
+            DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
+                           upsample_rates=[4, 4, hparams['hop_size'] // 16],
+                           c_in=c_in),
+            DiscriminatorS(use_cond=use_cond,
+                           upsample_rates=[4, 4, hparams['hop_size'] // 32],
+                           c_in=c_in),
+            DiscriminatorS(use_cond=use_cond,
+                           upsample_rates=[4, 4, hparams['hop_size'] // 64],
+                           c_in=c_in),
+        ])
+        self.meanpools = nn.ModuleList([
+            AvgPool1d(4, 2, padding=1),
+            AvgPool1d(4, 2, padding=1)
+        ])
+
+    def forward(self, y, y_hat, mel=None):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            if i != 0:
+                y = self.meanpools[i - 1](y)
+                y_hat = self.meanpools[i - 1](y_hat)
+            y_d_r, fmap_r = d(y, mel)
+            y_d_g, fmap_g = d(y_hat, mel)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+def feature_loss(fmap_r, fmap_g):
+    loss = 0
+    for dr, dg in zip(fmap_r, fmap_g):
+        for rl, gl in zip(dr, dg):
+            loss += torch.mean(torch.abs(rl - gl))
+
+    return loss * 2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+    r_losses = 0
+    g_losses = 0
+    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+        r_loss = torch.mean((1 - dr) ** 2)
+        g_loss = torch.mean(dg ** 2)
+        r_losses += r_loss
+        g_losses += g_loss
+    r_losses = r_losses / len(disc_real_outputs)
+    g_losses = g_losses / len(disc_real_outputs)
+    return r_losses, g_losses
+
+
+def cond_discriminator_loss(outputs):
+    loss = 0
+    for dg in outputs:
+        g_loss = torch.mean(dg ** 2)
+        loss += g_loss
+    loss = loss / len(outputs)
+    return loss
+
+
+def generator_loss(disc_outputs):
+    loss = 0
+    for dg in disc_outputs:
+        l = torch.mean((1 - dg) ** 2)
+        loss += l
+    loss = loss / len(disc_outputs)
+    return loss
diff --git a/modules/hifigan/mel_utils.py b/modules/hifigan/mel_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e0f7d4d16fa3e4aefc8949347455f5a6e938da
--- /dev/null
+++ b/modules/hifigan/mel_utils.py
@@ -0,0 +1,80 @@
+import numpy as np
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+from scipy.io.wavfile import read
+
+MAX_WAV_VALUE = 32768.0
+
+
+def load_wav(full_path):
+    sampling_rate, data = read(full_path)
+    return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+    return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+    return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+    return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+    output = dynamic_range_compression_torch(magnitudes)
+    return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+    output = dynamic_range_decompression_torch(magnitudes)
+    return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, hparams, center=False, complex=False):
+    # hop_size: 512  # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
+    # win_size: 2048  # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
+    # fmin: 55  # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+    # fmax: 10000  # To be increased/reduced depending on data.
+    # fft_size: 2048  # Extra window size is filled with 0 paddings to match this parameter
+    # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
+    n_fft = hparams['fft_size']
+    num_mels = hparams['audio_num_mel_bins']
+    sampling_rate = hparams['audio_sample_rate']
+    hop_size = hparams['hop_size']
+    win_size = hparams['win_size']
+    fmin = hparams['fmin']
+    fmax = hparams['fmax']
+    y = y.clamp(min=-1., max=1.)
+    global mel_basis, hann_window
+    if fmax not in mel_basis:
+        mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+        mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+                                mode='reflect')
+    y = y.squeeze(1)
+
+    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+                      center=center, pad_mode='reflect', normalized=False, onesided=True)
+
+    if not complex:
+        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+        spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
+        spec = spectral_normalize_torch(spec)
+    else:
+        B, C, T, _ = spec.shape
+        spec = spec.transpose(1, 2)  # [B, T, n_fft, 2]
+    return spec
diff --git a/modules/parallel_wavegan/__init__.py b/modules/parallel_wavegan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/parallel_wavegan/layers/__init__.py b/modules/parallel_wavegan/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e477f51116a3157781b1aefefbaf32fe4d4bd1f0
--- /dev/null
+++ b/modules/parallel_wavegan/layers/__init__.py
@@ -0,0 +1,5 @@
+from .causal_conv import *  # NOQA
+from .pqmf import *  # NOQA
+from .residual_block import *  # NOQA
+from modules.parallel_wavegan.layers.residual_stack import *  # NOQA
+from .upsample import *  # NOQA
diff --git a/modules/parallel_wavegan/layers/causal_conv.py b/modules/parallel_wavegan/layers/causal_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca77daf65f234e6fbe355ed148fc8f0ee85038a
--- /dev/null
+++ b/modules/parallel_wavegan/layers/causal_conv.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""Causal convolusion layer modules."""
+
+
+import torch
+
+
+class CausalConv1d(torch.nn.Module):
+    """CausalConv1d module with customized initialization."""
+
+    def __init__(self, in_channels, out_channels, kernel_size,
+                 dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}):
+        """Initialize CausalConv1d module."""
+        super(CausalConv1d, self).__init__()
+        self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
+        self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size,
+                                    dilation=dilation, bias=bias)
+
+    def forward(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input tensor (B, in_channels, T).
+
+        Returns:
+            Tensor: Output tensor (B, out_channels, T).
+
+        """
+        return self.conv(self.pad(x))[:, :, :x.size(2)]
+
+
+class CausalConvTranspose1d(torch.nn.Module):
+    """CausalConvTranspose1d module with customized initialization."""
+
+    def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
+        """Initialize CausalConvTranspose1d module."""
+        super(CausalConvTranspose1d, self).__init__()
+        self.deconv = torch.nn.ConvTranspose1d(
+            in_channels, out_channels, kernel_size, stride, bias=bias)
+        self.stride = stride
+
+    def forward(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input tensor (B, in_channels, T_in).
+
+        Returns:
+            Tensor: Output tensor (B, out_channels, T_out).
+
+        """
+        return self.deconv(x)[:, :, :-self.stride]
diff --git a/modules/parallel_wavegan/layers/pqmf.py b/modules/parallel_wavegan/layers/pqmf.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac21074fd32a370a099fa2facb62cfd3253d7579
--- /dev/null
+++ b/modules/parallel_wavegan/layers/pqmf.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""Pseudo QMF modules."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from scipy.signal import kaiser
+
+
+def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
+    """Design prototype filter for PQMF.
+
+    This method is based on `A Kaiser window approach for the design of prototype
+    filters of cosine modulated filterbanks`_.
+
+    Args:
+        taps (int): The number of filter taps.
+        cutoff_ratio (float): Cut-off frequency ratio.
+        beta (float): Beta coefficient for kaiser window.
+
+    Returns:
+        ndarray: Impluse response of prototype filter (taps + 1,).
+
+    .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
+        https://ieeexplore.ieee.org/abstract/document/681427
+
+    """
+    # check the arguments are valid
+    assert taps % 2 == 0, "The number of taps mush be even number."
+    assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
+
+    # make initial filter
+    omega_c = np.pi * cutoff_ratio
+    with np.errstate(invalid='ignore'):
+        h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
+            / (np.pi * (np.arange(taps + 1) - 0.5 * taps))
+    h_i[taps // 2] = np.cos(0) * cutoff_ratio  # fix nan due to indeterminate form
+
+    # apply kaiser window
+    w = kaiser(taps + 1, beta)
+    h = h_i * w
+
+    return h
+
+
+class PQMF(torch.nn.Module):
+    """PQMF module.
+
+    This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
+
+    .. _`Near-perfect-reconstruction pseudo-QMF banks`:
+        https://ieeexplore.ieee.org/document/258122
+
+    """
+
+    def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
+        """Initilize PQMF module.
+
+        Args:
+            subbands (int): The number of subbands.
+            taps (int): The number of filter taps.
+            cutoff_ratio (float): Cut-off frequency ratio.
+            beta (float): Beta coefficient for kaiser window.
+
+        """
+        super(PQMF, self).__init__()
+
+        # define filter coefficient
+        h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
+        h_analysis = np.zeros((subbands, len(h_proto)))
+        h_synthesis = np.zeros((subbands, len(h_proto)))
+        for k in range(subbands):
+            h_analysis[k] = 2 * h_proto * np.cos(
+                (2 * k + 1) * (np.pi / (2 * subbands)) *
+                (np.arange(taps + 1) - ((taps - 1) / 2)) +
+                (-1) ** k * np.pi / 4)
+            h_synthesis[k] = 2 * h_proto * np.cos(
+                (2 * k + 1) * (np.pi / (2 * subbands)) *
+                (np.arange(taps + 1) - ((taps - 1) / 2)) -
+                (-1) ** k * np.pi / 4)
+
+        # convert to tensor
+        analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
+        synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
+
+        # register coefficients as beffer
+        self.register_buffer("analysis_filter", analysis_filter)
+        self.register_buffer("synthesis_filter", synthesis_filter)
+
+        # filter for downsampling & upsampling
+        updown_filter = torch.zeros((subbands, subbands, subbands)).float()
+        for k in range(subbands):
+            updown_filter[k, k, 0] = 1.0
+        self.register_buffer("updown_filter", updown_filter)
+        self.subbands = subbands
+
+        # keep padding info
+        self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
+
+    def analysis(self, x):
+        """Analysis with PQMF.
+
+        Args:
+            x (Tensor): Input tensor (B, 1, T).
+
+        Returns:
+            Tensor: Output tensor (B, subbands, T // subbands).
+
+        """
+        x = F.conv1d(self.pad_fn(x), self.analysis_filter)
+        return F.conv1d(x, self.updown_filter, stride=self.subbands)
+
+    def synthesis(self, x):
+        """Synthesis with PQMF.
+
+        Args:
+            x (Tensor): Input tensor (B, subbands, T // subbands).
+
+        Returns:
+            Tensor: Output tensor (B, 1, T).
+
+        """
+        x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
+        return F.conv1d(self.pad_fn(x), self.synthesis_filter)
diff --git a/modules/parallel_wavegan/layers/residual_block.py b/modules/parallel_wavegan/layers/residual_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a267a86c1fa521c2824addf9dda304c43f1ff1f
--- /dev/null
+++ b/modules/parallel_wavegan/layers/residual_block.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+
+"""Residual block module in WaveNet.
+
+This code is modified from https://github.com/r9y9/wavenet_vocoder.
+
+"""
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+
+class Conv1d(torch.nn.Conv1d):
+    """Conv1d module with customized initialization."""
+
+    def __init__(self, *args, **kwargs):
+        """Initialize Conv1d module."""
+        super(Conv1d, self).__init__(*args, **kwargs)
+
+    def reset_parameters(self):
+        """Reset parameters."""
+        torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
+        if self.bias is not None:
+            torch.nn.init.constant_(self.bias, 0.0)
+
+
+class Conv1d1x1(Conv1d):
+    """1x1 Conv1d with customized initialization."""
+
+    def __init__(self, in_channels, out_channels, bias):
+        """Initialize 1x1 Conv1d module."""
+        super(Conv1d1x1, self).__init__(in_channels, out_channels,
+                                        kernel_size=1, padding=0,
+                                        dilation=1, bias=bias)
+
+
+class ResidualBlock(torch.nn.Module):
+    """Residual block module in WaveNet."""
+
+    def __init__(self,
+                 kernel_size=3,
+                 residual_channels=64,
+                 gate_channels=128,
+                 skip_channels=64,
+                 aux_channels=80,
+                 dropout=0.0,
+                 dilation=1,
+                 bias=True,
+                 use_causal_conv=False
+                 ):
+        """Initialize ResidualBlock module.
+
+        Args:
+            kernel_size (int): Kernel size of dilation convolution layer.
+            residual_channels (int): Number of channels for residual connection.
+            skip_channels (int): Number of channels for skip connection.
+            aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
+            dropout (float): Dropout probability.
+            dilation (int): Dilation factor.
+            bias (bool): Whether to add bias parameter in convolution layers.
+            use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
+
+        """
+        super(ResidualBlock, self).__init__()
+        self.dropout = dropout
+        # no future time stamps available
+        if use_causal_conv:
+            padding = (kernel_size - 1) * dilation
+        else:
+            assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+            padding = (kernel_size - 1) // 2 * dilation
+        self.use_causal_conv = use_causal_conv
+
+        # dilation conv
+        self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
+                           padding=padding, dilation=dilation, bias=bias)
+
+        # local conditioning
+        if aux_channels > 0:
+            self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
+        else:
+            self.conv1x1_aux = None
+
+        # conv output is split into two groups
+        gate_out_channels = gate_channels // 2
+        self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
+        self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
+
+    def forward(self, x, c):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input tensor (B, residual_channels, T).
+            c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
+
+        Returns:
+            Tensor: Output tensor for residual connection (B, residual_channels, T).
+            Tensor: Output tensor for skip connection (B, skip_channels, T).
+
+        """
+        residual = x
+        x = F.dropout(x, p=self.dropout, training=self.training)
+        x = self.conv(x)
+
+        # remove future time steps if use_causal_conv conv
+        x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
+
+        # split into two part for gated activation
+        splitdim = 1
+        xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
+
+        # local conditioning
+        if c is not None:
+            assert self.conv1x1_aux is not None
+            c = self.conv1x1_aux(c)
+            ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
+            xa, xb = xa + ca, xb + cb
+
+        x = torch.tanh(xa) * torch.sigmoid(xb)
+
+        # for skip connection
+        s = self.conv1x1_skip(x)
+
+        # for residual connection
+        x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
+
+        return x, s
diff --git a/modules/parallel_wavegan/layers/residual_stack.py b/modules/parallel_wavegan/layers/residual_stack.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e07c8803ad348dd923f6b7c0f7aff14aab9cf78
--- /dev/null
+++ b/modules/parallel_wavegan/layers/residual_stack.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""Residual stack module in MelGAN."""
+
+import torch
+
+from . import CausalConv1d
+
+
+class ResidualStack(torch.nn.Module):
+    """Residual stack module introduced in MelGAN."""
+
+    def __init__(self,
+                 kernel_size=3,
+                 channels=32,
+                 dilation=1,
+                 bias=True,
+                 nonlinear_activation="LeakyReLU",
+                 nonlinear_activation_params={"negative_slope": 0.2},
+                 pad="ReflectionPad1d",
+                 pad_params={},
+                 use_causal_conv=False,
+                 ):
+        """Initialize ResidualStack module.
+
+        Args:
+            kernel_size (int): Kernel size of dilation convolution layer.
+            channels (int): Number of channels of convolution layers.
+            dilation (int): Dilation factor.
+            bias (bool): Whether to add bias parameter in convolution layers.
+            nonlinear_activation (str): Activation function module name.
+            nonlinear_activation_params (dict): Hyperparameters for activation function.
+            pad (str): Padding function module name before dilated convolution layer.
+            pad_params (dict): Hyperparameters for padding function.
+            use_causal_conv (bool): Whether to use causal convolution.
+
+        """
+        super(ResidualStack, self).__init__()
+
+        # defile residual stack part
+        if not use_causal_conv:
+            assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+            self.stack = torch.nn.Sequential(
+                getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+                getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
+                torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
+                getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+                torch.nn.Conv1d(channels, channels, 1, bias=bias),
+            )
+        else:
+            self.stack = torch.nn.Sequential(
+                getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+                CausalConv1d(channels, channels, kernel_size, dilation=dilation,
+                             bias=bias, pad=pad, pad_params=pad_params),
+                getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+                torch.nn.Conv1d(channels, channels, 1, bias=bias),
+            )
+
+        # defile extra layer for skip connection
+        self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
+
+    def forward(self, c):
+        """Calculate forward propagation.
+
+        Args:
+            c (Tensor): Input tensor (B, channels, T).
+
+        Returns:
+            Tensor: Output tensor (B, chennels, T).
+
+        """
+        return self.stack(c) + self.skip_layer(c)
diff --git a/modules/parallel_wavegan/layers/tf_layers.py b/modules/parallel_wavegan/layers/tf_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0f46bd755c161cda2ac904fe37f3f3c6357a88d
--- /dev/null
+++ b/modules/parallel_wavegan/layers/tf_layers.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 MINH ANH (@dathudeptrai)
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""Tensorflow Layer modules complatible with pytorch."""
+
+import tensorflow as tf
+
+
+class TFReflectionPad1d(tf.keras.layers.Layer):
+    """Tensorflow ReflectionPad1d module."""
+
+    def __init__(self, padding_size):
+        """Initialize TFReflectionPad1d module.
+
+        Args:
+            padding_size (int): Padding size.
+
+        """
+        super(TFReflectionPad1d, self).__init__()
+        self.padding_size = padding_size
+
+    @tf.function
+    def call(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input tensor (B, T, 1, C).
+
+        Returns:
+            Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).
+
+        """
+        return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")
+
+
+class TFConvTranspose1d(tf.keras.layers.Layer):
+    """Tensorflow ConvTranspose1d module."""
+
+    def __init__(self, channels, kernel_size, stride, padding):
+        """Initialize TFConvTranspose1d( module.
+
+        Args:
+            channels (int): Number of channels.
+            kernel_size (int): kernel size.
+            strides (int): Stride width.
+            padding (str): Padding type ("same" or "valid").
+
+        """
+        super(TFConvTranspose1d, self).__init__()
+        self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
+            filters=channels,
+            kernel_size=(kernel_size, 1),
+            strides=(stride, 1),
+            padding=padding,
+        )
+
+    @tf.function
+    def call(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input tensor (B, T, 1, C).
+
+        Returns:
+            Tensors: Output tensor (B, T', 1, C').
+
+        """
+        x = self.conv1d_transpose(x)
+        return x
+
+
+class TFResidualStack(tf.keras.layers.Layer):
+    """Tensorflow ResidualStack module."""
+
+    def __init__(self,
+                 kernel_size,
+                 channels,
+                 dilation,
+                 bias,
+                 nonlinear_activation,
+                 nonlinear_activation_params,
+                 padding,
+                 ):
+        """Initialize TFResidualStack module.
+
+        Args:
+            kernel_size (int): Kernel size.
+            channles (int): Number of channels.
+            dilation (int): Dilation ine.
+            bias (bool): Whether to add bias parameter in convolution layers.
+            nonlinear_activation (str): Activation function module name.
+            nonlinear_activation_params (dict): Hyperparameters for activation function.
+            padding (str): Padding type ("same" or "valid").
+
+        """
+        super(TFResidualStack, self).__init__()
+        self.block = [
+            getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
+            TFReflectionPad1d(dilation),
+            tf.keras.layers.Conv2D(
+                filters=channels,
+                kernel_size=(kernel_size, 1),
+                dilation_rate=(dilation, 1),
+                use_bias=bias,
+                padding="valid",
+            ),
+            getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
+            tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
+        ]
+        self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
+
+    @tf.function
+    def call(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input tensor (B, T, 1, C).
+
+        Returns:
+            Tensor: Output tensor (B, T, 1, C).
+
+        """
+        _x = tf.identity(x)
+        for i, layer in enumerate(self.block):
+            _x = layer(_x)
+        shortcut = self.shortcut(x)
+        return shortcut + _x
diff --git a/modules/parallel_wavegan/layers/upsample.py b/modules/parallel_wavegan/layers/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c6397c420a81fadc5320e3a48f3249534decd8
--- /dev/null
+++ b/modules/parallel_wavegan/layers/upsample.py
@@ -0,0 +1,183 @@
+# -*- coding: utf-8 -*-
+
+"""Upsampling module.
+
+This code is modified from https://github.com/r9y9/wavenet_vocoder.
+
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from . import Conv1d
+
+
+class Stretch2d(torch.nn.Module):
+    """Stretch2d module."""
+
+    def __init__(self, x_scale, y_scale, mode="nearest"):
+        """Initialize Stretch2d module.
+
+        Args:
+            x_scale (int): X scaling factor (Time axis in spectrogram).
+            y_scale (int): Y scaling factor (Frequency axis in spectrogram).
+            mode (str): Interpolation mode.
+
+        """
+        super(Stretch2d, self).__init__()
+        self.x_scale = x_scale
+        self.y_scale = y_scale
+        self.mode = mode
+
+    def forward(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input tensor (B, C, F, T).
+
+        Returns:
+            Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
+
+        """
+        return F.interpolate(
+            x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
+
+
+class Conv2d(torch.nn.Conv2d):
+    """Conv2d module with customized initialization."""
+
+    def __init__(self, *args, **kwargs):
+        """Initialize Conv2d module."""
+        super(Conv2d, self).__init__(*args, **kwargs)
+
+    def reset_parameters(self):
+        """Reset parameters."""
+        self.weight.data.fill_(1. / np.prod(self.kernel_size))
+        if self.bias is not None:
+            torch.nn.init.constant_(self.bias, 0.0)
+
+
+class UpsampleNetwork(torch.nn.Module):
+    """Upsampling network module."""
+
+    def __init__(self,
+                 upsample_scales,
+                 nonlinear_activation=None,
+                 nonlinear_activation_params={},
+                 interpolate_mode="nearest",
+                 freq_axis_kernel_size=1,
+                 use_causal_conv=False,
+                 ):
+        """Initialize upsampling network module.
+
+        Args:
+            upsample_scales (list): List of upsampling scales.
+            nonlinear_activation (str): Activation function name.
+            nonlinear_activation_params (dict): Arguments for specified activation function.
+            interpolate_mode (str): Interpolation mode.
+            freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
+
+        """
+        super(UpsampleNetwork, self).__init__()
+        self.use_causal_conv = use_causal_conv
+        self.up_layers = torch.nn.ModuleList()
+        for scale in upsample_scales:
+            # interpolation layer
+            stretch = Stretch2d(scale, 1, interpolate_mode)
+            self.up_layers += [stretch]
+
+            # conv layer
+            assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
+            freq_axis_padding = (freq_axis_kernel_size - 1) // 2
+            kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
+            if use_causal_conv:
+                padding = (freq_axis_padding, scale * 2)
+            else:
+                padding = (freq_axis_padding, scale)
+            conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
+            self.up_layers += [conv]
+
+            # nonlinear
+            if nonlinear_activation is not None:
+                nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
+                self.up_layers += [nonlinear]
+
+    def forward(self, c):
+        """Calculate forward propagation.
+
+        Args:
+            c : Input tensor (B, C, T).
+
+        Returns:
+            Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
+
+        """
+        c = c.unsqueeze(1)  # (B, 1, C, T)
+        for f in self.up_layers:
+            if self.use_causal_conv and isinstance(f, Conv2d):
+                c = f(c)[..., :c.size(-1)]
+            else:
+                c = f(c)
+        return c.squeeze(1)  # (B, C, T')
+
+
+class ConvInUpsampleNetwork(torch.nn.Module):
+    """Convolution + upsampling network module."""
+
+    def __init__(self,
+                 upsample_scales,
+                 nonlinear_activation=None,
+                 nonlinear_activation_params={},
+                 interpolate_mode="nearest",
+                 freq_axis_kernel_size=1,
+                 aux_channels=80,
+                 aux_context_window=0,
+                 use_causal_conv=False
+                 ):
+        """Initialize convolution + upsampling network module.
+
+        Args:
+            upsample_scales (list): List of upsampling scales.
+            nonlinear_activation (str): Activation function name.
+            nonlinear_activation_params (dict): Arguments for specified activation function.
+            mode (str): Interpolation mode.
+            freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
+            aux_channels (int): Number of channels of pre-convolutional layer.
+            aux_context_window (int): Context window size of the pre-convolutional layer.
+            use_causal_conv (bool): Whether to use causal structure.
+
+        """
+        super(ConvInUpsampleNetwork, self).__init__()
+        self.aux_context_window = aux_context_window
+        self.use_causal_conv = use_causal_conv and aux_context_window > 0
+        # To capture wide-context information in conditional features
+        kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
+        # NOTE(kan-bayashi): Here do not use padding because the input is already padded
+        self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
+        self.upsample = UpsampleNetwork(
+            upsample_scales=upsample_scales,
+            nonlinear_activation=nonlinear_activation,
+            nonlinear_activation_params=nonlinear_activation_params,
+            interpolate_mode=interpolate_mode,
+            freq_axis_kernel_size=freq_axis_kernel_size,
+            use_causal_conv=use_causal_conv,
+        )
+
+    def forward(self, c):
+        """Calculate forward propagation.
+
+        Args:
+            c : Input tensor (B, C, T').
+
+        Returns:
+            Tensor: Upsampled tensor (B, C, T),
+                where T = (T' - aux_context_window * 2) * prod(upsample_scales).
+
+        Note:
+            The length of inputs considers the context window size.
+
+        """
+        c_ = self.conv_in(c)
+        c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
+        return self.upsample(c)
diff --git a/modules/parallel_wavegan/losses/__init__.py b/modules/parallel_wavegan/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b03080a907cb5cb4b316ceb74866ddbc406b33bf
--- /dev/null
+++ b/modules/parallel_wavegan/losses/__init__.py
@@ -0,0 +1 @@
+from .stft_loss import *  # NOQA
diff --git a/modules/parallel_wavegan/losses/stft_loss.py b/modules/parallel_wavegan/losses/stft_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d2aa21ad30ba094c406366e652067462f49cd2
--- /dev/null
+++ b/modules/parallel_wavegan/losses/stft_loss.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+    """Perform STFT and convert to magnitude spectrogram.
+
+    Args:
+        x (Tensor): Input signal tensor (B, T).
+        fft_size (int): FFT size.
+        hop_size (int): Hop size.
+        win_length (int): Window length.
+        window (str): Window function type.
+
+    Returns:
+        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+
+    """
+    x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
+    real = x_stft[..., 0]
+    imag = x_stft[..., 1]
+
+    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+    return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
+
+
+class SpectralConvergengeLoss(torch.nn.Module):
+    """Spectral convergence loss module."""
+
+    def __init__(self):
+        """Initilize spectral convergence loss module."""
+        super(SpectralConvergengeLoss, self).__init__()
+
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+
+        Returns:
+            Tensor: Spectral convergence loss value.
+
+        """
+        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+
+class LogSTFTMagnitudeLoss(torch.nn.Module):
+    """Log STFT magnitude loss module."""
+
+    def __init__(self):
+        """Initilize los STFT magnitude loss module."""
+        super(LogSTFTMagnitudeLoss, self).__init__()
+
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+
+        Returns:
+            Tensor: Log STFT magnitude loss value.
+
+        """
+        return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
+
+
+class STFTLoss(torch.nn.Module):
+    """STFT loss module."""
+
+    def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
+        """Initialize STFT loss module."""
+        super(STFTLoss, self).__init__()
+        self.fft_size = fft_size
+        self.shift_size = shift_size
+        self.win_length = win_length
+        self.window = getattr(torch, window)(win_length)
+        self.spectral_convergenge_loss = SpectralConvergengeLoss()
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+
+        Returns:
+            Tensor: Spectral convergence loss value.
+            Tensor: Log STFT magnitude loss value.
+
+        """
+        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+        return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+    """Multi resolution STFT loss module."""
+
+    def __init__(self,
+                 fft_sizes=[1024, 2048, 512],
+                 hop_sizes=[120, 240, 50],
+                 win_lengths=[600, 1200, 240],
+                 window="hann_window"):
+        """Initialize Multi resolution STFT loss module.
+
+        Args:
+            fft_sizes (list): List of FFT sizes.
+            hop_sizes (list): List of hop sizes.
+            win_lengths (list): List of window lengths.
+            window (str): Window function type.
+
+        """
+        super(MultiResolutionSTFTLoss, self).__init__()
+        assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+        self.stft_losses = torch.nn.ModuleList()
+        for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+            self.stft_losses += [STFTLoss(fs, ss, wl, window)]
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+
+        Returns:
+            Tensor: Multi resolution spectral convergence loss value.
+            Tensor: Multi resolution log STFT magnitude loss value.
+
+        """
+        sc_loss = 0.0
+        mag_loss = 0.0
+        for f in self.stft_losses:
+            sc_l, mag_l = f(x, y)
+            sc_loss += sc_l
+            mag_loss += mag_l
+        sc_loss /= len(self.stft_losses)
+        mag_loss /= len(self.stft_losses)
+
+        return sc_loss, mag_loss
diff --git a/modules/parallel_wavegan/models/__init__.py b/modules/parallel_wavegan/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4803ba6b2a0afc8022e756ae5b3f4c7403c3c1bd
--- /dev/null
+++ b/modules/parallel_wavegan/models/__init__.py
@@ -0,0 +1,2 @@
+from .melgan import *  # NOQA
+from .parallel_wavegan import *  # NOQA
diff --git a/modules/parallel_wavegan/models/melgan.py b/modules/parallel_wavegan/models/melgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..e021ae4817a8c1c97338e61b00b230c881836fd8
--- /dev/null
+++ b/modules/parallel_wavegan/models/melgan.py
@@ -0,0 +1,427 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""MelGAN Modules."""
+
+import logging
+
+import numpy as np
+import torch
+
+from modules.parallel_wavegan.layers import CausalConv1d
+from modules.parallel_wavegan.layers import CausalConvTranspose1d
+from modules.parallel_wavegan.layers import ResidualStack
+
+
+class MelGANGenerator(torch.nn.Module):
+    """MelGAN generator module."""
+
+    def __init__(self,
+                 in_channels=80,
+                 out_channels=1,
+                 kernel_size=7,
+                 channels=512,
+                 bias=True,
+                 upsample_scales=[8, 8, 2, 2],
+                 stack_kernel_size=3,
+                 stacks=3,
+                 nonlinear_activation="LeakyReLU",
+                 nonlinear_activation_params={"negative_slope": 0.2},
+                 pad="ReflectionPad1d",
+                 pad_params={},
+                 use_final_nonlinear_activation=True,
+                 use_weight_norm=True,
+                 use_causal_conv=False,
+                 ):
+        """Initialize MelGANGenerator module.
+
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            kernel_size (int): Kernel size of initial and final conv layer.
+            channels (int): Initial number of channels for conv layer.
+            bias (bool): Whether to add bias parameter in convolution layers.
+            upsample_scales (list): List of upsampling scales.
+            stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
+            stacks (int): Number of stacks in a single residual stack.
+            nonlinear_activation (str): Activation function module name.
+            nonlinear_activation_params (dict): Hyperparameters for activation function.
+            pad (str): Padding function module name before dilated convolution layer.
+            pad_params (dict): Hyperparameters for padding function.
+            use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
+            use_weight_norm (bool): Whether to use weight norm.
+                If set to true, it will be applied to all of the conv layers.
+            use_causal_conv (bool): Whether to use causal convolution.
+
+        """
+        super(MelGANGenerator, self).__init__()
+
+        # check hyper parameters is valid
+        assert channels >= np.prod(upsample_scales)
+        assert channels % (2 ** len(upsample_scales)) == 0
+        if not use_causal_conv:
+            assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+
+        # add initial layer
+        layers = []
+        if not use_causal_conv:
+            layers += [
+                getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
+                torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias),
+            ]
+        else:
+            layers += [
+                CausalConv1d(in_channels, channels, kernel_size,
+                             bias=bias, pad=pad, pad_params=pad_params),
+            ]
+
+        for i, upsample_scale in enumerate(upsample_scales):
+            # add upsampling layer
+            layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
+            if not use_causal_conv:
+                layers += [
+                    torch.nn.ConvTranspose1d(
+                        channels // (2 ** i),
+                        channels // (2 ** (i + 1)),
+                        upsample_scale * 2,
+                        stride=upsample_scale,
+                        padding=upsample_scale // 2 + upsample_scale % 2,
+                        output_padding=upsample_scale % 2,
+                        bias=bias,
+                    )
+                ]
+            else:
+                layers += [
+                    CausalConvTranspose1d(
+                        channels // (2 ** i),
+                        channels // (2 ** (i + 1)),
+                        upsample_scale * 2,
+                        stride=upsample_scale,
+                        bias=bias,
+                    )
+                ]
+
+            # add residual stack
+            for j in range(stacks):
+                layers += [
+                    ResidualStack(
+                        kernel_size=stack_kernel_size,
+                        channels=channels // (2 ** (i + 1)),
+                        dilation=stack_kernel_size ** j,
+                        bias=bias,
+                        nonlinear_activation=nonlinear_activation,
+                        nonlinear_activation_params=nonlinear_activation_params,
+                        pad=pad,
+                        pad_params=pad_params,
+                        use_causal_conv=use_causal_conv,
+                    )
+                ]
+
+        # add final layer
+        layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
+        if not use_causal_conv:
+            layers += [
+                getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
+                torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias),
+            ]
+        else:
+            layers += [
+                CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size,
+                             bias=bias, pad=pad, pad_params=pad_params),
+            ]
+        if use_final_nonlinear_activation:
+            layers += [torch.nn.Tanh()]
+
+        # define the model as a single function
+        self.melgan = torch.nn.Sequential(*layers)
+
+        # apply weight norm
+        if use_weight_norm:
+            self.apply_weight_norm()
+
+        # reset parameters
+        self.reset_parameters()
+
+    def forward(self, c):
+        """Calculate forward propagation.
+
+        Args:
+            c (Tensor): Input tensor (B, channels, T).
+
+        Returns:
+            Tensor: Output tensor (B, 1, T ** prod(upsample_scales)).
+
+        """
+        return self.melgan(c)
+
+    def remove_weight_norm(self):
+        """Remove weight normalization module from all of the layers."""
+        def _remove_weight_norm(m):
+            try:
+                logging.debug(f"Weight norm is removed from {m}.")
+                torch.nn.utils.remove_weight_norm(m)
+            except ValueError:  # this module didn't have weight norm
+                return
+
+        self.apply(_remove_weight_norm)
+
+    def apply_weight_norm(self):
+        """Apply weight normalization module from all of the layers."""
+        def _apply_weight_norm(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+                torch.nn.utils.weight_norm(m)
+                logging.debug(f"Weight norm is applied to {m}.")
+
+        self.apply(_apply_weight_norm)
+
+    def reset_parameters(self):
+        """Reset parameters.
+
+        This initialization follows official implementation manner.
+        https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
+
+        """
+        def _reset_parameters(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+                m.weight.data.normal_(0.0, 0.02)
+                logging.debug(f"Reset parameters in {m}.")
+
+        self.apply(_reset_parameters)
+
+
+class MelGANDiscriminator(torch.nn.Module):
+    """MelGAN discriminator module."""
+
+    def __init__(self,
+                 in_channels=1,
+                 out_channels=1,
+                 kernel_sizes=[5, 3],
+                 channels=16,
+                 max_downsample_channels=1024,
+                 bias=True,
+                 downsample_scales=[4, 4, 4, 4],
+                 nonlinear_activation="LeakyReLU",
+                 nonlinear_activation_params={"negative_slope": 0.2},
+                 pad="ReflectionPad1d",
+                 pad_params={},
+                 ):
+        """Initilize MelGAN discriminator module.
+
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer,
+                and the first and the second kernel sizes will be used for the last two layers.
+                For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
+                the last two layers' kernel size will be 5 and 3, respectively.
+            channels (int): Initial number of channels for conv layer.
+            max_downsample_channels (int): Maximum number of channels for downsampling layers.
+            bias (bool): Whether to add bias parameter in convolution layers.
+            downsample_scales (list): List of downsampling scales.
+            nonlinear_activation (str): Activation function module name.
+            nonlinear_activation_params (dict): Hyperparameters for activation function.
+            pad (str): Padding function module name before dilated convolution layer.
+            pad_params (dict): Hyperparameters for padding function.
+
+        """
+        super(MelGANDiscriminator, self).__init__()
+        self.layers = torch.nn.ModuleList()
+
+        # check kernel size is valid
+        assert len(kernel_sizes) == 2
+        assert kernel_sizes[0] % 2 == 1
+        assert kernel_sizes[1] % 2 == 1
+
+        # add first layer
+        self.layers += [
+            torch.nn.Sequential(
+                getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
+                torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias),
+                getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+            )
+        ]
+
+        # add downsample layers
+        in_chs = channels
+        for downsample_scale in downsample_scales:
+            out_chs = min(in_chs * downsample_scale, max_downsample_channels)
+            self.layers += [
+                torch.nn.Sequential(
+                    torch.nn.Conv1d(
+                        in_chs, out_chs,
+                        kernel_size=downsample_scale * 10 + 1,
+                        stride=downsample_scale,
+                        padding=downsample_scale * 5,
+                        groups=in_chs // 4,
+                        bias=bias,
+                    ),
+                    getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+                )
+            ]
+            in_chs = out_chs
+
+        # add final layers
+        out_chs = min(in_chs * 2, max_downsample_channels)
+        self.layers += [
+            torch.nn.Sequential(
+                torch.nn.Conv1d(
+                    in_chs, out_chs, kernel_sizes[0],
+                    padding=(kernel_sizes[0] - 1) // 2,
+                    bias=bias,
+                ),
+                getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+            )
+        ]
+        self.layers += [
+            torch.nn.Conv1d(
+                out_chs, out_channels, kernel_sizes[1],
+                padding=(kernel_sizes[1] - 1) // 2,
+                bias=bias,
+            ),
+        ]
+
+    def forward(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input noise signal (B, 1, T).
+
+        Returns:
+            List: List of output tensors of each layer.
+
+        """
+        outs = []
+        for f in self.layers:
+            x = f(x)
+            outs += [x]
+
+        return outs
+
+
+class MelGANMultiScaleDiscriminator(torch.nn.Module):
+    """MelGAN multi-scale discriminator module."""
+
+    def __init__(self,
+                 in_channels=1,
+                 out_channels=1,
+                 scales=3,
+                 downsample_pooling="AvgPool1d",
+                 # follow the official implementation setting
+                 downsample_pooling_params={
+                     "kernel_size": 4,
+                     "stride": 2,
+                     "padding": 1,
+                     "count_include_pad": False,
+                 },
+                 kernel_sizes=[5, 3],
+                 channels=16,
+                 max_downsample_channels=1024,
+                 bias=True,
+                 downsample_scales=[4, 4, 4, 4],
+                 nonlinear_activation="LeakyReLU",
+                 nonlinear_activation_params={"negative_slope": 0.2},
+                 pad="ReflectionPad1d",
+                 pad_params={},
+                 use_weight_norm=True,
+                 ):
+        """Initilize MelGAN multi-scale discriminator module.
+
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            downsample_pooling (str): Pooling module name for downsampling of the inputs.
+            downsample_pooling_params (dict): Parameters for the above pooling module.
+            kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer,
+                and the first and the second kernel sizes will be used for the last two layers.
+            channels (int): Initial number of channels for conv layer.
+            max_downsample_channels (int): Maximum number of channels for downsampling layers.
+            bias (bool): Whether to add bias parameter in convolution layers.
+            downsample_scales (list): List of downsampling scales.
+            nonlinear_activation (str): Activation function module name.
+            nonlinear_activation_params (dict): Hyperparameters for activation function.
+            pad (str): Padding function module name before dilated convolution layer.
+            pad_params (dict): Hyperparameters for padding function.
+            use_causal_conv (bool): Whether to use causal convolution.
+
+        """
+        super(MelGANMultiScaleDiscriminator, self).__init__()
+        self.discriminators = torch.nn.ModuleList()
+
+        # add discriminators
+        for _ in range(scales):
+            self.discriminators += [
+                MelGANDiscriminator(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    kernel_sizes=kernel_sizes,
+                    channels=channels,
+                    max_downsample_channels=max_downsample_channels,
+                    bias=bias,
+                    downsample_scales=downsample_scales,
+                    nonlinear_activation=nonlinear_activation,
+                    nonlinear_activation_params=nonlinear_activation_params,
+                    pad=pad,
+                    pad_params=pad_params,
+                )
+            ]
+        self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params)
+
+        # apply weight norm
+        if use_weight_norm:
+            self.apply_weight_norm()
+
+        # reset parameters
+        self.reset_parameters()
+
+    def forward(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input noise signal (B, 1, T).
+
+        Returns:
+            List: List of list of each discriminator outputs, which consists of each layer output tensors.
+
+        """
+        outs = []
+        for f in self.discriminators:
+            outs += [f(x)]
+            x = self.pooling(x)
+
+        return outs
+
+    def remove_weight_norm(self):
+        """Remove weight normalization module from all of the layers."""
+        def _remove_weight_norm(m):
+            try:
+                logging.debug(f"Weight norm is removed from {m}.")
+                torch.nn.utils.remove_weight_norm(m)
+            except ValueError:  # this module didn't have weight norm
+                return
+
+        self.apply(_remove_weight_norm)
+
+    def apply_weight_norm(self):
+        """Apply weight normalization module from all of the layers."""
+        def _apply_weight_norm(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+                torch.nn.utils.weight_norm(m)
+                logging.debug(f"Weight norm is applied to {m}.")
+
+        self.apply(_apply_weight_norm)
+
+    def reset_parameters(self):
+        """Reset parameters.
+
+        This initialization follows official implementation manner.
+        https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
+
+        """
+        def _reset_parameters(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+                m.weight.data.normal_(0.0, 0.02)
+                logging.debug(f"Reset parameters in {m}.")
+
+        self.apply(_reset_parameters)
diff --git a/modules/parallel_wavegan/models/parallel_wavegan.py b/modules/parallel_wavegan/models/parallel_wavegan.py
new file mode 100644
index 0000000000000000000000000000000000000000..c63b59f67aa48342179415c1d1beac68574a5498
--- /dev/null
+++ b/modules/parallel_wavegan/models/parallel_wavegan.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""Parallel WaveGAN Modules."""
+
+import logging
+import math
+
+import torch
+from torch import nn
+
+from modules.parallel_wavegan.layers import Conv1d
+from modules.parallel_wavegan.layers import Conv1d1x1
+from modules.parallel_wavegan.layers import ResidualBlock
+from modules.parallel_wavegan.layers import upsample
+from modules.parallel_wavegan import models
+
+
+class ParallelWaveGANGenerator(torch.nn.Module):
+    """Parallel WaveGAN Generator module."""
+
+    def __init__(self,
+                 in_channels=1,
+                 out_channels=1,
+                 kernel_size=3,
+                 layers=30,
+                 stacks=3,
+                 residual_channels=64,
+                 gate_channels=128,
+                 skip_channels=64,
+                 aux_channels=80,
+                 aux_context_window=2,
+                 dropout=0.0,
+                 bias=True,
+                 use_weight_norm=True,
+                 use_causal_conv=False,
+                 upsample_conditional_features=True,
+                 upsample_net="ConvInUpsampleNetwork",
+                 upsample_params={"upsample_scales": [4, 4, 4, 4]},
+                 use_pitch_embed=False,
+                 ):
+        """Initialize Parallel WaveGAN Generator module.
+
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            kernel_size (int): Kernel size of dilated convolution.
+            layers (int): Number of residual block layers.
+            stacks (int): Number of stacks i.e., dilation cycles.
+            residual_channels (int): Number of channels in residual conv.
+            gate_channels (int):  Number of channels in gated conv.
+            skip_channels (int): Number of channels in skip conv.
+            aux_channels (int): Number of channels for auxiliary feature conv.
+            aux_context_window (int): Context window size for auxiliary feature.
+            dropout (float): Dropout rate. 0.0 means no dropout applied.
+            bias (bool): Whether to use bias parameter in conv layer.
+            use_weight_norm (bool): Whether to use weight norm.
+                If set to true, it will be applied to all of the conv layers.
+            use_causal_conv (bool): Whether to use causal structure.
+            upsample_conditional_features (bool): Whether to use upsampling network.
+            upsample_net (str): Upsampling network architecture.
+            upsample_params (dict): Upsampling network parameters.
+
+        """
+        super(ParallelWaveGANGenerator, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.aux_channels = aux_channels
+        self.layers = layers
+        self.stacks = stacks
+        self.kernel_size = kernel_size
+
+        # check the number of layers and stacks
+        assert layers % stacks == 0
+        layers_per_stack = layers // stacks
+
+        # define first convolution
+        self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
+
+        # define conv + upsampling network
+        if upsample_conditional_features:
+            upsample_params.update({
+                "use_causal_conv": use_causal_conv,
+            })
+            if upsample_net == "MelGANGenerator":
+                assert aux_context_window == 0
+                upsample_params.update({
+                    "use_weight_norm": False,  # not to apply twice
+                    "use_final_nonlinear_activation": False,
+                })
+                self.upsample_net = getattr(models, upsample_net)(**upsample_params)
+            else:
+                if upsample_net == "ConvInUpsampleNetwork":
+                    upsample_params.update({
+                        "aux_channels": aux_channels,
+                        "aux_context_window": aux_context_window,
+                    })
+                self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
+        else:
+            self.upsample_net = None
+
+        # define residual blocks
+        self.conv_layers = torch.nn.ModuleList()
+        for layer in range(layers):
+            dilation = 2 ** (layer % layers_per_stack)
+            conv = ResidualBlock(
+                kernel_size=kernel_size,
+                residual_channels=residual_channels,
+                gate_channels=gate_channels,
+                skip_channels=skip_channels,
+                aux_channels=aux_channels,
+                dilation=dilation,
+                dropout=dropout,
+                bias=bias,
+                use_causal_conv=use_causal_conv,
+            )
+            self.conv_layers += [conv]
+
+        # define output layers
+        self.last_conv_layers = torch.nn.ModuleList([
+            torch.nn.ReLU(inplace=True),
+            Conv1d1x1(skip_channels, skip_channels, bias=True),
+            torch.nn.ReLU(inplace=True),
+            Conv1d1x1(skip_channels, out_channels, bias=True),
+        ])
+
+        self.use_pitch_embed = use_pitch_embed
+        if use_pitch_embed:
+            self.pitch_embed = nn.Embedding(300, aux_channels, 0)
+            self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
+
+        # apply weight norm
+        if use_weight_norm:
+            self.apply_weight_norm()
+
+    def forward(self, x, c=None, pitch=None, **kwargs):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input noise signal (B, C_in, T).
+            c (Tensor): Local conditioning auxiliary features (B, C ,T').
+            pitch (Tensor): Local conditioning pitch (B, T').
+
+        Returns:
+            Tensor: Output tensor (B, C_out, T)
+
+        """
+        # perform upsampling
+        if c is not None and self.upsample_net is not None:
+            if self.use_pitch_embed:
+                p = self.pitch_embed(pitch)
+                c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
+            c = self.upsample_net(c)
+            assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
+
+        # encode to hidden representation
+        x = self.first_conv(x)
+        skips = 0
+        for f in self.conv_layers:
+            x, h = f(x, c)
+            skips += h
+        skips *= math.sqrt(1.0 / len(self.conv_layers))
+
+        # apply final layers
+        x = skips
+        for f in self.last_conv_layers:
+            x = f(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        """Remove weight normalization module from all of the layers."""
+        def _remove_weight_norm(m):
+            try:
+                logging.debug(f"Weight norm is removed from {m}.")
+                torch.nn.utils.remove_weight_norm(m)
+            except ValueError:  # this module didn't have weight norm
+                return
+
+        self.apply(_remove_weight_norm)
+
+    def apply_weight_norm(self):
+        """Apply weight normalization module from all of the layers."""
+        def _apply_weight_norm(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+                torch.nn.utils.weight_norm(m)
+                logging.debug(f"Weight norm is applied to {m}.")
+
+        self.apply(_apply_weight_norm)
+
+    @staticmethod
+    def _get_receptive_field_size(layers, stacks, kernel_size,
+                                  dilation=lambda x: 2 ** x):
+        assert layers % stacks == 0
+        layers_per_cycle = layers // stacks
+        dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
+        return (kernel_size - 1) * sum(dilations) + 1
+
+    @property
+    def receptive_field_size(self):
+        """Return receptive field size."""
+        return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
+
+
+class ParallelWaveGANDiscriminator(torch.nn.Module):
+    """Parallel WaveGAN Discriminator module."""
+
+    def __init__(self,
+                 in_channels=1,
+                 out_channels=1,
+                 kernel_size=3,
+                 layers=10,
+                 conv_channels=64,
+                 dilation_factor=1,
+                 nonlinear_activation="LeakyReLU",
+                 nonlinear_activation_params={"negative_slope": 0.2},
+                 bias=True,
+                 use_weight_norm=True,
+                 ):
+        """Initialize Parallel WaveGAN Discriminator module.
+
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            kernel_size (int): Number of output channels.
+            layers (int): Number of conv layers.
+            conv_channels (int): Number of chnn layers.
+            dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
+                the dilation will be 2, 4, 8, ..., and so on.
+            nonlinear_activation (str): Nonlinear function after each conv.
+            nonlinear_activation_params (dict): Nonlinear function parameters
+            bias (bool): Whether to use bias parameter in conv.
+            use_weight_norm (bool) Whether to use weight norm.
+                If set to true, it will be applied to all of the conv layers.
+
+        """
+        super(ParallelWaveGANDiscriminator, self).__init__()
+        assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+        assert dilation_factor > 0, "Dilation factor must be > 0."
+        self.conv_layers = torch.nn.ModuleList()
+        conv_in_channels = in_channels
+        for i in range(layers - 1):
+            if i == 0:
+                dilation = 1
+            else:
+                dilation = i if dilation_factor == 1 else dilation_factor ** i
+                conv_in_channels = conv_channels
+            padding = (kernel_size - 1) // 2 * dilation
+            conv_layer = [
+                Conv1d(conv_in_channels, conv_channels,
+                       kernel_size=kernel_size, padding=padding,
+                       dilation=dilation, bias=bias),
+                getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
+            ]
+            self.conv_layers += conv_layer
+        padding = (kernel_size - 1) // 2
+        last_conv_layer = Conv1d(
+            conv_in_channels, out_channels,
+            kernel_size=kernel_size, padding=padding, bias=bias)
+        self.conv_layers += [last_conv_layer]
+
+        # apply weight norm
+        if use_weight_norm:
+            self.apply_weight_norm()
+
+    def forward(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input noise signal (B, 1, T).
+
+        Returns:
+            Tensor: Output tensor (B, 1, T)
+
+        """
+        for f in self.conv_layers:
+            x = f(x)
+        return x
+
+    def apply_weight_norm(self):
+        """Apply weight normalization module from all of the layers."""
+        def _apply_weight_norm(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+                torch.nn.utils.weight_norm(m)
+                logging.debug(f"Weight norm is applied to {m}.")
+
+        self.apply(_apply_weight_norm)
+
+    def remove_weight_norm(self):
+        """Remove weight normalization module from all of the layers."""
+        def _remove_weight_norm(m):
+            try:
+                logging.debug(f"Weight norm is removed from {m}.")
+                torch.nn.utils.remove_weight_norm(m)
+            except ValueError:  # this module didn't have weight norm
+                return
+
+        self.apply(_remove_weight_norm)
+
+
+class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
+    """Parallel WaveGAN Discriminator module."""
+
+    def __init__(self,
+                 in_channels=1,
+                 out_channels=1,
+                 kernel_size=3,
+                 layers=30,
+                 stacks=3,
+                 residual_channels=64,
+                 gate_channels=128,
+                 skip_channels=64,
+                 dropout=0.0,
+                 bias=True,
+                 use_weight_norm=True,
+                 use_causal_conv=False,
+                 nonlinear_activation="LeakyReLU",
+                 nonlinear_activation_params={"negative_slope": 0.2},
+                 ):
+        """Initialize Parallel WaveGAN Discriminator module.
+
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            kernel_size (int): Kernel size of dilated convolution.
+            layers (int): Number of residual block layers.
+            stacks (int): Number of stacks i.e., dilation cycles.
+            residual_channels (int): Number of channels in residual conv.
+            gate_channels (int):  Number of channels in gated conv.
+            skip_channels (int): Number of channels in skip conv.
+            dropout (float): Dropout rate. 0.0 means no dropout applied.
+            bias (bool): Whether to use bias parameter in conv.
+            use_weight_norm (bool): Whether to use weight norm.
+                If set to true, it will be applied to all of the conv layers.
+            use_causal_conv (bool): Whether to use causal structure.
+            nonlinear_activation_params (dict): Nonlinear function parameters
+
+        """
+        super(ResidualParallelWaveGANDiscriminator, self).__init__()
+        assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.layers = layers
+        self.stacks = stacks
+        self.kernel_size = kernel_size
+
+        # check the number of layers and stacks
+        assert layers % stacks == 0
+        layers_per_stack = layers // stacks
+
+        # define first convolution
+        self.first_conv = torch.nn.Sequential(
+            Conv1d1x1(in_channels, residual_channels, bias=True),
+            getattr(torch.nn, nonlinear_activation)(
+                inplace=True, **nonlinear_activation_params),
+        )
+
+        # define residual blocks
+        self.conv_layers = torch.nn.ModuleList()
+        for layer in range(layers):
+            dilation = 2 ** (layer % layers_per_stack)
+            conv = ResidualBlock(
+                kernel_size=kernel_size,
+                residual_channels=residual_channels,
+                gate_channels=gate_channels,
+                skip_channels=skip_channels,
+                aux_channels=-1,
+                dilation=dilation,
+                dropout=dropout,
+                bias=bias,
+                use_causal_conv=use_causal_conv,
+            )
+            self.conv_layers += [conv]
+
+        # define output layers
+        self.last_conv_layers = torch.nn.ModuleList([
+            getattr(torch.nn, nonlinear_activation)(
+                inplace=True, **nonlinear_activation_params),
+            Conv1d1x1(skip_channels, skip_channels, bias=True),
+            getattr(torch.nn, nonlinear_activation)(
+                inplace=True, **nonlinear_activation_params),
+            Conv1d1x1(skip_channels, out_channels, bias=True),
+        ])
+
+        # apply weight norm
+        if use_weight_norm:
+            self.apply_weight_norm()
+
+    def forward(self, x):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Input noise signal (B, 1, T).
+
+        Returns:
+            Tensor: Output tensor (B, 1, T)
+
+        """
+        x = self.first_conv(x)
+
+        skips = 0
+        for f in self.conv_layers:
+            x, h = f(x, None)
+            skips += h
+        skips *= math.sqrt(1.0 / len(self.conv_layers))
+
+        # apply final layers
+        x = skips
+        for f in self.last_conv_layers:
+            x = f(x)
+        return x
+
+    def apply_weight_norm(self):
+        """Apply weight normalization module from all of the layers."""
+        def _apply_weight_norm(m):
+            if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+                torch.nn.utils.weight_norm(m)
+                logging.debug(f"Weight norm is applied to {m}.")
+
+        self.apply(_apply_weight_norm)
+
+    def remove_weight_norm(self):
+        """Remove weight normalization module from all of the layers."""
+        def _remove_weight_norm(m):
+            try:
+                logging.debug(f"Weight norm is removed from {m}.")
+                torch.nn.utils.remove_weight_norm(m)
+            except ValueError:  # this module didn't have weight norm
+                return
+
+        self.apply(_remove_weight_norm)
diff --git a/modules/parallel_wavegan/models/source.py b/modules/parallel_wavegan/models/source.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2a006e53c0e2194036fd08ea9d6ed4d9a10d6cf
--- /dev/null
+++ b/modules/parallel_wavegan/models/source.py
@@ -0,0 +1,538 @@
+import torch
+import numpy as np
+import sys
+import torch.nn.functional as torch_nn_func
+
+
+class SineGen(torch.nn.Module):
+    """ Definition of sine generator
+    SineGen(samp_rate, harmonic_num = 0,
+            sine_amp = 0.1, noise_std = 0.003,
+            voiced_threshold = 0,
+            flag_for_pulse=False)
+
+    samp_rate: sampling rate in Hz
+    harmonic_num: number of harmonic overtones (default 0)
+    sine_amp: amplitude of sine-wavefrom (default 0.1)
+    noise_std: std of Gaussian noise (default 0.003)
+    voiced_thoreshold: F0 threshold for U/V classification (default 0)
+    flag_for_pulse: this SinGen is used inside PulseGen (default False)
+
+    Note: when flag_for_pulse is True, the first time step of a voiced
+        segment is always sin(np.pi) or cos(0)
+    """
+
+    def __init__(self, samp_rate, harmonic_num=0,
+                 sine_amp=0.1, noise_std=0.003,
+                 voiced_threshold=0,
+                 flag_for_pulse=False):
+        super(SineGen, self).__init__()
+        self.sine_amp = sine_amp
+        self.noise_std = noise_std
+        self.harmonic_num = harmonic_num
+        self.dim = self.harmonic_num + 1
+        self.sampling_rate = samp_rate
+        self.voiced_threshold = voiced_threshold
+        self.flag_for_pulse = flag_for_pulse
+
+    def _f02uv(self, f0):
+        # generate uv signal
+        uv = torch.ones_like(f0)
+        uv = uv * (f0 > self.voiced_threshold)
+        return uv
+
+    def _f02sine(self, f0_values):
+        """ f0_values: (batchsize, length, dim)
+            where dim indicates fundamental tone and overtones
+        """
+        # convert to F0 in rad. The interger part n can be ignored
+        # because 2 * np.pi * n doesn't affect phase
+        rad_values = (f0_values / self.sampling_rate) % 1
+
+        # initial phase noise (no noise for fundamental component)
+        rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
+                              device=f0_values.device)
+        rand_ini[:, 0] = 0
+        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+        # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+        if not self.flag_for_pulse:
+            # for normal case
+
+            # To prevent torch.cumsum numerical overflow,
+            # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+            # Buffer tmp_over_one_idx indicates the time step to add -1.
+            # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+            tmp_over_one = torch.cumsum(rad_values, 1) % 1
+            tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
+                                tmp_over_one[:, :-1, :]) < 0
+            cumsum_shift = torch.zeros_like(rad_values)
+            cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+            sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
+                              * 2 * np.pi)
+        else:
+            # If necessary, make sure that the first time step of every
+            # voiced segments is sin(pi) or cos(0)
+            # This is used for pulse-train generation
+
+            # identify the last time step in unvoiced segments
+            uv = self._f02uv(f0_values)
+            uv_1 = torch.roll(uv, shifts=-1, dims=1)
+            uv_1[:, -1, :] = 1
+            u_loc = (uv < 1) * (uv_1 > 0)
+
+            # get the instantanouse phase
+            tmp_cumsum = torch.cumsum(rad_values, dim=1)
+            # different batch needs to be processed differently
+            for idx in range(f0_values.shape[0]):
+                temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+                temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+                # stores the accumulation of i.phase within
+                # each voiced segments
+                tmp_cumsum[idx, :, :] = 0
+                tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+            # rad_values - tmp_cumsum: remove the accumulation of i.phase
+            # within the previous voiced segment.
+            i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+            # get the sines
+            sines = torch.cos(i_phase * 2 * np.pi)
+        return sines
+
+    def forward(self, f0):
+        """ sine_tensor, uv = forward(f0)
+        input F0: tensor(batchsize=1, length, dim=1)
+                  f0 for unvoiced steps should be 0
+        output sine_tensor: tensor(batchsize=1, length, dim)
+        output uv: tensor(batchsize=1, length, 1)
+        """
+        with torch.no_grad():
+            f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
+                                 device=f0.device)
+            # fundamental component
+            f0_buf[:, :, 0] = f0[:, :, 0]
+            for idx in np.arange(self.harmonic_num):
+                # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
+                f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
+
+            # generate sine waveforms
+            sine_waves = self._f02sine(f0_buf) * self.sine_amp
+
+            # generate uv signal
+            # uv = torch.ones(f0.shape)
+            # uv = uv * (f0 > self.voiced_threshold)
+            uv = self._f02uv(f0)
+
+            # noise: for unvoiced should be similar to sine_amp
+            #        std = self.sine_amp/3 -> max value ~ self.sine_amp
+            # .       for voiced regions is self.noise_std
+            noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+            noise = noise_amp * torch.randn_like(sine_waves)
+
+            # first: set the unvoiced part to 0 by uv
+            # then: additive noise
+            sine_waves = sine_waves * uv + noise
+        return sine_waves, uv, noise
+
+
+class PulseGen(torch.nn.Module):
+    """ Definition of Pulse train generator
+
+    There are many ways to implement pulse generator.
+    Here, PulseGen is based on SinGen. For a perfect
+    """
+    def __init__(self, samp_rate, pulse_amp = 0.1,
+                 noise_std = 0.003, voiced_threshold = 0):
+        super(PulseGen, self).__init__()
+        self.pulse_amp = pulse_amp
+        self.sampling_rate = samp_rate
+        self.voiced_threshold = voiced_threshold
+        self.noise_std = noise_std
+        self.l_sinegen = SineGen(self.sampling_rate, harmonic_num=0, \
+                                 sine_amp=self.pulse_amp, noise_std=0, \
+                                 voiced_threshold=self.voiced_threshold, \
+                                 flag_for_pulse=True)
+
+    def forward(self, f0):
+        """ Pulse train generator
+        pulse_train, uv = forward(f0)
+        input F0: tensor(batchsize=1, length, dim=1)
+                  f0 for unvoiced steps should be 0
+        output pulse_train: tensor(batchsize=1, length, dim)
+        output uv: tensor(batchsize=1, length, 1)
+
+        Note: self.l_sine doesn't make sure that the initial phase of
+        a voiced segment is np.pi, the first pulse in a voiced segment
+        may not be at the first time step within a voiced segment
+        """
+        with torch.no_grad():
+            sine_wav, uv, noise = self.l_sinegen(f0)
+
+            # sine without additive noise
+            pure_sine = sine_wav - noise
+
+            # step t corresponds to a pulse if
+            # sine[t] > sine[t+1] & sine[t] > sine[t-1]
+            # & sine[t-1], sine[t+1], and sine[t] are voiced
+            # or
+            # sine[t] is voiced, sine[t-1] is unvoiced
+            # we use torch.roll to simulate sine[t+1] and sine[t-1]
+            sine_1 = torch.roll(pure_sine, shifts=1, dims=1)
+            uv_1 = torch.roll(uv, shifts=1, dims=1)
+            uv_1[:, 0, :] = 0
+            sine_2 = torch.roll(pure_sine, shifts=-1, dims=1)
+            uv_2 = torch.roll(uv, shifts=-1, dims=1)
+            uv_2[:, -1, :] = 0
+
+            loc = (pure_sine > sine_1) * (pure_sine > sine_2) \
+                  * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \
+                  + (uv_1 < 1) * (uv > 0)
+
+            # pulse train without noise
+            pulse_train = pure_sine * loc
+
+            # additive noise to pulse train
+            # note that noise from sinegen is zero in voiced regions
+            pulse_noise = torch.randn_like(pure_sine) * self.noise_std
+
+            # with additive noise on pulse, and unvoiced regions
+            pulse_train += pulse_noise * loc + pulse_noise * (1 - uv)
+        return pulse_train, sine_wav, uv, pulse_noise
+
+
+class SignalsConv1d(torch.nn.Module):
+    """ Filtering input signal with time invariant filter
+    Note: FIRFilter conducted filtering given fixed FIR weight
+          SignalsConv1d convolves two signals
+    Note: this is based on torch.nn.functional.conv1d
+
+    """
+
+    def __init__(self):
+        super(SignalsConv1d, self).__init__()
+
+    def forward(self, signal, system_ir):
+        """ output = forward(signal, system_ir)
+
+        signal:    (batchsize, length1, dim)
+        system_ir: (length2, dim)
+
+        output:    (batchsize, length1, dim)
+        """
+        if signal.shape[-1] != system_ir.shape[-1]:
+            print("Error: SignalsConv1d expects shape:")
+            print("signal    (batchsize, length1, dim)")
+            print("system_id (batchsize, length2, dim)")
+            print("But received signal: {:s}".format(str(signal.shape)))
+            print(" system_ir: {:s}".format(str(system_ir.shape)))
+            sys.exit(1)
+        padding_length = system_ir.shape[0] - 1
+        groups = signal.shape[-1]
+
+        # pad signal on the left
+        signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), \
+                                       (padding_length, 0))
+        # prepare system impulse response as (dim, 1, length2)
+        # also flip the impulse response
+        ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), \
+                        dims=[2])
+        # convolute
+        output = torch_nn_func.conv1d(signal_pad, ir, groups=groups)
+        return output.permute(0, 2, 1)
+
+
+class CyclicNoiseGen_v1(torch.nn.Module):
+    """ CyclicnoiseGen_v1
+    Cyclic noise with a single parameter of beta.
+    Pytorch v1 implementation assumes f_t is also fixed
+    """
+
+    def __init__(self, samp_rate,
+                 noise_std=0.003, voiced_threshold=0):
+        super(CyclicNoiseGen_v1, self).__init__()
+        self.samp_rate = samp_rate
+        self.noise_std = noise_std
+        self.voiced_threshold = voiced_threshold
+
+        self.l_pulse = PulseGen(samp_rate, pulse_amp=1.0,
+                                noise_std=noise_std,
+                                voiced_threshold=voiced_threshold)
+        self.l_conv = SignalsConv1d()
+
+    def noise_decay(self, beta, f0mean):
+        """ decayed_noise = noise_decay(beta, f0mean)
+        decayed_noise =  n[t]exp(-t * f_mean / beta / samp_rate)
+
+        beta: (dim=1) or (batchsize=1, 1, dim=1)
+        f0mean (batchsize=1, 1, dim=1)
+
+        decayed_noise (batchsize=1, length, dim=1)
+        """
+        with torch.no_grad():
+            # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T
+            # truncate the noise when decayed by -40 dB
+            length = 4.6 * self.samp_rate / f0mean
+            length = length.int()
+            time_idx = torch.arange(0, length, device=beta.device)
+            time_idx = time_idx.unsqueeze(0).unsqueeze(2)
+            time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2])
+
+        noise = torch.randn(time_idx.shape, device=beta.device)
+
+        # due to Pytorch implementation, use f0_mean as the f0 factor
+        decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate)
+        return noise * self.noise_std * decay
+
+    def forward(self, f0s, beta):
+        """ Producde cyclic-noise
+        """
+        # pulse train
+        pulse_train, sine_wav, uv, noise = self.l_pulse(f0s)
+        pure_pulse = pulse_train - noise
+
+        # decayed_noise (length, dim=1)
+        if (uv < 1).all():
+            # all unvoiced
+            cyc_noise = torch.zeros_like(sine_wav)
+        else:
+            f0mean = f0s[uv > 0].mean()
+
+            decayed_noise = self.noise_decay(beta, f0mean)[0, :, :]
+            # convolute
+            cyc_noise = self.l_conv(pure_pulse, decayed_noise)
+
+        # add noise in invoiced segments
+        cyc_noise = cyc_noise + noise * (1.0 - uv)
+        return cyc_noise, pulse_train, sine_wav, uv, noise
+
+
+class SineGen(torch.nn.Module):
+    """ Definition of sine generator
+    SineGen(samp_rate, harmonic_num = 0,
+            sine_amp = 0.1, noise_std = 0.003,
+            voiced_threshold = 0,
+            flag_for_pulse=False)
+
+    samp_rate: sampling rate in Hz
+    harmonic_num: number of harmonic overtones (default 0)
+    sine_amp: amplitude of sine-wavefrom (default 0.1)
+    noise_std: std of Gaussian noise (default 0.003)
+    voiced_thoreshold: F0 threshold for U/V classification (default 0)
+    flag_for_pulse: this SinGen is used inside PulseGen (default False)
+
+    Note: when flag_for_pulse is True, the first time step of a voiced
+        segment is always sin(np.pi) or cos(0)
+    """
+
+    def __init__(self, samp_rate, harmonic_num=0,
+                 sine_amp=0.1, noise_std=0.003,
+                 voiced_threshold=0,
+                 flag_for_pulse=False):
+        super(SineGen, self).__init__()
+        self.sine_amp = sine_amp
+        self.noise_std = noise_std
+        self.harmonic_num = harmonic_num
+        self.dim = self.harmonic_num + 1
+        self.sampling_rate = samp_rate
+        self.voiced_threshold = voiced_threshold
+        self.flag_for_pulse = flag_for_pulse
+
+    def _f02uv(self, f0):
+        # generate uv signal
+        uv = torch.ones_like(f0)
+        uv = uv * (f0 > self.voiced_threshold)
+        return uv
+
+    def _f02sine(self, f0_values):
+        """ f0_values: (batchsize, length, dim)
+            where dim indicates fundamental tone and overtones
+        """
+        # convert to F0 in rad. The interger part n can be ignored
+        # because 2 * np.pi * n doesn't affect phase
+        rad_values = (f0_values / self.sampling_rate) % 1
+
+        # initial phase noise (no noise for fundamental component)
+        rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
+                              device=f0_values.device)
+        rand_ini[:, 0] = 0
+        rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+        # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+        if not self.flag_for_pulse:
+            # for normal case
+
+            # To prevent torch.cumsum numerical overflow,
+            # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+            # Buffer tmp_over_one_idx indicates the time step to add -1.
+            # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+            tmp_over_one = torch.cumsum(rad_values, 1) % 1
+            tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
+                                tmp_over_one[:, :-1, :]) < 0
+            cumsum_shift = torch.zeros_like(rad_values)
+            cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+            sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
+                              * 2 * np.pi)
+        else:
+            # If necessary, make sure that the first time step of every
+            # voiced segments is sin(pi) or cos(0)
+            # This is used for pulse-train generation
+
+            # identify the last time step in unvoiced segments
+            uv = self._f02uv(f0_values)
+            uv_1 = torch.roll(uv, shifts=-1, dims=1)
+            uv_1[:, -1, :] = 1
+            u_loc = (uv < 1) * (uv_1 > 0)
+
+            # get the instantanouse phase
+            tmp_cumsum = torch.cumsum(rad_values, dim=1)
+            # different batch needs to be processed differently
+            for idx in range(f0_values.shape[0]):
+                temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+                temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+                # stores the accumulation of i.phase within
+                # each voiced segments
+                tmp_cumsum[idx, :, :] = 0
+                tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+            # rad_values - tmp_cumsum: remove the accumulation of i.phase
+            # within the previous voiced segment.
+            i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+            # get the sines
+            sines = torch.cos(i_phase * 2 * np.pi)
+        return sines
+
+    def forward(self, f0):
+        """ sine_tensor, uv = forward(f0)
+        input F0: tensor(batchsize=1, length, dim=1)
+                  f0 for unvoiced steps should be 0
+        output sine_tensor: tensor(batchsize=1, length, dim)
+        output uv: tensor(batchsize=1, length, 1)
+        """
+        with torch.no_grad():
+            f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, \
+                                 device=f0.device)
+            # fundamental component
+            f0_buf[:, :, 0] = f0[:, :, 0]
+            for idx in np.arange(self.harmonic_num):
+                # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
+                f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
+
+            # generate sine waveforms
+            sine_waves = self._f02sine(f0_buf) * self.sine_amp
+
+            # generate uv signal
+            # uv = torch.ones(f0.shape)
+            # uv = uv * (f0 > self.voiced_threshold)
+            uv = self._f02uv(f0)
+
+            # noise: for unvoiced should be similar to sine_amp
+            #        std = self.sine_amp/3 -> max value ~ self.sine_amp
+            # .       for voiced regions is self.noise_std
+            noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+            noise = noise_amp * torch.randn_like(sine_waves)
+
+            # first: set the unvoiced part to 0 by uv
+            # then: additive noise
+            sine_waves = sine_waves * uv + noise
+        return sine_waves, uv, noise
+
+
+class SourceModuleCycNoise_v1(torch.nn.Module):
+    """ SourceModuleCycNoise_v1
+    SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0)
+    sampling_rate: sampling_rate in Hz
+
+    noise_std: std of Gaussian noise (default: 0.003)
+    voiced_threshold: threshold to set U/V given F0 (default: 0)
+
+    cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta)
+    F0_upsampled (batchsize, length, 1)
+    beta (1)
+    cyc (batchsize, length, 1)
+    noise (batchsize, length, 1)
+    uv (batchsize, length, 1)
+    """
+
+    def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0):
+        super(SourceModuleCycNoise_v1, self).__init__()
+        self.sampling_rate = sampling_rate
+        self.noise_std = noise_std
+        self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std,
+                                           voiced_threshod)
+
+    def forward(self, f0_upsamped, beta):
+        """
+        cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta)
+        F0_upsampled (batchsize, length, 1)
+        beta (1)
+        cyc (batchsize, length, 1)
+        noise (batchsize, length, 1)
+        uv (batchsize, length, 1)
+        """
+        # source for harmonic branch
+        cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta)
+
+        # source for noise branch, in the same shape as uv
+        noise = torch.randn_like(uv) * self.noise_std / 3
+        return cyc, noise, uv
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+    """ SourceModule for hn-nsf
+    SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+                 add_noise_std=0.003, voiced_threshod=0)
+    sampling_rate: sampling_rate in Hz
+    harmonic_num: number of harmonic above F0 (default: 0)
+    sine_amp: amplitude of sine source signal (default: 0.1)
+    add_noise_std: std of additive Gaussian noise (default: 0.003)
+        note that amplitude of noise in unvoiced is decided
+        by sine_amp
+    voiced_threshold: threhold to set U/V given F0 (default: 0)
+
+    Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+    F0_sampled (batchsize, length, 1)
+    Sine_source (batchsize, length, 1)
+    noise_source (batchsize, length 1)
+    uv (batchsize, length, 1)
+    """
+
+    def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
+                 add_noise_std=0.003, voiced_threshod=0):
+        super(SourceModuleHnNSF, self).__init__()
+
+        self.sine_amp = sine_amp
+        self.noise_std = add_noise_std
+
+        # to produce sine waveforms
+        self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+                                 sine_amp, add_noise_std, voiced_threshod)
+
+        # to merge source harmonics into a single excitation
+        self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+        self.l_tanh = torch.nn.Tanh()
+
+    def forward(self, x):
+        """
+        Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+        F0_sampled (batchsize, length, 1)
+        Sine_source (batchsize, length, 1)
+        noise_source (batchsize, length 1)
+        """
+        # source for harmonic branch
+        sine_wavs, uv, _ = self.l_sin_gen(x)
+        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+        # source for noise branch, in the same shape as uv
+        noise = torch.randn_like(uv) * self.sine_amp / 3
+        return sine_merge, noise, uv
+
+
+if __name__ == '__main__':
+    source = SourceModuleCycNoise_v1(24000)
+    x = torch.randn(16, 25600, 1)
+
+
diff --git a/modules/parallel_wavegan/optimizers/__init__.py b/modules/parallel_wavegan/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0e0c5932838281e912079e5784d84d43444a61a
--- /dev/null
+++ b/modules/parallel_wavegan/optimizers/__init__.py
@@ -0,0 +1,2 @@
+from torch.optim import *  # NOQA
+from .radam import *  # NOQA
diff --git a/modules/parallel_wavegan/optimizers/radam.py b/modules/parallel_wavegan/optimizers/radam.py
new file mode 100644
index 0000000000000000000000000000000000000000..e805d7e34921bee436e1e7fd9e1f753c7609186b
--- /dev/null
+++ b/modules/parallel_wavegan/optimizers/radam.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+
+"""RAdam optimizer.
+
+This code is drived from https://github.com/LiyuanLucasLiu/RAdam.
+"""
+
+import math
+import torch
+
+from torch.optim.optimizer import Optimizer
+
+
+class RAdam(Optimizer):
+    """Rectified Adam optimizer."""
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
+        """Initilize RAdam optimizer."""
+        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+        self.buffer = [[None, None, None] for ind in range(10)]
+        super(RAdam, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        """Set state."""
+        super(RAdam, self).__setstate__(state)
+
+    def step(self, closure=None):
+        """Run one step."""
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data.float()
+                if grad.is_sparse:
+                    raise RuntimeError('RAdam does not support sparse gradients')
+
+                p_data_fp32 = p.data.float()
+
+                state = self.state[p]
+
+                if len(state) == 0:
+                    state['step'] = 0
+                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
+                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+                else:
+                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+                state['step'] += 1
+                buffered = self.buffer[int(state['step'] % 10)]
+                if state['step'] == buffered[0]:
+                    N_sma, step_size = buffered[1], buffered[2]
+                else:
+                    buffered[0] = state['step']
+                    beta2_t = beta2 ** state['step']
+                    N_sma_max = 2 / (1 - beta2) - 1
+                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+                    buffered[1] = N_sma
+
+                    # more conservative since it's an approximated value
+                    if N_sma >= 5:
+                        step_size = math.sqrt(
+                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])  # NOQA
+                    else:
+                        step_size = 1.0 / (1 - beta1 ** state['step'])
+                    buffered[2] = step_size
+
+                if group['weight_decay'] != 0:
+                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+                # more conservative since it's an approximated value
+                if N_sma >= 5:
+                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+                else:
+                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+
+                p.data.copy_(p_data_fp32)
+
+        return loss
diff --git a/modules/parallel_wavegan/stft_loss.py b/modules/parallel_wavegan/stft_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..229e6c777dc9ec7f710842d1e648dba1189ec8b4
--- /dev/null
+++ b/modules/parallel_wavegan/stft_loss.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""STFT-based Loss modules."""
+import librosa
+import torch
+
+from modules.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft
+
+
+class STFTLoss(torch.nn.Module):
+    """STFT loss module."""
+
+    def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
+                 use_mel_loss=False):
+        """Initialize STFT loss module."""
+        super(STFTLoss, self).__init__()
+        self.fft_size = fft_size
+        self.shift_size = shift_size
+        self.win_length = win_length
+        self.window = getattr(torch, window)(win_length)
+        self.spectral_convergenge_loss = SpectralConvergengeLoss()
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+        self.use_mel_loss = use_mel_loss
+        self.mel_basis = None
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+
+        Returns:
+            Tensor: Spectral convergence loss value.
+            Tensor: Log STFT magnitude loss value.
+
+        """
+        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+        if self.use_mel_loss:
+            if self.mel_basis is None:
+                self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T
+            x_mag = x_mag @ self.mel_basis
+            y_mag = y_mag @ self.mel_basis
+
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+        return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+    """Multi resolution STFT loss module."""
+
+    def __init__(self,
+                 fft_sizes=[1024, 2048, 512],
+                 hop_sizes=[120, 240, 50],
+                 win_lengths=[600, 1200, 240],
+                 window="hann_window",
+                 use_mel_loss=False):
+        """Initialize Multi resolution STFT loss module.
+
+        Args:
+            fft_sizes (list): List of FFT sizes.
+            hop_sizes (list): List of hop sizes.
+            win_lengths (list): List of window lengths.
+            window (str): Window function type.
+
+        """
+        super(MultiResolutionSTFTLoss, self).__init__()
+        assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+        self.stft_losses = torch.nn.ModuleList()
+        for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+            self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)]
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+
+        Returns:
+            Tensor: Multi resolution spectral convergence loss value.
+            Tensor: Multi resolution log STFT magnitude loss value.
+
+        """
+        sc_loss = 0.0
+        mag_loss = 0.0
+        for f in self.stft_losses:
+            sc_l, mag_l = f(x, y)
+            sc_loss += sc_l
+            mag_loss += mag_l
+        sc_loss /= len(self.stft_losses)
+        mag_loss /= len(self.stft_losses)
+
+        return sc_loss, mag_loss
diff --git a/modules/parallel_wavegan/utils/__init__.py b/modules/parallel_wavegan/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8fa95a020706b5412c3959fbf6e5980019c0d5f
--- /dev/null
+++ b/modules/parallel_wavegan/utils/__init__.py
@@ -0,0 +1 @@
+from .utils import *  # NOQA
diff --git a/modules/parallel_wavegan/utils/utils.py b/modules/parallel_wavegan/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d48a5ed28e8555d4b8cfb15fdee86426bbb9e368
--- /dev/null
+++ b/modules/parallel_wavegan/utils/utils.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+"""Utility functions."""
+
+import fnmatch
+import logging
+import os
+import sys
+
+import h5py
+import numpy as np
+
+
+def find_files(root_dir, query="*.wav", include_root_dir=True):
+    """Find files recursively.
+
+    Args:
+        root_dir (str): Root root_dir to find.
+        query (str): Query to find.
+        include_root_dir (bool): If False, root_dir name is not included.
+
+    Returns:
+        list: List of found filenames.
+
+    """
+    files = []
+    for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
+        for filename in fnmatch.filter(filenames, query):
+            files.append(os.path.join(root, filename))
+    if not include_root_dir:
+        files = [file_.replace(root_dir + "/", "") for file_ in files]
+
+    return files
+
+
+def read_hdf5(hdf5_name, hdf5_path):
+    """Read hdf5 dataset.
+
+    Args:
+        hdf5_name (str): Filename of hdf5 file.
+        hdf5_path (str): Dataset name in hdf5 file.
+
+    Return:
+        any: Dataset values.
+
+    """
+    if not os.path.exists(hdf5_name):
+        logging.error(f"There is no such a hdf5 file ({hdf5_name}).")
+        sys.exit(1)
+
+    hdf5_file = h5py.File(hdf5_name, "r")
+
+    if hdf5_path not in hdf5_file:
+        logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})")
+        sys.exit(1)
+
+    hdf5_data = hdf5_file[hdf5_path][()]
+    hdf5_file.close()
+
+    return hdf5_data
+
+
+def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True):
+    """Write dataset to hdf5.
+
+    Args:
+        hdf5_name (str): Hdf5 dataset filename.
+        hdf5_path (str): Dataset path in hdf5.
+        write_data (ndarray): Data to write.
+        is_overwrite (bool): Whether to overwrite dataset.
+
+    """
+    # convert to numpy array
+    write_data = np.array(write_data)
+
+    # check folder existence
+    folder_name, _ = os.path.split(hdf5_name)
+    if not os.path.exists(folder_name) and len(folder_name) != 0:
+        os.makedirs(folder_name)
+
+    # check hdf5 existence
+    if os.path.exists(hdf5_name):
+        # if already exists, open with r+ mode
+        hdf5_file = h5py.File(hdf5_name, "r+")
+        # check dataset existence
+        if hdf5_path in hdf5_file:
+            if is_overwrite:
+                logging.warning("Dataset in hdf5 file already exists. "
+                                "recreate dataset in hdf5.")
+                hdf5_file.__delitem__(hdf5_path)
+            else:
+                logging.error("Dataset in hdf5 file already exists. "
+                              "if you want to overwrite, please set is_overwrite = True.")
+                hdf5_file.close()
+                sys.exit(1)
+    else:
+        # if not exists, open with w mode
+        hdf5_file = h5py.File(hdf5_name, "w")
+
+    # write data to hdf5
+    hdf5_file.create_dataset(hdf5_path, data=write_data)
+    hdf5_file.flush()
+    hdf5_file.close()
+
+
+class HDF5ScpLoader(object):
+    """Loader class for a fests.scp file of hdf5 file.
+
+    Examples:
+        key1 /some/path/a.h5:feats
+        key2 /some/path/b.h5:feats
+        key3 /some/path/c.h5:feats
+        key4 /some/path/d.h5:feats
+        ...
+        >>> loader = HDF5ScpLoader("hdf5.scp")
+        >>> array = loader["key1"]
+
+        key1 /some/path/a.h5
+        key2 /some/path/b.h5
+        key3 /some/path/c.h5
+        key4 /some/path/d.h5
+        ...
+        >>> loader = HDF5ScpLoader("hdf5.scp", "feats")
+        >>> array = loader["key1"]
+
+    """
+
+    def __init__(self, feats_scp, default_hdf5_path="feats"):
+        """Initialize HDF5 scp loader.
+
+        Args:
+            feats_scp (str): Kaldi-style feats.scp file with hdf5 format.
+            default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used.
+
+        """
+        self.default_hdf5_path = default_hdf5_path
+        with open(feats_scp) as f:
+            lines = [line.replace("\n", "") for line in f.readlines()]
+        self.data = {}
+        for line in lines:
+            key, value = line.split()
+            self.data[key] = value
+
+    def get_path(self, key):
+        """Get hdf5 file path for a given key."""
+        return self.data[key]
+
+    def __getitem__(self, key):
+        """Get ndarray for a given key."""
+        p = self.data[key]
+        if ":" in p:
+            return read_hdf5(*p.split(":"))
+        else:
+            return read_hdf5(p, self.default_hdf5_path)
+
+    def __len__(self):
+        """Return the length of the scp file."""
+        return len(self.data)
+
+    def __iter__(self):
+        """Return the iterator of the scp file."""
+        return iter(self.data)
+
+    def keys(self):
+        """Return the keys of the scp file."""
+        return self.data.keys()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b8f177662183b701add1f77712a22612682e45e2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+matplotlib
+librosa==0.8.0
+tqdm
+pandas
+numba==0.53.1
+PyYAML==5.3.1
+tensorboardX
+pyloudnorm
+setuptools>=41.0.0
+g2p_en
+resemblyzer
+webrtcvad
+tensorboard==2.6.0
+scikit-image
+textgrid
+jiwer
+pycwt
+PyWavelets
+praat-parselmouth==0.3.3
+jieba
+einops
+chardet
+h5py
diff --git a/tasks/base_task.py b/tasks/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa31903693c814af1e9a75cd64071e883dca4aa1
--- /dev/null
+++ b/tasks/base_task.py
@@ -0,0 +1,355 @@
+from itertools import chain
+
+from torch.utils.data import ConcatDataset
+from torch.utils.tensorboard import SummaryWriter
+import subprocess
+import traceback
+from datetime import datetime
+from functools import wraps
+from utils.hparams import hparams
+import random
+import sys
+import numpy as np
+from utils.trainer import Trainer
+from torch import nn
+import torch.utils.data
+import utils
+import logging
+import os
+
+torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
+
+log_format = '%(asctime)s %(message)s'
+logging.basicConfig(stream=sys.stdout, level=logging.INFO,
+                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
+
+
+def data_loader(fn):
+    """
+    Decorator to make any fx with this use the lazy property
+    :param fn:
+    :return:
+    """
+
+    wraps(fn)
+    attr_name = '_lazy_' + fn.__name__
+
+    def _get_data_loader(self):
+        try:
+            value = getattr(self, attr_name)
+        except AttributeError:
+            try:
+                value = fn(self)  # Lazy evaluation, done only once.
+            except AttributeError as e:
+                # Guard against AttributeError suppression. (Issue #142)
+                traceback.print_exc()
+                error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
+                raise RuntimeError(error) from e
+            setattr(self, attr_name, value)  # Memoize evaluation.
+        return value
+
+    return _get_data_loader
+
+
+class BaseDataset(torch.utils.data.Dataset):
+    def __init__(self, shuffle):
+        super().__init__()
+        self.hparams = hparams
+        self.shuffle = shuffle
+        self.sort_by_len = hparams['sort_by_len']
+        self.sizes = None
+
+    @property
+    def _sizes(self):
+        return self.sizes
+
+    def __getitem__(self, index):
+        raise NotImplementedError
+
+    def collater(self, samples):
+        raise NotImplementedError
+
+    def __len__(self):
+        return len(self._sizes)
+
+    def num_tokens(self, index):
+        return self.size(index)
+
+    def size(self, index):
+        """Return an example's size as a float or tuple. This value is used when
+        filtering a dataset with ``--max-positions``."""
+        return min(self._sizes[index], hparams['max_frames'])
+
+    def ordered_indices(self):
+        """Return an ordered list of indices. Batches will be constructed based
+        on this order."""
+        if self.shuffle:
+            indices = np.random.permutation(len(self))
+            if self.sort_by_len:
+                indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
+        else:
+            indices = np.arange(len(self))
+        return indices
+
+    @property
+    def num_workers(self):
+        return int(os.getenv('NUM_WORKERS', hparams['ds_workers']))
+
+
+class BaseConcatDataset(ConcatDataset):
+    def collater(self, samples):
+        return self.datasets[0].collater(samples)
+
+    @property
+    def _sizes(self):
+        if not hasattr(self, 'sizes'):
+            self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets]))
+        return self.sizes
+
+    def size(self, index):
+        return min(self._sizes[index], hparams['max_frames'])
+
+    def num_tokens(self, index):
+        return self.size(index)
+
+    def ordered_indices(self):
+        """Return an ordered list of indices. Batches will be constructed based
+        on this order."""
+        if self.datasets[0].shuffle:
+            indices = np.random.permutation(len(self))
+            if self.datasets[0].sort_by_len:
+                indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
+        else:
+            indices = np.arange(len(self))
+        return indices
+
+    @property
+    def num_workers(self):
+        return self.datasets[0].num_workers
+
+
+class BaseTask(nn.Module):
+    def __init__(self, *args, **kwargs):
+        # dataset configs
+        super(BaseTask, self).__init__()
+        self.current_epoch = 0
+        self.global_step = 0
+        self.trainer = None
+        self.use_ddp = False
+        self.gradient_clip_norm = hparams['clip_grad_norm']
+        self.gradient_clip_val = hparams.get('clip_grad_value', 0)
+        self.model = None
+        self.training_losses_meter = None
+        self.logger: SummaryWriter = None
+
+    ######################
+    # build model, dataloaders, optimizer, scheduler and tensorboard
+    ######################
+    def build_model(self):
+        raise NotImplementedError
+
+    @data_loader
+    def train_dataloader(self):
+        raise NotImplementedError
+
+    @data_loader
+    def test_dataloader(self):
+        raise NotImplementedError
+
+    @data_loader
+    def val_dataloader(self):
+        raise NotImplementedError
+
+    def build_scheduler(self, optimizer):
+        return None
+
+    def build_optimizer(self, model):
+        raise NotImplementedError
+
+    def configure_optimizers(self):
+        optm = self.build_optimizer(self.model)
+        self.scheduler = self.build_scheduler(optm)
+        if isinstance(optm, (list, tuple)):
+            return optm
+        return [optm]
+
+    def build_tensorboard(self, save_dir, name, version, **kwargs):
+        root_dir = os.path.join(save_dir, name)
+        os.makedirs(root_dir, exist_ok=True)
+        log_dir = os.path.join(root_dir, "version_" + str(version))
+        self.logger = SummaryWriter(log_dir=log_dir, **kwargs)
+
+    ######################
+    # training
+    ######################
+    def on_train_start(self):
+        pass
+
+    def on_epoch_start(self):
+        self.training_losses_meter = {'total_loss': utils.AvgrageMeter()}
+
+    def _training_step(self, sample, batch_idx, optimizer_idx):
+        """
+
+        :param sample:
+        :param batch_idx:
+        :return: total loss: torch.Tensor, loss_log: dict
+        """
+        raise NotImplementedError
+
+    def training_step(self, sample, batch_idx, optimizer_idx=-1):
+        """
+
+        :param sample:
+        :param batch_idx:
+        :param optimizer_idx:
+        :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict}
+        """
+        loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
+        if loss_ret is None:
+            return {'loss': None}
+        total_loss, log_outputs = loss_ret
+        log_outputs = utils.tensors_to_scalars(log_outputs)
+        for k, v in log_outputs.items():
+            if k not in self.training_losses_meter:
+                self.training_losses_meter[k] = utils.AvgrageMeter()
+            if not np.isnan(v):
+                self.training_losses_meter[k].update(v)
+        self.training_losses_meter['total_loss'].update(total_loss.item())
+
+        if optimizer_idx >= 0:
+            log_outputs[f'lr_{optimizer_idx}'] = self.trainer.optimizers[optimizer_idx].param_groups[0]['lr']
+
+        progress_bar_log = log_outputs
+        tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
+        return {
+            'loss': total_loss,
+            'progress_bar': progress_bar_log,
+            'tb_log': tb_log
+        }
+
+    def on_before_optimization(self, opt_idx):
+        if self.gradient_clip_norm > 0:
+            torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm)
+        if self.gradient_clip_val > 0:
+            torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val)
+
+    def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
+        if self.scheduler is not None:
+            self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
+
+    def on_epoch_end(self):
+        loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()}
+        print(f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}")
+
+    def on_train_end(self):
+        pass
+
+    ######################
+    # validation
+    ######################
+    def validation_step(self, sample, batch_idx):
+        """
+
+        :param sample:
+        :param batch_idx:
+        :return: output: {"losses": {...}, "total_loss": float, ...} or (total loss: torch.Tensor, loss_log: dict)
+        """
+        raise NotImplementedError
+
+    def validation_end(self, outputs):
+        """
+
+        :param outputs:
+        :return: loss_output: dict
+        """
+        all_losses_meter = {'total_loss': utils.AvgrageMeter()}
+        for output in outputs:
+            if len(output) == 0 or output is None:
+                continue
+            if isinstance(output, dict):
+                assert 'losses' in output, 'Key "losses" should exist in validation output.'
+                n = output.pop('nsamples', 1)
+                losses = utils.tensors_to_scalars(output['losses'])
+                total_loss = output.get('total_loss', sum(losses.values()))
+            else:
+                assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)'
+                n = 1
+                total_loss, losses = output
+                losses = utils.tensors_to_scalars(losses)
+            if isinstance(total_loss, torch.Tensor):
+                total_loss = total_loss.item()
+            for k, v in losses.items():
+                if k not in all_losses_meter:
+                    all_losses_meter[k] = utils.AvgrageMeter()
+                all_losses_meter[k].update(v, n)
+            all_losses_meter['total_loss'].update(total_loss, n)
+        loss_output = {k: round(v.avg, 4) for k, v in all_losses_meter.items()}
+        print(f"| Valid results: {loss_output}")
+        return {
+            'tb_log': {f'val/{k}': v for k, v in loss_output.items()},
+            'val_loss': loss_output['total_loss']
+        }
+
+    ######################
+    # testing
+    ######################
+    def test_start(self):
+        pass
+
+    def test_step(self, sample, batch_idx):
+        return self.validation_step(sample, batch_idx)
+
+    def test_end(self, outputs):
+        return self.validation_end(outputs)
+
+    ######################
+    # utils
+    ######################
+    def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True):
+        if current_model_name is None:
+            current_model_name = model_name
+        utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict)
+
+    ######################
+    # start training/testing
+    ######################
+    @classmethod
+    def start(cls):
+        os.environ['MASTER_PORT'] = str(random.randint(15000, 30000))
+        random.seed(hparams['seed'])
+        np.random.seed(hparams['seed'])
+        work_dir = hparams['work_dir']
+        trainer = Trainer(
+            work_dir=work_dir,
+            val_check_interval=hparams['val_check_interval'],
+            tb_log_interval=hparams['tb_log_interval'],
+            max_updates=hparams['max_updates'],
+            num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000,
+            accumulate_grad_batches=hparams['accumulate_grad_batches'],
+            print_nan_grads=hparams['print_nan_grads'],
+            resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0),
+            amp=hparams['amp'],
+            # save ckpt
+            monitor_key=hparams['valid_monitor_key'],
+            monitor_mode=hparams['valid_monitor_mode'],
+            num_ckpt_keep=hparams['num_ckpt_keep'],
+            save_best=hparams['save_best'],
+            seed=hparams['seed'],
+            debug=hparams['debug']
+        )
+        if not hparams['inference']:  # train
+            if len(hparams['save_codes']) > 0:
+                t = datetime.now().strftime('%Y%m%d%H%M%S')
+                code_dir = f'{work_dir}/codes/{t}'
+                subprocess.check_call(f'mkdir -p "{code_dir}"', shell=True)
+                for c in hparams['save_codes']:
+                    if os.path.exists(c):
+                        subprocess.check_call(f'rsync -av --exclude=__pycache__  "{c}" "{code_dir}/"', shell=True)
+                print(f"| Copied codes to {code_dir}.")
+            trainer.fit(cls)
+        else:
+            trainer.test(cls)
+
+    def on_keyboard_interrupt(self):
+        pass
diff --git a/tasks/run.py b/tasks/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..82c7559cec873eebf7c2c0ab6554895e21de7e7c
--- /dev/null
+++ b/tasks/run.py
@@ -0,0 +1,15 @@
+import importlib
+from utils.hparams import set_hparams, hparams
+
+
+def run_task():
+    assert hparams['task_cls'] != ''
+    pkg = ".".join(hparams["task_cls"].split(".")[:-1])
+    cls_name = hparams["task_cls"].split(".")[-1]
+    task_cls = getattr(importlib.import_module(pkg), cls_name)
+    task_cls.start()
+
+
+if __name__ == '__main__':
+    set_hparams()
+    run_task()
diff --git a/tasks/tts/dataset_utils.py b/tasks/tts/dataset_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..488e616dd63cb8fdf30c47e037a2acc21c41c7f3
--- /dev/null
+++ b/tasks/tts/dataset_utils.py
@@ -0,0 +1,260 @@
+from utils.cwt import get_lf0_cwt
+import torch.optim
+import torch.utils.data
+import importlib
+from utils.indexed_datasets import IndexedDataset
+from utils.pitch_utils import norm_interp_f0, denorm_f0, f0_to_coarse
+import numpy as np
+from tasks.base_task import BaseDataset
+import torch
+import torch.optim
+import torch.utils.data
+import utils
+import torch.distributions
+from utils.hparams import hparams
+from utils.pitch_utils import norm_interp_f0
+from resemblyzer import VoiceEncoder
+import json
+from data_gen.tts.data_gen_utils import build_phone_encoder
+
+class BaseTTSDataset(BaseDataset):
+    def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None):
+        super().__init__(shuffle)
+        self.data_dir = hparams['binary_data_dir'] if data_dir is None else data_dir
+        self.prefix = prefix
+        self.hparams = hparams
+        self.indexed_ds = None
+        self.ext_mel2ph = None
+
+        def load_size():
+            self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
+
+        if prefix == 'test' or hparams['inference']:
+            if test_items is not None:
+                self.indexed_ds, self.sizes = test_items, test_sizes
+            else:
+                load_size()
+            if hparams['num_test_samples'] > 0:
+                self.avail_idxs = [x for x in range(hparams['num_test_samples']) \
+                                   if x < len(self.sizes)]
+                if len(hparams['test_ids']) > 0:
+                    self.avail_idxs = hparams['test_ids'] + self.avail_idxs
+            else:
+                self.avail_idxs = list(range(len(self.sizes)))
+        else:
+            load_size()
+            self.avail_idxs = list(range(len(self.sizes)))
+
+        if hparams['min_frames'] > 0:
+            self.avail_idxs = [
+                x for x in self.avail_idxs if self.sizes[x] >= hparams['min_frames']]
+        self.sizes = [self.sizes[i] for i in self.avail_idxs]
+
+    def _get_item(self, index):
+        if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
+            index = self.avail_idxs[index]
+        if self.indexed_ds is None:
+            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
+        return self.indexed_ds[index]
+
+    def __getitem__(self, index):
+        hparams = self.hparams
+        item = self._get_item(index)
+        assert len(item['mel']) == self.sizes[index], (len(item['mel']), self.sizes[index])
+        max_frames = hparams['max_frames']
+        spec = torch.Tensor(item['mel'])[:max_frames]
+        max_frames = spec.shape[0] // hparams['frames_multiple'] * hparams['frames_multiple']
+        spec = spec[:max_frames]
+        phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']])
+        sample = {
+            "id": index,
+            "item_name": item['item_name'],
+            "text": item['txt'],
+            "txt_token": phone,
+            "mel": spec,
+            "mel_nonpadding": spec.abs().sum(-1) > 0,
+        }
+        if hparams['use_spk_embed']:
+            sample["spk_embed"] = torch.Tensor(item['spk_embed'])
+        if hparams['use_spk_id']:
+            sample["spk_id"] = item['spk_id']
+        return sample
+
+    def collater(self, samples):
+        if len(samples) == 0:
+            return {}
+        hparams = self.hparams
+        id = torch.LongTensor([s['id'] for s in samples])
+        item_names = [s['item_name'] for s in samples]
+        text = [s['text'] for s in samples]
+        txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
+        mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
+        txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
+        mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
+
+        batch = {
+            'id': id,
+            'item_name': item_names,
+            'nsamples': len(samples),
+            'text': text,
+            'txt_tokens': txt_tokens,
+            'txt_lengths': txt_lengths,
+            'mels': mels,
+            'mel_lengths': mel_lengths,
+        }
+
+        if hparams['use_spk_embed']:
+            spk_embed = torch.stack([s['spk_embed'] for s in samples])
+            batch['spk_embed'] = spk_embed
+        if hparams['use_spk_id']:
+            spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
+            batch['spk_ids'] = spk_ids
+        return batch
+
+
+class FastSpeechDataset(BaseTTSDataset):
+    def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None):
+        super().__init__(prefix, shuffle, test_items, test_sizes, data_dir)
+        self.f0_mean, self.f0_std = hparams.get('f0_mean', None), hparams.get('f0_std', None)
+        if prefix == 'test' and hparams['test_input_dir'] != '':
+            self.data_dir = hparams['test_input_dir']
+            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
+            self.indexed_ds = sorted(self.indexed_ds, key=lambda item: item['item_name'])
+            items = {}
+            for i in range(len(self.indexed_ds)):
+                speaker = self.indexed_ds[i]['item_name'].split('_')[0]
+                if speaker not in items.keys():
+                    items[speaker] = [i]
+                else:
+                    items[speaker].append(i)
+            sort_item = sorted(items.values(), key=lambda item_pre_speaker: len(item_pre_speaker), reverse=True)
+            self.avail_idxs = [n for a in sort_item for n in a][:hparams['num_test_samples']]
+            self.indexed_ds, self.sizes = self.load_test_inputs()
+            self.avail_idxs = [i for i in range(hparams['num_test_samples'])]
+
+        if hparams['pitch_type'] == 'cwt':
+            _, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10))
+
+    def __getitem__(self, index):
+        sample = super(FastSpeechDataset, self).__getitem__(index)
+        item = self._get_item(index)
+        hparams = self.hparams
+        max_frames = hparams['max_frames']
+        spec = sample['mel']
+        T = spec.shape[0]
+        phone = sample['txt_token']
+        sample['energy'] = (spec.exp() ** 2).sum(-1).sqrt()
+        sample['mel2ph'] = mel2ph = torch.LongTensor(item['mel2ph'])[:T] if 'mel2ph' in item else None
+        if hparams['use_pitch_embed']:
+            assert 'f0' in item
+            if hparams.get('normalize_pitch', False):
+                f0 = item["f0"]
+                if len(f0 > 0) > 0 and f0[f0 > 0].std() > 0:
+                    f0[f0 > 0] = (f0[f0 > 0] - f0[f0 > 0].mean()) / f0[f0 > 0].std() * hparams['f0_std'] + \
+                                 hparams['f0_mean']
+                    f0[f0 > 0] = f0[f0 > 0].clip(min=60, max=500)
+                pitch = f0_to_coarse(f0)
+                pitch = torch.LongTensor(pitch[:max_frames])
+            else:
+                pitch = torch.LongTensor(item.get("pitch"))[:max_frames] if "pitch" in item else None
+            f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
+            uv = torch.FloatTensor(uv)
+            f0 = torch.FloatTensor(f0)
+            if hparams['pitch_type'] == 'cwt':
+                cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames]
+                f0_mean = item.get('f0_mean', item.get('cwt_mean'))
+                f0_std = item.get('f0_std', item.get('cwt_std'))
+                sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
+            elif hparams['pitch_type'] == 'ph':
+                if "f0_ph" in item:
+                    f0 = torch.FloatTensor(item['f0_ph'])
+                else:
+                    f0 = denorm_f0(f0, None, hparams)
+                f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0)
+                f0_phlevel_num = torch.zeros_like(phone).float().scatter_add(
+                    0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1)
+                f0_ph = f0_phlevel_sum / f0_phlevel_num
+                f0, uv = norm_interp_f0(f0_ph, hparams)
+        else:
+            f0 = uv = torch.zeros_like(mel2ph)
+            pitch = None
+        sample["f0"], sample["uv"], sample["pitch"] = f0, uv, pitch
+        if hparams['use_spk_embed']:
+            sample["spk_embed"] = torch.Tensor(item['spk_embed'])
+        if hparams['use_spk_id']:
+            sample["spk_id"] = item['spk_id']
+        return sample
+
+    def collater(self, samples):
+        if len(samples) == 0:
+            return {}
+        hparams = self.hparams
+        batch = super(FastSpeechDataset, self).collater(samples)
+        f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
+        pitch = utils.collate_1d([s['pitch'] for s in samples]) if samples[0]['pitch'] is not None else None
+        uv = utils.collate_1d([s['uv'] for s in samples])
+        energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
+        mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
+            if samples[0]['mel2ph'] is not None else None
+        batch.update({
+            'mel2ph': mel2ph,
+            'energy': energy,
+            'pitch': pitch,
+            'f0': f0,
+            'uv': uv,
+        })
+        if hparams['pitch_type'] == 'cwt':
+            cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples])
+            f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
+            f0_std = torch.Tensor([s['f0_std'] for s in samples])
+            batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
+        return batch
+
+    def load_test_inputs(self):
+        binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
+        pkg = ".".join(binarizer_cls.split(".")[:-1])
+        cls_name = binarizer_cls.split(".")[-1]
+        binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
+        ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
+        ph_set = json.load(open(ph_set_fn, 'r'))
+        print("| phone set: ", ph_set)
+        phone_encoder = build_phone_encoder(hparams['binary_data_dir'])
+        word_encoder = None
+        voice_encoder = VoiceEncoder().cuda()
+        encoder = [phone_encoder, word_encoder]
+        sizes = []
+        items = []
+        for i in range(len(self.avail_idxs)):
+            item = self._get_item(i)
+
+            item2tgfn = f"{hparams['test_input_dir'].replace('binary', 'processed')}/mfa_outputs/{item['item_name']}.TextGrid"
+            item = binarizer_cls.process_item(item['item_name'], item['ph'], item['txt'], item2tgfn,
+                                              item['wav_fn'], item['spk_id'], encoder, hparams['binarization_args'])
+            item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
+                if hparams['binarization_args']['with_spk_embed'] else None  # 判断是否保存embedding文件
+            items.append(item)
+            sizes.append(item['len'])
+        return items, sizes
+
+class FastSpeechWordDataset(FastSpeechDataset):
+    def __getitem__(self, index):
+        sample = super(FastSpeechWordDataset, self).__getitem__(index)
+        item = self._get_item(index)
+        max_frames = hparams['max_frames']
+        sample["ph_words"] = item["ph_words"]
+        sample["word_tokens"] = torch.LongTensor(item["word_tokens"])
+        sample["mel2word"] = torch.LongTensor(item.get("mel2word"))[:max_frames]
+        sample["ph2word"] = torch.LongTensor(item['ph2word'][:hparams['max_input_tokens']])
+        return sample
+
+    def collater(self, samples):
+        batch = super(FastSpeechWordDataset, self).collater(samples)
+        ph_words = [s['ph_words'] for s in samples]
+        batch['ph_words'] = ph_words
+        word_tokens = utils.collate_1d([s['word_tokens'] for s in samples], 0)
+        batch['word_tokens'] = word_tokens
+        mel2word = utils.collate_1d([s['mel2word'] for s in samples], 0)
+        batch['mel2word'] = mel2word
+        ph2word = utils.collate_1d([s['ph2word'] for s in samples], 0)
+        batch['ph2word'] = ph2word
+        return batch
diff --git a/tasks/tts/fs2.py b/tasks/tts/fs2.py
new file mode 100755
index 0000000000000000000000000000000000000000..473c514b523ecbd45acfdecdb33d7b633c59eb6c
--- /dev/null
+++ b/tasks/tts/fs2.py
@@ -0,0 +1,292 @@
+import matplotlib
+matplotlib.use('Agg')
+
+from tasks.tts.tts_base import TTSBaseTask
+from vocoders.base_vocoder import get_vocoder_cls
+from tasks.tts.dataset_utils import FastSpeechDataset
+from modules.commons.ssim import ssim
+import os
+from modules.fastspeech.tts_modules import mel2ph_to_dur
+from utils.hparams import hparams
+from utils.plot import spec_to_figure, dur_to_figure, f0_to_figure
+from utils.pitch_utils import denorm_f0
+from modules.fastspeech.fs2 import FastSpeech2
+import torch
+import torch.optim
+import torch.utils.data
+import torch.nn.functional as F
+import utils
+import torch.distributions
+import numpy as np
+
+
+class FastSpeech2Task(TTSBaseTask):
+    def __init__(self):
+        super(FastSpeech2Task, self).__init__()
+        self.dataset_cls = FastSpeechDataset
+        self.mse_loss_fn = torch.nn.MSELoss()
+        mel_losses = hparams['mel_loss'].split("|")
+        self.loss_and_lambda = {}
+        for i, l in enumerate(mel_losses):
+            if l == '':
+                continue
+            if ':' in l:
+                l, lbd = l.split(":")
+                lbd = float(lbd)
+            else:
+                lbd = 1.0
+            self.loss_and_lambda[l] = lbd
+        print("| Mel losses:", self.loss_and_lambda)
+        self.sil_ph = self.phone_encoder.sil_phonemes()
+        f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy'
+        if os.path.exists(f0_stats_fn):
+            hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn)
+            hparams['f0_mean'] = float(hparams['f0_mean'])
+            hparams['f0_std'] = float(hparams['f0_std'])
+
+    def build_tts_model(self):
+        self.model = FastSpeech2(self.phone_encoder)
+
+    def build_model(self):
+        self.build_tts_model()
+        if hparams['load_ckpt'] != '':
+            self.load_ckpt(hparams['load_ckpt'], strict=False)
+        utils.print_arch(self.model)
+        return self.model
+
+    def _training_step(self, sample, batch_idx, _):
+        loss_output = self.run_model(self.model, sample)
+        total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
+        loss_output['batch_size'] = sample['txt_tokens'].size()[0]
+        return total_loss, loss_output
+
+    def validation_step(self, sample, batch_idx):
+        outputs = {}
+        outputs['losses'] = {}
+        outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True)
+        outputs['total_loss'] = sum(outputs['losses'].values())
+        outputs['nsamples'] = sample['nsamples']
+        mel_out = self.model.out2mel(model_out['mel_out'])
+        outputs = utils.tensors_to_scalars(outputs)
+        if self.global_step % hparams['valid_infer_interval'] == 0 \
+                and batch_idx < hparams['num_valid_plots']:
+            vmin = hparams['mel_vmin']
+            vmax = hparams['mel_vmax']
+            self.plot_mel(batch_idx, sample['mels'], mel_out)
+            self.plot_dur(batch_idx, sample, model_out)
+            if hparams['use_pitch_embed']:
+                self.plot_pitch(batch_idx, sample, model_out)
+            if self.vocoder is None:
+                self.vocoder = get_vocoder_cls(hparams)()
+            if self.global_step > 0:
+                spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+                # with gt duration
+                model_out = self.model(sample['txt_tokens'], mel2ph=sample['mel2ph'],
+                                       spk_embed=spk_embed, infer=True)
+                wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu())
+                self.logger.add_audio(f'wav_gtdur_{batch_idx}', wav_pred, self.global_step,
+                                      hparams['audio_sample_rate'])
+                self.logger.add_figure(
+                    f'mel_gtdur_{batch_idx}',
+                    spec_to_figure(model_out['mel_out'][0], vmin, vmax), self.global_step)
+                # with pred duration
+                model_out = self.model(sample['txt_tokens'], spk_embed=spk_embed, infer=True)
+                self.logger.add_figure(
+                    f'mel_{batch_idx}',
+                    spec_to_figure(model_out['mel_out'][0], vmin, vmax), self.global_step)
+                wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu())
+                self.logger.add_audio(f'wav_{batch_idx}', wav_pred, self.global_step, hparams['audio_sample_rate'])
+            # gt wav
+            if self.global_step <= hparams['valid_infer_interval']:
+                mel_gt = sample['mels'][0].cpu()
+                wav_gt = self.vocoder.spec2wav(mel_gt)
+                self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, 22050)
+        return outputs
+
+    def run_model(self, model, sample, return_output=False):
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+        target = sample['mels']  # [B, T_s, 80]
+        mel2ph = sample['mel2ph']  # [B, T_s]
+        f0 = sample['f0']
+        uv = sample['uv']
+        energy = sample['energy']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
+                       ref_mels=target, f0=f0, uv=uv, energy=energy,
+                       tgt_mels=target, infer=False)
+        losses = {}
+        self.add_mel_loss(output['mel_out'], target, losses)
+        self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
+        if hparams['use_pitch_embed']:
+            self.add_pitch_loss(output, sample, losses)
+        if not return_output:
+            return losses
+        else:
+            return losses, output
+
+    ############
+    # losses
+    ############
+    def add_mel_loss(self, mel_out, target, losses, postfix='', mel_mix_loss=None):
+        nonpadding = target.abs().sum(-1).ne(0).float()
+        for loss_name, lbd in self.loss_and_lambda.items():
+            if 'l1' == loss_name:
+                l = self.l1_loss(mel_out, target)
+            elif 'mse' == loss_name:
+                l = self.mse_loss(mel_out, target)
+            elif 'ssim' == loss_name:
+                l = self.ssim_loss(mel_out, target)
+            elif 'gdl' == loss_name:
+                l = self.gdl_loss_fn(mel_out, target, nonpadding) \
+                    * self.loss_and_lambda['gdl']
+            losses[f'{loss_name}{postfix}'] = l * lbd
+
+    def l1_loss(self, decoder_output, target):
+        # decoder_output : B x T x n_mel
+        # target : B x T x n_mel
+        l1_loss = F.l1_loss(decoder_output, target, reduction='none')
+        weights = self.weights_nonzero_speech(target)
+        l1_loss = (l1_loss * weights).sum() / weights.sum()
+        return l1_loss
+
+    def add_energy_loss(self, energy_pred, energy, losses):
+        nonpadding = (energy != 0).float()
+        loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum()
+        loss = loss * hparams['lambda_energy']
+        losses['e'] = loss
+
+    def mse_loss(self, decoder_output, target):
+        # decoder_output : B x T x n_mel
+        # target : B x T x n_mel
+        assert decoder_output.shape == target.shape
+        mse_loss = F.mse_loss(decoder_output, target, reduction='none')
+        weights = self.weights_nonzero_speech(target)
+        mse_loss = (mse_loss * weights).sum() / weights.sum()
+        return mse_loss
+
+    def ssim_loss(self, decoder_output, target, bias=6.0):
+        # decoder_output : B x T x n_mel
+        # target : B x T x n_mel
+        assert decoder_output.shape == target.shape
+        weights = self.weights_nonzero_speech(target)
+        decoder_output = decoder_output[:, None] + bias
+        target = target[:, None] + bias
+        ssim_loss = 1 - ssim(decoder_output, target, size_average=False)
+        ssim_loss = (ssim_loss * weights).sum() / weights.sum()
+        return ssim_loss
+
+    def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, losses=None):
+        """
+
+        :param dur_pred: [B, T], float, log scale
+        :param mel2ph: [B, T]
+        :param txt_tokens: [B, T]
+        :param losses:
+        :return:
+        """
+        B, T = txt_tokens.shape
+        nonpadding = (txt_tokens != 0).float()
+        dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding
+        is_sil = torch.zeros_like(txt_tokens).bool()
+        for p in self.sil_ph:
+            is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0])
+        is_sil = is_sil.float()  # [B, T_txt]
+        losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none')
+        losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum()
+        losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur']
+        dur_pred = (dur_pred.exp() - 1).clamp(min=0)
+        # use linear scale for sent and word duration
+        if hparams['lambda_word_dur'] > 0:
+            word_id = (is_sil.cumsum(-1) * (1 - is_sil)).long()
+            word_dur_p = dur_pred.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_pred)[:, 1:]
+            word_dur_g = dur_gt.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_gt)[:, 1:]
+            wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none')
+            word_nonpadding = (word_dur_g > 0).float()
+            wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum()
+            losses['wdur'] = wdur_loss * hparams['lambda_word_dur']
+        if hparams['lambda_sent_dur'] > 0:
+            sent_dur_p = dur_pred.sum(-1)
+            sent_dur_g = dur_gt.sum(-1)
+            sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean')
+            losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur']
+
+    def add_pitch_loss(self, output, sample, losses):
+        mel2ph = sample['mel2ph']  # [B, T_s]
+        f0 = sample['f0']
+        uv = sample['uv']
+        nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \
+            else (sample['txt_tokens'] != 0).float()
+        self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) # output['pitch_pred']: [B, T, 2], f0: [B, T], uv: [B, T]
+
+    def add_f0_loss(self, p_pred, f0, uv, losses, nonpadding, postfix=''):
+        assert p_pred[..., 0].shape == f0.shape
+        if hparams['use_uv'] and hparams['pitch_type'] == 'frame':
+            assert p_pred[..., 1].shape == uv.shape, (p_pred.shape, uv.shape)
+            losses[f'uv{postfix}'] = (F.binary_cross_entropy_with_logits(
+                p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \
+                                     / nonpadding.sum() * hparams['lambda_uv']
+            nonpadding = nonpadding * (uv == 0).float()
+        f0_pred = p_pred[:, :, 0]
+        pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss
+        losses[f'f0{postfix}'] = (pitch_loss_fn(f0_pred, f0, reduction='none') * nonpadding).sum() \
+                                 / nonpadding.sum() * hparams['lambda_f0']
+
+
+    ############
+    # validation plots
+    ############
+    def plot_dur(self, batch_idx, sample, model_out):
+        T_txt = sample['txt_tokens'].shape[1]
+        dur_gt = mel2ph_to_dur(sample['mel2ph'], T_txt)[0]
+        dur_pred = model_out['dur']
+        if hasattr(self.model, 'out2dur'):
+            dur_pred = self.model.out2dur(model_out['dur']).float()
+        txt = self.phone_encoder.decode(sample['txt_tokens'][0].cpu().numpy())
+        txt = txt.split(" ")
+        self.logger.add_figure(
+            f'dur_{batch_idx}', dur_to_figure(dur_gt, dur_pred, txt), self.global_step)
+
+    def plot_pitch(self, batch_idx, sample, model_out):
+        self.logger.add_figure(
+            f'f0_{batch_idx}',
+            f0_to_figure(model_out['f0_denorm'][0], None, model_out['f0_denorm_pred'][0]),
+            self.global_step)
+
+    ############
+    # inference
+    ############
+    def test_step(self, sample, batch_idx):
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        txt_tokens = sample['txt_tokens']
+        mel2ph, uv, f0 = None, None, None
+        ref_mels = sample['mels']
+        if hparams['use_gt_dur']:
+            mel2ph = sample['mel2ph']
+        if hparams['use_gt_f0']:
+            f0 = sample['f0']
+            uv = sample['uv']
+        run_model = lambda: self.model(
+            txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True)
+        if hparams['profile_infer']:
+            mel2ph, uv, f0 = sample['mel2ph'], sample['uv'], sample['f0']
+            with utils.Timer('fs', enable=True):
+                outputs = run_model()
+            if 'gen_wav_time' not in self.stats:
+                self.stats['gen_wav_time'] = 0
+            wav_time = float(outputs["mels_out"].shape[1]) * hparams['hop_size'] / hparams["audio_sample_rate"]
+            self.stats['gen_wav_time'] += wav_time
+            print(f'[Timer] wav total seconds: {self.stats["gen_wav_time"]}')
+            from pytorch_memlab import LineProfiler
+            with LineProfiler(self.model.forward) as prof:
+                run_model()
+            prof.print_stats()
+        else:
+            outputs = run_model()
+            sample['outputs'] = self.model.out2mel(outputs['mel_out'])
+            sample['mel2ph_pred'] = outputs['mel2ph']
+            if hparams['use_pitch_embed']:
+                sample['f0'] = denorm_f0(sample['f0'], sample['uv'], hparams)
+                if hparams['pitch_type'] == 'ph':
+                    sample['f0'] = torch.gather(F.pad(sample['f0'], [1, 0]), 1, sample['mel2ph'])
+                sample['f0_pred'] = outputs.get('f0_denorm')
+            return self.after_infer(sample)
diff --git a/tasks/tts/fs2_utils.py b/tasks/tts/fs2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..092550863d2fd72f008cc790bc6d950340e68182
--- /dev/null
+++ b/tasks/tts/fs2_utils.py
@@ -0,0 +1,173 @@
+import matplotlib
+
+matplotlib.use('Agg')
+
+import glob
+import importlib
+from utils.cwt import get_lf0_cwt
+import os
+import torch.optim
+import torch.utils.data
+from utils.indexed_datasets import IndexedDataset
+from utils.pitch_utils import norm_interp_f0
+import numpy as np
+from tasks.base_task import BaseDataset
+import torch
+import torch.optim
+import torch.utils.data
+import utils
+import torch.distributions
+from utils.hparams import hparams
+
+
+class FastSpeechDataset(BaseDataset):
+    def __init__(self, prefix, shuffle=False):
+        super().__init__(shuffle)
+        self.data_dir = hparams['binary_data_dir']
+        self.prefix = prefix
+        self.hparams = hparams
+        self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
+        self.indexed_ds = None
+        # self.name2spk_id={}
+
+        # pitch stats
+        f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
+        if os.path.exists(f0_stats_fn):
+            hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
+            hparams['f0_mean'] = float(hparams['f0_mean'])
+            hparams['f0_std'] = float(hparams['f0_std'])
+        else:
+            hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
+
+        if prefix == 'test':
+            if hparams['test_input_dir'] != '':
+                self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
+            else:
+                if hparams['num_test_samples'] > 0:
+                    self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
+                    self.sizes = [self.sizes[i] for i in self.avail_idxs]
+
+        if hparams['pitch_type'] == 'cwt':
+            _, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10))
+
+    def _get_item(self, index):
+        if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
+            index = self.avail_idxs[index]
+        if self.indexed_ds is None:
+            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
+        return self.indexed_ds[index]
+
+    def __getitem__(self, index):
+        hparams = self.hparams
+        item = self._get_item(index)
+        max_frames = hparams['max_frames']
+        spec = torch.Tensor(item['mel'])[:max_frames]
+        energy = (spec.exp() ** 2).sum(-1).sqrt()
+        mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
+        f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
+        phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']])
+        pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
+        # print(item.keys(), item['mel'].shape, spec.shape)
+        sample = {
+            "id": index,
+            "item_name": item['item_name'],
+            "text": item['txt'],
+            "txt_token": phone,
+            "mel": spec,
+            "pitch": pitch,
+            "energy": energy,
+            "f0": f0,
+            "uv": uv,
+            "mel2ph": mel2ph,
+            "mel_nonpadding": spec.abs().sum(-1) > 0,
+        }
+        if self.hparams['use_spk_embed']:
+            sample["spk_embed"] = torch.Tensor(item['spk_embed'])
+        if self.hparams['use_spk_id']:
+            sample["spk_id"] = item['spk_id']
+            # sample['spk_id'] = 0
+            # for key in self.name2spk_id.keys():
+            #     if key in item['item_name']:
+            #         sample['spk_id'] = self.name2spk_id[key]
+            #         break
+        if self.hparams['pitch_type'] == 'cwt':
+            cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames]
+            f0_mean = item.get('f0_mean', item.get('cwt_mean'))
+            f0_std = item.get('f0_std', item.get('cwt_std'))
+            sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
+        elif self.hparams['pitch_type'] == 'ph':
+            f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0)
+            f0_phlevel_num = torch.zeros_like(phone).float().scatter_add(
+                0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1)
+            sample["f0_ph"] = f0_phlevel_sum / f0_phlevel_num
+        return sample
+
+    def collater(self, samples):
+        if len(samples) == 0:
+            return {}
+        id = torch.LongTensor([s['id'] for s in samples])
+        item_names = [s['item_name'] for s in samples]
+        text = [s['text'] for s in samples]
+        txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
+        f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
+        pitch = utils.collate_1d([s['pitch'] for s in samples])
+        uv = utils.collate_1d([s['uv'] for s in samples])
+        energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
+        mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
+            if samples[0]['mel2ph'] is not None else None
+        mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
+        txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
+        mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
+
+        batch = {
+            'id': id,
+            'item_name': item_names,
+            'nsamples': len(samples),
+            'text': text,
+            'txt_tokens': txt_tokens,
+            'txt_lengths': txt_lengths,
+            'mels': mels,
+            'mel_lengths': mel_lengths,
+            'mel2ph': mel2ph,
+            'energy': energy,
+            'pitch': pitch,
+            'f0': f0,
+            'uv': uv,
+        }
+
+        if self.hparams['use_spk_embed']:
+            spk_embed = torch.stack([s['spk_embed'] for s in samples])
+            batch['spk_embed'] = spk_embed
+        if self.hparams['use_spk_id']:
+            spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
+            batch['spk_ids'] = spk_ids
+        if self.hparams['pitch_type'] == 'cwt':
+            cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples])
+            f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
+            f0_std = torch.Tensor([s['f0_std'] for s in samples])
+            batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
+        elif self.hparams['pitch_type'] == 'ph':
+            batch['f0'] = utils.collate_1d([s['f0_ph'] for s in samples])
+
+        return batch
+
+    def load_test_inputs(self, test_input_dir, spk_id=0):
+        inp_wav_paths = glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3')
+        sizes = []
+        items = []
+
+        binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
+        pkg = ".".join(binarizer_cls.split(".")[:-1])
+        cls_name = binarizer_cls.split(".")[-1]
+        binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
+        binarization_args = hparams['binarization_args']
+
+        for wav_fn in inp_wav_paths:
+            item_name = os.path.basename(wav_fn)
+            ph = txt = tg_fn = ''
+            wav_fn = wav_fn
+            encoder = None
+            item = binarizer_cls.process_item(item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args)
+            items.append(item)
+            sizes.append(item['len'])
+        return items, sizes
diff --git a/tasks/tts/pe.py b/tasks/tts/pe.py
new file mode 100644
index 0000000000000000000000000000000000000000..3880c80d0820c36e044c00bd38a07fd3cce73323
--- /dev/null
+++ b/tasks/tts/pe.py
@@ -0,0 +1,155 @@
+import matplotlib
+matplotlib.use('Agg')
+
+import torch
+import numpy as np
+import os
+
+from tasks.base_task import BaseDataset
+from tasks.tts.fs2 import FastSpeech2Task
+from modules.fastspeech.pe import PitchExtractor
+import utils
+from utils.indexed_datasets import IndexedDataset
+from utils.hparams import hparams
+from utils.plot import f0_to_figure
+from utils.pitch_utils import norm_interp_f0, denorm_f0
+
+
+class PeDataset(BaseDataset):
+    def __init__(self, prefix, shuffle=False):
+        super().__init__(shuffle)
+        self.data_dir = hparams['binary_data_dir']
+        self.prefix = prefix
+        self.hparams = hparams
+        self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
+        self.indexed_ds = None
+
+        # pitch stats
+        f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
+        if os.path.exists(f0_stats_fn):
+            hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
+            hparams['f0_mean'] = float(hparams['f0_mean'])
+            hparams['f0_std'] = float(hparams['f0_std'])
+        else:
+            hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
+
+        if prefix == 'test':
+            if hparams['num_test_samples'] > 0:
+                self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
+                self.sizes = [self.sizes[i] for i in self.avail_idxs]
+
+    def _get_item(self, index):
+        if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
+            index = self.avail_idxs[index]
+        if self.indexed_ds is None:
+            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
+        return self.indexed_ds[index]
+
+    def __getitem__(self, index):
+        hparams = self.hparams
+        item = self._get_item(index)
+        max_frames = hparams['max_frames']
+        spec = torch.Tensor(item['mel'])[:max_frames]
+        # mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
+        f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
+        pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
+        # print(item.keys(), item['mel'].shape, spec.shape)
+        sample = {
+            "id": index,
+            "item_name": item['item_name'],
+            "text": item['txt'],
+            "mel": spec,
+            "pitch": pitch,
+            "f0": f0,
+            "uv": uv,
+            # "mel2ph": mel2ph,
+            # "mel_nonpadding": spec.abs().sum(-1) > 0,
+        }
+        return sample
+
+    def collater(self, samples):
+        if len(samples) == 0:
+            return {}
+        id = torch.LongTensor([s['id'] for s in samples])
+        item_names = [s['item_name'] for s in samples]
+        text = [s['text'] for s in samples]
+        f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
+        pitch = utils.collate_1d([s['pitch'] for s in samples])
+        uv = utils.collate_1d([s['uv'] for s in samples])
+        mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
+        mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
+        # mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
+        #     if samples[0]['mel2ph'] is not None else None
+        # mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0)
+
+        batch = {
+            'id': id,
+            'item_name': item_names,
+            'nsamples': len(samples),
+            'text': text,
+            'mels': mels,
+            'mel_lengths': mel_lengths,
+            'pitch': pitch,
+            # 'mel2ph': mel2ph,
+            # 'mel_nonpaddings': mel_nonpaddings,
+            'f0': f0,
+            'uv': uv,
+        }
+        return batch
+
+
+class PitchExtractionTask(FastSpeech2Task):
+    def __init__(self):
+        super().__init__()
+        self.dataset_cls = PeDataset
+
+    def build_tts_model(self):
+        self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers'])
+
+    # def build_scheduler(self, optimizer):
+    #     return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
+    def _training_step(self, sample, batch_idx, _):
+        loss_output = self.run_model(self.model, sample)
+        total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
+        loss_output['batch_size'] = sample['mels'].size()[0]
+        return total_loss, loss_output
+
+    def validation_step(self, sample, batch_idx):
+        outputs = {}
+        outputs['losses'] = {}
+        outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True)
+        outputs['total_loss'] = sum(outputs['losses'].values())
+        outputs['nsamples'] = sample['nsamples']
+        outputs = utils.tensors_to_scalars(outputs)
+        if batch_idx < hparams['num_valid_plots']:
+            self.plot_pitch(batch_idx, model_out, sample)
+        return outputs
+
+    def run_model(self, model, sample, return_output=False, infer=False):
+        f0 = sample['f0']
+        uv = sample['uv']
+        output = model(sample['mels'])
+        losses = {}
+        self.add_pitch_loss(output, sample, losses)
+        if not return_output:
+            return losses
+        else:
+            return losses, output
+
+    def plot_pitch(self, batch_idx, model_out, sample):
+        gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
+        self.logger.experiment.add_figure(
+            f'f0_{batch_idx}',
+            f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]),
+            self.global_step)
+
+    def add_pitch_loss(self, output, sample, losses):
+        # mel2ph = sample['mel2ph']  # [B, T_s]
+        mel = sample['mels']
+        f0 = sample['f0']
+        uv = sample['uv']
+        # nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \
+        #     else (sample['txt_tokens'] != 0).float()
+        nonpadding = (mel.abs().sum(-1) > 0).float()  # sample['mel_nonpaddings']
+        # print(nonpadding[0][-8:], nonpadding.shape)
+        self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding)
\ No newline at end of file
diff --git a/tasks/tts/tts.py b/tasks/tts/tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..f803c1e738137cb1eca19a1943196abd2884c0a5
--- /dev/null
+++ b/tasks/tts/tts.py
@@ -0,0 +1,131 @@
+from multiprocessing.pool import Pool
+
+import matplotlib
+
+from utils.pl_utils import data_loader
+from utils.training_utils import RSQRTSchedule
+from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
+from modules.fastspeech.pe import PitchExtractor
+
+matplotlib.use('Agg')
+import os
+import numpy as np
+from tqdm import tqdm
+import torch.distributed as dist
+
+from tasks.base_task import BaseTask
+from utils.hparams import hparams
+from utils.text_encoder import TokenTextEncoder
+import json
+
+import torch
+import torch.optim
+import torch.utils.data
+import utils
+
+
+
+class TtsTask(BaseTask):
+    def __init__(self, *args, **kwargs):
+        self.vocoder = None
+        self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir'])
+        self.padding_idx = self.phone_encoder.pad()
+        self.eos_idx = self.phone_encoder.eos()
+        self.seg_idx = self.phone_encoder.seg()
+        self.saving_result_pool = None
+        self.saving_results_futures = None
+        self.stats = {}
+        super().__init__(*args, **kwargs)
+
+    def build_scheduler(self, optimizer):
+        return RSQRTSchedule(optimizer)
+
+    def build_optimizer(self, model):
+        self.optimizer = optimizer = torch.optim.AdamW(
+            model.parameters(),
+            lr=hparams['lr'])
+        return optimizer
+
+    def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
+                         required_batch_size_multiple=-1, endless=False, batch_by_size=True):
+        devices_cnt = torch.cuda.device_count()
+        if devices_cnt == 0:
+            devices_cnt = 1
+        if required_batch_size_multiple == -1:
+            required_batch_size_multiple = devices_cnt
+
+        def shuffle_batches(batches):
+            np.random.shuffle(batches)
+            return batches
+
+        if max_tokens is not None:
+            max_tokens *= devices_cnt
+        if max_sentences is not None:
+            max_sentences *= devices_cnt
+        indices = dataset.ordered_indices()
+        if batch_by_size:
+            batch_sampler = utils.batch_by_size(
+                indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
+                required_batch_size_multiple=required_batch_size_multiple,
+            )
+        else:
+            batch_sampler = []
+            for i in range(0, len(indices), max_sentences):
+                batch_sampler.append(indices[i:i + max_sentences])
+
+        if shuffle:
+            batches = shuffle_batches(list(batch_sampler))
+            if endless:
+                batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
+        else:
+            batches = batch_sampler
+            if endless:
+                batches = [b for _ in range(1000) for b in batches]
+        num_workers = dataset.num_workers
+        if self.trainer.use_ddp:
+            num_replicas = dist.get_world_size()
+            rank = dist.get_rank()
+            batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
+        return torch.utils.data.DataLoader(dataset,
+                                           collate_fn=dataset.collater,
+                                           batch_sampler=batches,
+                                           num_workers=num_workers,
+                                           pin_memory=False)
+
+    def build_phone_encoder(self, data_dir):
+        phone_list_file = os.path.join(data_dir, 'phone_set.json')
+
+        phone_list = json.load(open(phone_list_file))
+        return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
+
+    def build_optimizer(self, model):
+        self.optimizer = optimizer = torch.optim.AdamW(
+            model.parameters(),
+            lr=hparams['lr'])
+        return optimizer
+
+    def test_start(self):
+        self.saving_result_pool = Pool(8)
+        self.saving_results_futures = []
+        self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
+        if hparams.get('pe_enable') is not None and hparams['pe_enable']:
+            self.pe = PitchExtractor().cuda()
+            utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
+            self.pe.eval()
+    def test_end(self, outputs):
+        self.saving_result_pool.close()
+        [f.get() for f in tqdm(self.saving_results_futures)]
+        self.saving_result_pool.join()
+        return {}
+
+    ##########
+    # utils
+    ##########
+    def weights_nonzero_speech(self, target):
+        # target : B x T x mel
+        # Assign weight 1.0 to all labels except for padding (id=0).
+        dim = target.size(-1)
+        return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
+
+if __name__ == '__main__':
+    TtsTask.start()
diff --git a/tasks/tts/tts_base.py b/tasks/tts/tts_base.py
new file mode 100755
index 0000000000000000000000000000000000000000..509740b54dbf23db6bafebd6bc46089ee83cf499
--- /dev/null
+++ b/tasks/tts/tts_base.py
@@ -0,0 +1,305 @@
+import filecmp
+
+import matplotlib
+
+from utils.plot import spec_to_figure
+
+matplotlib.use('Agg')
+
+from data_gen.tts.data_gen_utils import get_pitch
+from modules.fastspeech.tts_modules import mel2ph_to_dur
+from tasks.tts.dataset_utils import BaseTTSDataset
+from utils.tts_utils import sequence_mask
+from multiprocessing.pool import Pool
+from tasks.base_task import data_loader, BaseConcatDataset
+from utils.common_schedulers import RSQRTSchedule, NoneSchedule
+from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
+import os
+import numpy as np
+from tqdm import tqdm
+import torch.distributed as dist
+from tasks.base_task import BaseTask
+from utils.hparams import hparams
+from utils.text_encoder import TokenTextEncoder
+import json
+import matplotlib.pyplot as plt
+import torch
+import torch.optim
+import torch.utils.data
+import utils
+from utils import audio
+import pandas as pd
+
+
+class TTSBaseTask(BaseTask):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.dataset_cls = BaseTTSDataset
+        self.max_tokens = hparams['max_tokens']
+        self.max_sentences = hparams['max_sentences']
+        self.max_valid_tokens = hparams['max_valid_tokens']
+        if self.max_valid_tokens == -1:
+            hparams['max_valid_tokens'] = self.max_valid_tokens = self.max_tokens
+        self.max_valid_sentences = hparams['max_valid_sentences']
+        if self.max_valid_sentences == -1:
+            hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences
+        self.vocoder = None
+        self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir'])
+        self.padding_idx = self.phone_encoder.pad()
+        self.eos_idx = self.phone_encoder.eos()
+        self.seg_idx = self.phone_encoder.seg()
+        self.saving_result_pool = None
+        self.saving_results_futures = None
+        self.stats = {}
+
+    @data_loader
+    def train_dataloader(self):
+        if hparams['train_sets'] != '':
+            train_sets = hparams['train_sets'].split("|")
+            # check if all train_sets have the same spk map and dictionary
+            binary_data_dir = hparams['binary_data_dir']
+            file_to_cmp = ['phone_set.json']
+            if os.path.exists(f'{binary_data_dir}/word_set.json'):
+                file_to_cmp.append('word_set.json')
+            if hparams['use_spk_id']:
+                file_to_cmp.append('spk_map.json')
+            for f in file_to_cmp:
+                for ds_name in train_sets:
+                    base_file = os.path.join(binary_data_dir, f)
+                    ds_file = os.path.join(ds_name, f)
+                    assert filecmp.cmp(base_file, ds_file), \
+                        f'{f} in {ds_name} is not same with that in {binary_data_dir}.'
+            train_dataset = BaseConcatDataset([
+                self.dataset_cls(prefix='train', shuffle=True, data_dir=ds_name) for ds_name in train_sets])
+        else:
+            train_dataset = self.dataset_cls(prefix=hparams['train_set_name'], shuffle=True)
+        return self.build_dataloader(train_dataset, True, self.max_tokens, self.max_sentences,
+                                     endless=hparams['endless_ds'])
+
+    @data_loader
+    def val_dataloader(self):
+        valid_dataset = self.dataset_cls(prefix=hparams['valid_set_name'], shuffle=False)
+        return self.build_dataloader(valid_dataset, False, self.max_valid_tokens, self.max_valid_sentences)
+
+    @data_loader
+    def test_dataloader(self):
+        test_dataset = self.dataset_cls(prefix=hparams['test_set_name'], shuffle=False)
+        self.test_dl = self.build_dataloader(
+            test_dataset, False, self.max_valid_tokens,
+            self.max_valid_sentences, batch_by_size=False)
+        return self.test_dl
+
+    def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
+                         required_batch_size_multiple=-1, endless=False, batch_by_size=True):
+        devices_cnt = torch.cuda.device_count()
+        if devices_cnt == 0:
+            devices_cnt = 1
+        if required_batch_size_multiple == -1:
+            required_batch_size_multiple = devices_cnt
+
+        def shuffle_batches(batches):
+            np.random.shuffle(batches)
+            return batches
+
+        if max_tokens is not None:
+            max_tokens *= devices_cnt
+        if max_sentences is not None:
+            max_sentences *= devices_cnt
+        indices = dataset.ordered_indices()
+        if batch_by_size:
+            batch_sampler = utils.batch_by_size(
+                indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
+                required_batch_size_multiple=required_batch_size_multiple,
+            )
+        else:
+            batch_sampler = []
+            for i in range(0, len(indices), max_sentences):
+                batch_sampler.append(indices[i:i + max_sentences])
+
+        if shuffle:
+            batches = shuffle_batches(list(batch_sampler))
+            if endless:
+                batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
+        else:
+            batches = batch_sampler
+            if endless:
+                batches = [b for _ in range(1000) for b in batches]
+        num_workers = dataset.num_workers
+        if self.trainer.use_ddp:
+            num_replicas = dist.get_world_size()
+            rank = dist.get_rank()
+            batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
+        return torch.utils.data.DataLoader(dataset,
+                                           collate_fn=dataset.collater,
+                                           batch_sampler=batches,
+                                           num_workers=num_workers,
+                                           pin_memory=False)
+
+    def build_phone_encoder(self, data_dir):
+        phone_list_file = os.path.join(data_dir, 'phone_set.json')
+        phone_list = json.load(open(phone_list_file))
+        return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
+
+    def build_scheduler(self, optimizer):
+        if hparams['scheduler'] == 'rsqrt':
+            return RSQRTSchedule(optimizer)
+        else:
+            return NoneSchedule(optimizer)
+
+    def build_optimizer(self, model):
+        self.optimizer = optimizer = torch.optim.AdamW(
+            model.parameters(),
+            lr=hparams['lr'],
+            betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
+            weight_decay=hparams['weight_decay'])
+        return optimizer
+
+    def plot_mel(self, batch_idx, spec, spec_out, name=None):
+        spec_cat = torch.cat([spec, spec_out], -1)
+        name = f'mel_{batch_idx}' if name is None else name
+        vmin = hparams['mel_vmin']
+        vmax = hparams['mel_vmax']
+        self.logger.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step)
+
+    def test_start(self):
+        self.saving_result_pool = Pool(min(int(os.getenv('N_PROC', os.cpu_count())), 16))
+        self.saving_results_futures = []
+        self.results_id = 0
+        self.gen_dir = os.path.join(
+            hparams['work_dir'],
+            f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
+        self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
+
+    def after_infer(self, predictions, sil_start_frame=0):
+        predictions = utils.unpack_dict_to_list(predictions)
+        assert len(predictions) == 1, 'Only support batch_size=1 in inference.'
+        prediction = predictions[0]
+        prediction = utils.tensors_to_np(prediction)
+        item_name = prediction.get('item_name')
+        text = prediction.get('text')
+        ph_tokens = prediction.get('txt_tokens')
+        mel_gt = prediction["mels"]
+        mel2ph_gt = prediction.get("mel2ph")
+        mel2ph_gt = mel2ph_gt if mel2ph_gt is not None else None
+        mel_pred = prediction["outputs"]
+        mel2ph_pred = prediction.get("mel2ph_pred")
+        f0_gt = prediction.get("f0")
+        f0_pred = prediction.get("f0_pred")
+
+        str_phs = None
+        if self.phone_encoder is not None and 'txt_tokens' in prediction:
+            str_phs = self.phone_encoder.decode(prediction['txt_tokens'], strip_padding=True)
+
+        if 'encdec_attn' in prediction:
+            encdec_attn = prediction['encdec_attn']
+            encdec_attn = encdec_attn[encdec_attn.max(-1).sum(-1).argmax(-1)]
+            txt_lengths = prediction.get('txt_lengths')
+            encdec_attn = encdec_attn.T[:txt_lengths, :len(mel_gt)]
+        else:
+            encdec_attn = None
+
+        wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
+        wav_pred[:sil_start_frame * hparams['hop_size']] = 0
+        gen_dir = self.gen_dir
+        base_fn = f'[{self.results_id:06d}][{item_name}][%s]'
+        # if text is not None:
+        #     base_fn += text.replace(":", "%3A")[:80]
+        base_fn = base_fn.replace(' ', '_')
+        if not hparams['profile_infer']:
+            os.makedirs(gen_dir, exist_ok=True)
+            os.makedirs(f'{gen_dir}/wavs', exist_ok=True)
+            os.makedirs(f'{gen_dir}/plot', exist_ok=True)
+            if hparams.get('save_mel_npy', False):
+                os.makedirs(f'{gen_dir}/npy', exist_ok=True)
+            if 'encdec_attn' in prediction:
+                os.makedirs(f'{gen_dir}/attn_plot', exist_ok=True)
+            self.saving_results_futures.append(
+                self.saving_result_pool.apply_async(self.save_result, args=[
+                    wav_pred, mel_pred, base_fn % 'P', gen_dir, str_phs, mel2ph_pred, encdec_attn]))
+
+            if mel_gt is not None and hparams['save_gt']:
+                wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt)
+                self.saving_results_futures.append(
+                    self.saving_result_pool.apply_async(self.save_result, args=[
+                        wav_gt, mel_gt, base_fn % 'G', gen_dir, str_phs, mel2ph_gt]))
+                if hparams['save_f0']:
+                    import matplotlib.pyplot as plt
+                    f0_pred_, _ = get_pitch(wav_pred, mel_pred, hparams)
+                    f0_gt_, _ = get_pitch(wav_gt, mel_gt, hparams)
+                    fig = plt.figure()
+                    plt.plot(f0_pred_, label=r'$\hat{f_0}$')
+                    plt.plot(f0_gt_, label=r'$f_0$')
+                    plt.legend()
+                    plt.tight_layout()
+                    plt.savefig(f'{gen_dir}/plot/[F0][{item_name}]{text}.png', format='png')
+                    plt.close(fig)
+            print(f"Pred_shape: {mel_pred.shape}, gt_shape: {mel_gt.shape}")
+        self.results_id += 1
+        return {
+            'item_name': item_name,
+            'text': text,
+            'ph_tokens': self.phone_encoder.decode(ph_tokens.tolist()),
+            'wav_fn_pred': base_fn % 'P',
+            'wav_fn_gt': base_fn % 'G',
+        }
+
+    @staticmethod
+    def save_result(wav_out, mel, base_fn, gen_dir, str_phs=None, mel2ph=None, alignment=None):
+        audio.save_wav(wav_out, f'{gen_dir}/wavs/{base_fn}.wav', hparams['audio_sample_rate'],
+                       norm=hparams['out_wav_norm'])
+        fig = plt.figure(figsize=(14, 10))
+        spec_vmin = hparams['mel_vmin']
+        spec_vmax = hparams['mel_vmax']
+        heatmap = plt.pcolor(mel.T, vmin=spec_vmin, vmax=spec_vmax)
+        fig.colorbar(heatmap)
+        f0, _ = get_pitch(wav_out, mel, hparams)
+        f0 = f0 / 10 * (f0 > 0)
+        plt.plot(f0, c='white', linewidth=1, alpha=0.6)
+        if mel2ph is not None and str_phs is not None:
+            decoded_txt = str_phs.split(" ")
+            dur = mel2ph_to_dur(torch.LongTensor(mel2ph)[None, :], len(decoded_txt))[0].numpy()
+            dur = [0] + list(np.cumsum(dur))
+            for i in range(len(dur) - 1):
+                shift = (i % 20) + 1
+                plt.text(dur[i], shift, decoded_txt[i])
+                plt.hlines(shift, dur[i], dur[i + 1], colors='b' if decoded_txt[i] != '|' else 'black')
+                plt.vlines(dur[i], 0, 5, colors='b' if decoded_txt[i] != '|' else 'black',
+                           alpha=1, linewidth=1)
+        plt.tight_layout()
+        plt.savefig(f'{gen_dir}/plot/{base_fn}.png', format='png')
+        plt.close(fig)
+        if hparams.get('save_mel_npy', False):
+            np.save(f'{gen_dir}/npy/{base_fn}', mel)
+        if alignment is not None:
+            fig, ax = plt.subplots(figsize=(12, 16))
+            im = ax.imshow(alignment, aspect='auto', origin='lower',
+                           interpolation='none')
+            decoded_txt = str_phs.split(" ")
+            ax.set_yticks(np.arange(len(decoded_txt)))
+            ax.set_yticklabels(list(decoded_txt), fontsize=6)
+            fig.colorbar(im, ax=ax)
+            fig.savefig(f'{gen_dir}/attn_plot/{base_fn}_attn.png', format='png')
+            plt.close(fig)
+
+    def test_end(self, outputs):
+        pd.DataFrame(outputs).to_csv(f'{self.gen_dir}/meta.csv')
+        self.saving_result_pool.close()
+        [f.get() for f in tqdm(self.saving_results_futures)]
+        self.saving_result_pool.join()
+        return {}
+
+    ##########
+    # utils
+    ##########
+    def weights_nonzero_speech(self, target):
+        # target : B x T x mel
+        # Assign weight 1.0 to all labels except for padding (id=0).
+        dim = target.size(-1)
+        return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
+
+    def make_stop_target(self, target):
+        # target : B x T x mel
+        seq_mask = target.abs().sum(-1).ne(0).float()
+        seq_length = seq_mask.sum(1)
+        mask_r = 1 - sequence_mask(seq_length - 1, target.size(1)).float()
+        return seq_mask, mask_r
diff --git a/tasks/tts/tts_utils.py b/tasks/tts/tts_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e13439ee72e4fda220605c5868b3159110d9129b
--- /dev/null
+++ b/tasks/tts/tts_utils.py
@@ -0,0 +1,54 @@
+import importlib
+
+from data_gen.tts.base_binarizer import BaseBinarizer
+from data_gen.tts.base_preprocess import BasePreprocessor
+from data_gen.tts.txt_processors.base_text_processor import get_txt_processor_cls
+from utils.hparams import hparams
+
+
+def parse_dataset_configs():
+    max_tokens = hparams['max_tokens']
+    max_sentences = hparams['max_sentences']
+    max_valid_tokens = hparams['max_valid_tokens']
+    if max_valid_tokens == -1:
+        hparams['max_valid_tokens'] = max_valid_tokens = max_tokens
+    max_valid_sentences = hparams['max_valid_sentences']
+    if max_valid_sentences == -1:
+        hparams['max_valid_sentences'] = max_valid_sentences = max_sentences
+    return max_tokens, max_sentences, max_valid_tokens, max_valid_sentences
+
+
+def parse_mel_losses():
+    mel_losses = hparams['mel_losses'].split("|")
+    loss_and_lambda = {}
+    for i, l in enumerate(mel_losses):
+        if l == '':
+            continue
+        if ':' in l:
+            l, lbd = l.split(":")
+            lbd = float(lbd)
+        else:
+            lbd = 1.0
+        loss_and_lambda[l] = lbd
+    print("| Mel losses:", loss_and_lambda)
+    return loss_and_lambda
+
+
+def load_data_preprocessor():
+    preprocess_cls = hparams["preprocess_cls"]
+    pkg = ".".join(preprocess_cls.split(".")[:-1])
+    cls_name = preprocess_cls.split(".")[-1]
+    preprocessor: BasePreprocessor = getattr(importlib.import_module(pkg), cls_name)()
+    preprocess_args = {}
+    preprocess_args.update(hparams['preprocess_args'])
+    return preprocessor, preprocess_args
+
+
+def load_data_binarizer():
+    binarizer_cls = hparams['binarizer_cls']
+    pkg = ".".join(binarizer_cls.split(".")[:-1])
+    cls_name = binarizer_cls.split(".")[-1]
+    binarizer: BaseBinarizer = getattr(importlib.import_module(pkg), cls_name)()
+    binarization_args = {}
+    binarization_args.update(hparams['binarization_args'])
+    return binarizer, binarization_args
\ No newline at end of file
diff --git a/tasks/vocoder/dataset_utils.py b/tasks/vocoder/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..05dcdaa524efde31575dd30b57b627d22744b53c
--- /dev/null
+++ b/tasks/vocoder/dataset_utils.py
@@ -0,0 +1,204 @@
+import glob
+import importlib
+import os
+from resemblyzer import VoiceEncoder
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
+import utils
+from tasks.base_task import BaseDataset
+from utils.hparams import hparams
+from utils.indexed_datasets import IndexedDataset
+from tqdm import tqdm
+
+class EndlessDistributedSampler(DistributedSampler):
+    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.shuffle = shuffle
+
+        g = torch.Generator()
+        g.manual_seed(self.epoch)
+        if self.shuffle:
+            indices = [i for _ in range(1000) for i in torch.randperm(
+                len(self.dataset), generator=g).tolist()]
+        else:
+            indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))]
+        indices = indices[:len(indices) // self.num_replicas * self.num_replicas]
+        indices = indices[self.rank::self.num_replicas]
+        self.indices = indices
+
+    def __iter__(self):
+        return iter(self.indices)
+
+    def __len__(self):
+        return len(self.indices)
+
+
+class VocoderDataset(BaseDataset):
+    def __init__(self, prefix, shuffle=False):
+        super().__init__(shuffle)
+        self.hparams = hparams
+        self.prefix = prefix
+        self.data_dir = hparams['binary_data_dir']
+        self.is_infer = prefix == 'test'
+        self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size']
+        self.aux_context_window = hparams['aux_context_window']
+        self.hop_size = hparams['hop_size']
+        if self.is_infer and hparams['test_input_dir'] != '':
+            self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
+            self.avail_idxs = [i for i, _ in enumerate(self.sizes)]
+        elif self.is_infer and hparams['test_mel_dir'] != '':
+            self.indexed_ds, self.sizes = self.load_mel_inputs(hparams['test_mel_dir'])
+            self.avail_idxs = [i for i, _ in enumerate(self.sizes)]
+        else:
+            self.indexed_ds = None
+            self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
+            self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if
+                               s - 2 * self.aux_context_window > self.batch_max_frames]
+            print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.")
+            self.sizes = [s for idx, s in enumerate(self.sizes) if
+                          s - 2 * self.aux_context_window > self.batch_max_frames]
+
+    def _get_item(self, index):
+        if self.indexed_ds is None:
+            self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
+        item = self.indexed_ds[index]
+        return item
+
+    def __getitem__(self, index):
+        index = self.avail_idxs[index]
+        item = self._get_item(index)
+        sample = {
+            "id": index,
+            "item_name": item['item_name'],
+            "mel": torch.FloatTensor(item['mel']),
+            "wav": torch.FloatTensor(item['wav'].astype(np.float32)),
+        }
+        if 'pitch' in item:
+            sample['pitch'] = torch.LongTensor(item['pitch'])
+            sample['f0'] = torch.FloatTensor(item['f0'])
+
+        if hparams.get('use_spk_embed', False):
+            sample["spk_embed"] = torch.Tensor(item['spk_embed'])
+        if hparams.get('use_emo_embed', False):
+            sample["emo_embed"] = torch.Tensor(item['emo_embed'])
+
+        return sample
+
+    def collater(self, batch):
+        if len(batch) == 0:
+            return {}
+
+        y_batch, c_batch, p_batch, f0_batch = [], [], [], []
+        item_name = []
+        have_pitch = 'pitch' in batch[0]
+        for idx in range(len(batch)):
+            item_name.append(batch[idx]['item_name'])
+            x, c = batch[idx]['wav'] if self.hparams['use_wav'] else None, batch[idx]['mel'].squeeze(0)
+            if have_pitch:
+                p = batch[idx]['pitch']
+                f0 = batch[idx]['f0']
+            if self.hparams['use_wav']:self._assert_ready_for_upsampling(x, c, self.hop_size, 0) 
+            if len(c) - 2 * self.aux_context_window > self.batch_max_frames:
+                # randomly pickup with the batch_max_steps length of the part
+                batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len(
+                    c) - 2 * self.aux_context_window - 1
+                batch_max_steps = batch_max_frames * self.hop_size
+                interval_start = self.aux_context_window
+                interval_end = len(c) - batch_max_frames - self.aux_context_window
+                start_frame = np.random.randint(interval_start, interval_end)
+                start_step = start_frame * self.hop_size
+                if self.hparams['use_wav']:y = x[start_step: start_step + batch_max_steps]
+                c = c[start_frame - self.aux_context_window:
+                      start_frame + self.aux_context_window + batch_max_frames]
+                if have_pitch:
+                    p = p[start_frame - self.aux_context_window:
+                          start_frame + self.aux_context_window + batch_max_frames]
+                    f0 = f0[start_frame - self.aux_context_window:
+                            start_frame + self.aux_context_window + batch_max_frames]
+                if self.hparams['use_wav']:self._assert_ready_for_upsampling(y, c, self.hop_size, self.aux_context_window)
+            else:
+                print(f"Removed short sample from batch (length={len(x)}).")
+                continue
+            if self.hparams['use_wav']:y_batch += [y.reshape(-1, 1)]  # [(T, 1), (T, 1), ...]
+            c_batch += [c]  # [(T' C), (T' C), ...]
+            if have_pitch:
+                p_batch += [p]  # [(T' C), (T' C), ...]
+                f0_batch += [f0]  # [(T' C), (T' C), ...]
+
+        # convert each batch to tensor, asuume that each item in batch has the same length
+        if self.hparams['use_wav']:y_batch = utils.collate_2d(y_batch, 0).transpose(2, 1)  # (B, 1, T)
+        c_batch = utils.collate_2d(c_batch, 0).transpose(2, 1)  # (B, C, T')
+        if have_pitch:
+            p_batch = utils.collate_1d(p_batch, 0)  # (B, T')
+            f0_batch = utils.collate_1d(f0_batch, 0)  # (B, T')
+        else:
+            p_batch, f0_batch = None, None
+
+        # make input noise signal batch tensor
+        if self.hparams['use_wav']: z_batch = torch.randn(y_batch.size())  # (B, 1, T)
+        else: z_batch=[]
+        return {
+            'z': z_batch,
+            'mels': c_batch,
+            'wavs': y_batch,
+            'pitches': p_batch,
+            'f0': f0_batch,
+            'item_name': item_name
+        }
+
+    @staticmethod
+    def _assert_ready_for_upsampling(x, c, hop_size, context_window):
+        """Assert the audio and feature lengths are correctly adjusted for upsamping."""
+        assert len(x) == (len(c) - 2 * context_window) * hop_size
+
+    def load_test_inputs(self, test_input_dir, spk_id=0):
+        inp_wav_paths = sorted(glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/**/*.mp3'))
+        sizes = []
+        items = []
+
+        binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
+        pkg = ".".join(binarizer_cls.split(".")[:-1])
+        cls_name = binarizer_cls.split(".")[-1]
+        binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
+        binarization_args = hparams['binarization_args']
+
+        for wav_fn in inp_wav_paths:
+            item_name = wav_fn[len(test_input_dir) + 1:].replace("/", "_")
+            item = binarizer_cls.process_item(
+                item_name, wav_fn, binarization_args)
+            items.append(item)
+            sizes.append(item['len'])
+        return items, sizes
+
+    def load_mel_inputs(self, test_input_dir, spk_id=0):
+        inp_mel_paths = sorted(glob.glob(f'{test_input_dir}/*.npy'))
+        sizes = []
+        items = []
+
+        binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
+        pkg = ".".join(binarizer_cls.split(".")[:-1])
+        cls_name = binarizer_cls.split(".")[-1]
+        binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
+        binarization_args = hparams['binarization_args']
+
+        for mel in inp_mel_paths:
+            mel_input = np.load(mel)
+            mel_input = torch.FloatTensor(mel_input)
+            item_name = mel[len(test_input_dir) + 1:].replace("/", "_")
+            item = binarizer_cls.process_mel_item(item_name, mel_input, None, binarization_args)
+            items.append(item)
+            sizes.append(item['len'])
+        return items, sizes
diff --git a/tasks/vocoder/vocoder_base.py b/tasks/vocoder/vocoder_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..04f45af60c8ac1c1f8303d091f8c6031ec8451bf
--- /dev/null
+++ b/tasks/vocoder/vocoder_base.py
@@ -0,0 +1,66 @@
+import os
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
+
+from tasks.base_task import BaseTask
+from tasks.base_task import data_loader
+from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler
+from utils.hparams import hparams
+
+
+class VocoderBaseTask(BaseTask):
+    def __init__(self):
+        super(VocoderBaseTask, self).__init__()
+        self.max_sentences = hparams['max_sentences']
+        self.max_valid_sentences = hparams['max_valid_sentences']
+        if self.max_valid_sentences == -1:
+            hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences
+        self.dataset_cls = VocoderDataset
+
+    @data_loader
+    def train_dataloader(self):
+        train_dataset = self.dataset_cls('train', shuffle=True)
+        return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds'])
+
+    @data_loader
+    def val_dataloader(self):
+        valid_dataset = self.dataset_cls('valid', shuffle=False)
+        return self.build_dataloader(valid_dataset, False, self.max_valid_sentences)
+
+    @data_loader
+    def test_dataloader(self):
+        test_dataset = self.dataset_cls('test', shuffle=False)
+        return self.build_dataloader(test_dataset, False, self.max_valid_sentences)
+
+    def build_dataloader(self, dataset, shuffle, max_sentences, endless=False):
+        world_size = 1
+        rank = 0
+        if dist.is_initialized():
+            world_size = dist.get_world_size()
+            rank = dist.get_rank()
+        sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler
+        train_sampler = sampler_cls(
+            dataset=dataset,
+            num_replicas=world_size,
+            rank=rank,
+            shuffle=shuffle,
+        )
+        return torch.utils.data.DataLoader(
+            dataset=dataset,
+            shuffle=False,
+            collate_fn=dataset.collater,
+            batch_size=max_sentences,
+            num_workers=dataset.num_workers,
+            sampler=train_sampler,
+            pin_memory=True,
+        )
+
+    def test_start(self):
+        self.gen_dir = os.path.join(hparams['work_dir'],
+                                    f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
+        os.makedirs(self.gen_dir, exist_ok=True)
+
+    def test_end(self, outputs):
+        return {}
diff --git a/usr/.gitkeep b/usr/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/usr/__init__.py b/usr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/usr/diff/diffusion.py b/usr/diff/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e874d64d4636c0b842392b91e92c7586770cbe58
--- /dev/null
+++ b/usr/diff/diffusion.py
@@ -0,0 +1,333 @@
+import math
+import random
+from functools import partial
+from inspect import isfunction
+from pathlib import Path
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from tqdm import tqdm
+from einops import rearrange
+
+from modules.fastspeech.fs2 import FastSpeech2
+from utils.hparams import hparams
+
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def cycle(dl):
+    while True:
+        for data in dl:
+            yield data
+
+
+def num_to_groups(num, divisor):
+    groups = num // divisor
+    remainder = num % divisor
+    arr = [divisor] * groups
+    if remainder > 0:
+        arr.append(remainder)
+    return arr
+
+
+class Residual(nn.Module):
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+
+    def forward(self, x, *args, **kwargs):
+        return self.fn(x, *args, **kwargs) + x
+
+
+class SinusoidalPosEmb(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class Mish(nn.Module):
+    def forward(self, x):
+        return x * torch.tanh(F.softplus(x))
+
+
+class Upsample(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.conv = nn.ConvTranspose2d(dim, dim, 4, 2, 1)
+
+    def forward(self, x):
+        return self.conv(x)
+
+
+class Downsample(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
+
+    def forward(self, x):
+        return self.conv(x)
+
+
+class Rezero(nn.Module):
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+        self.g = nn.Parameter(torch.zeros(1))
+
+    def forward(self, x):
+        return self.fn(x) * self.g
+
+
+# building block modules
+
+class Block(nn.Module):
+    def __init__(self, dim, dim_out, groups=8):
+        super().__init__()
+        self.block = nn.Sequential(
+            nn.Conv2d(dim, dim_out, 3, padding=1),
+            nn.GroupNorm(groups, dim_out),
+            Mish()
+        )
+
+    def forward(self, x):
+        return self.block(x)
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, dim, dim_out, *, time_emb_dim, groups=8):
+        super().__init__()
+        self.mlp = nn.Sequential(
+            Mish(),
+            nn.Linear(time_emb_dim, dim_out)
+        )
+
+        self.block1 = Block(dim, dim_out)
+        self.block2 = Block(dim_out, dim_out)
+        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
+
+    def forward(self, x, time_emb):
+        h = self.block1(x)
+        h += self.mlp(time_emb)[:, :, None, None]
+        h = self.block2(h)
+        return h + self.res_conv(x)
+
+
+class LinearAttention(nn.Module):
+    def __init__(self, dim, heads=4, dim_head=32):
+        super().__init__()
+        self.heads = heads
+        hidden_dim = dim_head * heads
+        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+    def forward(self, x):
+        b, c, h, w = x.shape
+        qkv = self.to_qkv(x)
+        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
+        k = k.softmax(dim=-1)
+        context = torch.einsum('bhdn,bhen->bhde', k, v)
+        out = torch.einsum('bhde,bhdn->bhen', context, q)
+        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+        return self.to_out(out)
+
+
+# gaussian diffusion trainer class
+
+def extract(a, t, x_shape):
+    b, *_ = t.shape
+    out = a.gather(-1, t)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
+
+
+def cosine_beta_schedule(timesteps, s=0.008):
+    """
+    cosine schedule
+    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+    """
+    steps = timesteps + 1
+    x = np.linspace(0, steps, steps)
+    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
+    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+    return np.clip(betas, a_min=0, a_max=0.999)
+
+
+class GaussianDiffusion(nn.Module):
+    def __init__(self, phone_encoder, out_dims, denoise_fn,
+                 timesteps=1000, loss_type='l1', betas=None, spec_min=None, spec_max=None):
+        super().__init__()
+        self.denoise_fn = denoise_fn
+        if hparams.get('use_midi') is not None and hparams['use_midi']:
+            self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
+        else:
+            self.fs2 = FastSpeech2(phone_encoder, out_dims)
+        self.fs2.decoder = None
+        self.mel_bins = out_dims
+
+        if exists(betas):
+            betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
+        else:
+            betas = cosine_beta_schedule(timesteps)
+
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.loss_type = loss_type
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
+        self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
+
+    def q_mean_variance(self, x_start, t):
+        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, cond, clip_denoised: bool):
+        noise_pred = self.denoise_fn(x, t, cond=cond)
+        x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
+
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (
+                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+        )
+
+    def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        x_recon = self.denoise_fn(x_noisy, t, cond)
+
+        if self.loss_type == 'l1':
+            if nonpadding is not None:
+                loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
+            else:
+                # print('are you sure w/o nonpadding?')
+                loss = (noise - x_recon).abs().mean()
+
+        elif self.loss_type == 'l2':
+            loss = F.mse_loss(noise, x_recon)
+        else:
+            raise NotImplementedError()
+
+        return loss
+
+    def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
+                ref_mels=None, f0=None, uv=None, energy=None, infer=False):
+        b, *_, device = *txt_tokens.shape, txt_tokens.device
+        ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
+                       skip_decoder=True, infer=infer)
+        cond = ret['decoder_inp'].transpose(1, 2)
+        if not infer:
+            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
+            x = ref_mels
+            x = self.norm_spec(x)
+            x = x.transpose(1, 2)[:, None, :, :]  # [B, 1, M, T]
+            nonpadding = (mel2ph != 0).float()
+            ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
+        else:
+            t = self.num_timesteps
+            shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+            x = torch.randn(shape, device=device)
+            for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
+                x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
+            x = x[:, 0].transpose(1, 2)
+            ret['mel_out'] = self.denorm_spec(x)
+
+        return ret
+
+    def norm_spec(self, x):
+        return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
+
+    def denorm_spec(self, x):
+        return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
+
+    def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
+        return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
+
+    def out2mel(self, x):
+        return x
diff --git a/usr/diff/net.py b/usr/diff/net.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8811115eafb4f27165cf4d89c67c0d9455aac9d
--- /dev/null
+++ b/usr/diff/net.py
@@ -0,0 +1,130 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from math import sqrt
+
+from .diffusion import Mish
+from utils.hparams import hparams
+
+Linear = nn.Linear
+ConvTranspose2d = nn.ConvTranspose2d
+
+
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
+
+    def override(self, attrs):
+        if isinstance(attrs, dict):
+            self.__dict__.update(**attrs)
+        elif isinstance(attrs, (list, tuple, set)):
+            for attr in attrs:
+                self.override(attr)
+        elif attrs is not None:
+            raise NotImplementedError
+        return self
+
+
+class SinusoidalPosEmb(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+def Conv1d(*args, **kwargs):
+    layer = nn.Conv1d(*args, **kwargs)
+    nn.init.kaiming_normal_(layer.weight)
+    return layer
+
+
+@torch.jit.script
+def silu(x):
+    return x * torch.sigmoid(x)
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, encoder_hidden, residual_channels, dilation):
+        super().__init__()
+        self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
+        self.diffusion_projection = Linear(residual_channels, residual_channels)
+        self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
+        self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
+
+    def forward(self, x, conditioner, diffusion_step):
+        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+        conditioner = self.conditioner_projection(conditioner)
+        y = x + diffusion_step
+
+        y = self.dilated_conv(y) + conditioner
+
+        gate, filter = torch.chunk(y, 2, dim=1)
+        y = torch.sigmoid(gate) * torch.tanh(filter)
+
+        y = self.output_projection(y)
+        residual, skip = torch.chunk(y, 2, dim=1)
+        return (x + residual) / sqrt(2.0), skip
+
+
+class DiffNet(nn.Module):
+    def __init__(self, in_dims=80):
+        super().__init__()
+        self.params = params = AttrDict(
+            # Model params
+            encoder_hidden=hparams['hidden_size'],
+            residual_layers=hparams['residual_layers'],
+            residual_channels=hparams['residual_channels'],
+            dilation_cycle_length=hparams['dilation_cycle_length'],
+        )
+        self.input_projection = Conv1d(in_dims, params.residual_channels, 1)
+        self.diffusion_embedding = SinusoidalPosEmb(params.residual_channels)
+        dim = params.residual_channels
+        self.mlp = nn.Sequential(
+            nn.Linear(dim, dim * 4),
+            Mish(),
+            nn.Linear(dim * 4, dim)
+        )
+        self.residual_layers = nn.ModuleList([
+            ResidualBlock(params.encoder_hidden, params.residual_channels, 2 ** (i % params.dilation_cycle_length))
+            for i in range(params.residual_layers)
+        ])
+        self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1)
+        self.output_projection = Conv1d(params.residual_channels, in_dims, 1)
+        nn.init.zeros_(self.output_projection.weight)
+
+    def forward(self, spec, diffusion_step, cond):
+        """
+
+        :param spec: [B, 1, M, T]
+        :param diffusion_step: [B, 1]
+        :param cond: [B, M, T]
+        :return:
+        """
+        x = spec[:, 0]
+        x = self.input_projection(x)  # x [B, residual_channel, T]
+
+        x = F.relu(x)
+        diffusion_step = self.diffusion_embedding(diffusion_step)
+        diffusion_step = self.mlp(diffusion_step)
+        skip = []
+        for layer_id, layer in enumerate(self.residual_layers):
+            x, skip_connection = layer(x, cond, diffusion_step)
+            skip.append(skip_connection)
+
+        x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
+        x = self.skip_projection(x)
+        x = F.relu(x)
+        x = self.output_projection(x)  # [B, 80, T]
+        return x[:, None, :, :]
diff --git a/usr/diff/shallow_diffusion_tts.py b/usr/diff/shallow_diffusion_tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..d296fbc1c297a9703e004bd1d216ed34f0008446
--- /dev/null
+++ b/usr/diff/shallow_diffusion_tts.py
@@ -0,0 +1,307 @@
+import math
+import random
+from functools import partial
+from inspect import isfunction
+from pathlib import Path
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from tqdm import tqdm
+from einops import rearrange
+
+from modules.fastspeech.fs2 import FastSpeech2
+from utils.hparams import hparams
+
+def vpsde_beta_t(t, T, min_beta, max_beta):
+    t_coef = (2 * t - 1) / (T ** 2)
+    return 1. - np.exp(-min_beta / T - 0.5 * (max_beta - min_beta) * t_coef)
+
+def _logsnr_schedule_cosine(t, *, logsnr_min, logsnr_max):
+  b = np.arctan(np.exp(-0.5 * logsnr_max))
+  a = np.arctan(np.exp(-0.5 * logsnr_min)) - b
+  return -2. * np.log(np.tan(a * t + b))
+
+
+def get_noise_schedule_list(schedule_mode, timesteps, min_beta=0.0, max_beta=0.01, s=0.008):
+    if schedule_mode == "linear":
+        schedule_list = np.linspace(0.000001, 0.01, timesteps)
+    elif schedule_mode == "cosine":
+        steps = timesteps + 1
+        x = np.linspace(0, steps, steps)
+        alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
+        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+        schedule_list = np.clip(betas, a_min=0, a_max=0.999)
+    elif schedule_mode == "vpsde":
+        schedule_list = np.array([
+            vpsde_beta_t(t, timesteps, min_beta, max_beta) for t in range(1, timesteps + 1)])
+    elif schedule_mode == "logsnr":
+        u = np.array([t for t in range(0, timesteps + 1)])
+        schedule_list = np.array([
+            _logsnr_schedule_cosine(t / timesteps, logsnr_min=-20.0, logsnr_max=20.0) for t in range(1, timesteps + 1)])
+    else:
+        raise NotImplementedError
+    return schedule_list
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+# gaussian diffusion trainer class
+
+def extract(a, t, x_shape):
+    b, *_ = t.shape
+    out = a.gather(-1, t)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
+
+
+def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
+    """
+    linear schedule
+    """
+    betas = np.linspace(1e-4, max_beta, timesteps)
+    return betas
+
+
+def cosine_beta_schedule(timesteps, s=0.008):
+    """
+    cosine schedule
+    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+    """
+    steps = timesteps + 1
+    x = np.linspace(0, steps, steps)
+    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
+    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+    return np.clip(betas, a_min=0, a_max=0.999)
+
+
+beta_schedule = {
+    "cosine": cosine_beta_schedule,
+    "linear": linear_beta_schedule,
+}
+
+
+class GaussianDiffusion(nn.Module):
+    def __init__(self, phone_encoder, out_dims, denoise_fn,
+                 timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None):
+        super().__init__()
+        self.denoise_fn = denoise_fn
+        if hparams.get('use_midi') is not None and hparams['use_midi']:
+            self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
+        else:
+            self.fs2 = FastSpeech2(phone_encoder, out_dims)
+        self.mel_bins = out_dims
+
+        if exists(betas):
+            betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
+        else:
+            if 'schedule_type' in hparams.keys():
+                betas = beta_schedule[hparams['schedule_type']](timesteps)
+            else:
+                betas = cosine_beta_schedule(timesteps)
+
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.K_step = K_step
+        self.loss_type = loss_type
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
+        self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
+
+    def q_mean_variance(self, x_start, t):
+        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+        variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, cond, clip_denoised: bool):
+        noise_pred = self.denoise_fn(x, t, cond=cond)
+        x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
+
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (
+                extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+        )
+
+    def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        x_recon = self.denoise_fn(x_noisy, t, cond)
+
+        if self.loss_type == 'l1':
+            if nonpadding is not None:
+                loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
+            else:
+                # print('are you sure w/o nonpadding?')
+                loss = (noise - x_recon).abs().mean()
+
+        elif self.loss_type == 'l2':
+            loss = F.mse_loss(noise, x_recon)
+        else:
+            raise NotImplementedError()
+
+        return loss
+
+    def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
+                ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
+        b, *_, device = *txt_tokens.shape, txt_tokens.device
+        ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
+                       skip_decoder=(not infer), infer=infer, **kwargs)
+        cond = ret['decoder_inp'].transpose(1, 2)
+
+        if not infer:
+            t = torch.randint(0, self.K_step, (b,), device=device).long()
+            x = ref_mels
+            x = self.norm_spec(x)
+            x = x.transpose(1, 2)[:, None, :, :]  # [B, 1, M, T]
+            ret['diff_loss'] = self.p_losses(x, t, cond)
+            # nonpadding = (mel2ph != 0).float()
+            # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
+        else:
+            ret['fs2_mel'] = ret['mel_out']
+            fs2_mels = ret['mel_out']
+            t = self.K_step
+            fs2_mels = self.norm_spec(fs2_mels)
+            fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
+
+            x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
+            if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
+                print('===> gaussion start.')
+                shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+                x = torch.randn(shape, device=device)
+            for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
+                x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
+            x = x[:, 0].transpose(1, 2)
+            if mel2ph is not None:  # for singing
+                ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
+            else:
+                ret['mel_out'] = self.denorm_spec(x)
+        return ret
+
+    # def norm_spec(self, x):
+    #     return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
+    #
+    # def denorm_spec(self, x):
+    #     return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
+
+    def norm_spec(self, x):
+        return x
+
+    def denorm_spec(self, x):
+        return x
+
+    def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
+        return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
+        
+    def out2mel(self, x):
+        return x
+
+
+class OfflineGaussianDiffusion(GaussianDiffusion):
+    def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
+                ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
+        b, *_, device = *txt_tokens.shape, txt_tokens.device
+
+        ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
+                       skip_decoder=True, infer=True, **kwargs)
+        cond = ret['decoder_inp'].transpose(1, 2)
+        fs2_mels = ref_mels[1]
+        ref_mels = ref_mels[0]
+
+        if not infer:
+            t = torch.randint(0, self.K_step, (b,), device=device).long()
+            x = ref_mels
+            x = self.norm_spec(x)
+            x = x.transpose(1, 2)[:, None, :, :]  # [B, 1, M, T]
+            ret['diff_loss'] = self.p_losses(x, t, cond)
+        else:
+            t = self.K_step
+            fs2_mels = self.norm_spec(fs2_mels)
+            fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
+
+            x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
+
+            if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
+                print('===> gaussion start.')
+                shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+                x = torch.randn(shape, device=device)
+            for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
+                x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
+            x = x[:, 0].transpose(1, 2)
+            ret['mel_out'] = self.denorm_spec(x)
+        return ret
\ No newline at end of file
diff --git a/usr/diffspeech_task.py b/usr/diffspeech_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4fca7e9e46fc378468188d58fc42bc989df824c
--- /dev/null
+++ b/usr/diffspeech_task.py
@@ -0,0 +1,122 @@
+import torch
+
+import utils
+from utils.hparams import hparams
+from .diff.net import DiffNet
+from .diff.shallow_diffusion_tts import GaussianDiffusion
+from .task import DiffFsTask
+from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
+from utils.pitch_utils import denorm_f0
+from tasks.tts.fs2_utils import FastSpeechDataset
+
+DIFF_DECODERS = {
+    'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
+}
+
+
+class DiffSpeechTask(DiffFsTask):
+    def __init__(self):
+        super(DiffSpeechTask, self).__init__()
+        self.dataset_cls = FastSpeechDataset
+        self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
+
+    def build_tts_model(self):
+        mel_bins = hparams['audio_num_mel_bins']
+        self.model = GaussianDiffusion(
+            phone_encoder=self.phone_encoder,
+            out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
+            timesteps=hparams['timesteps'],
+            K_step=hparams['K_step'],
+            loss_type=hparams['diff_loss_type'],
+            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+        )
+        if hparams['fs2_ckpt'] != '':
+            utils.load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
+        # self.model.fs2.decoder = None
+        for k, v in self.model.fs2.named_parameters():
+            if not 'predictor' in k:
+                v.requires_grad = False
+
+    def build_optimizer(self, model):
+        self.optimizer = optimizer = torch.optim.AdamW(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=hparams['lr'],
+            betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
+            weight_decay=hparams['weight_decay'])
+        return optimizer
+
+    def run_model(self, model, sample, return_output=False, infer=False):
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+        target = sample['mels']  # [B, T_s, 80]
+        # mel2ph = sample['mel2ph'] if hparams['use_gt_dur'] else None # [B, T_s]
+        mel2ph = sample['mel2ph']
+        f0 = sample['f0']
+        uv = sample['uv']
+        energy = sample['energy']
+        # fs2_mel = sample['fs2_mels']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        if hparams['pitch_type'] == 'cwt':
+            cwt_spec = sample[f'cwt_spec']
+            f0_mean = sample['f0_mean']
+            f0_std = sample['f0_std']
+            sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
+
+        output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
+                       ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
+
+        losses = {}
+        if 'diff_loss' in output:
+            losses['mel'] = output['diff_loss']
+        self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
+        if hparams['use_pitch_embed']:
+            self.add_pitch_loss(output, sample, losses)
+        if hparams['use_energy_embed']:
+            self.add_energy_loss(output['energy_pred'], energy, losses)
+        if not return_output:
+            return losses
+        else:
+            return losses, output
+
+    def validation_step(self, sample, batch_idx):
+        outputs = {}
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+
+        energy = sample['energy']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        mel2ph = sample['mel2ph']
+        f0 = sample['f0']
+        uv = sample['uv']
+
+        outputs['losses'] = {}
+
+        outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
+
+
+        outputs['total_loss'] = sum(outputs['losses'].values())
+        outputs['nsamples'] = sample['nsamples']
+        outputs = utils.tensors_to_scalars(outputs)
+        if batch_idx < hparams['num_valid_plots']:
+            # model_out = self.model(
+            #     txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, inference=True)
+            # self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}')
+            model_out = self.model(
+                txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True)
+            gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
+            self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm'))
+            self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'])
+        return outputs
+
+    ############
+    # validation plots
+    ############
+    def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None):
+        gt_wav = gt_wav[0].cpu().numpy()
+        wav_out = wav_out[0].cpu().numpy()
+        gt_f0 = gt_f0[0].cpu().numpy()
+        f0 = f0[0].cpu().numpy()
+        if is_mel:
+            gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0)
+            wav_out = self.vocoder.spec2wav(wav_out, f0=f0)
+        self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+        self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+
diff --git a/usr/task.py b/usr/task.py
new file mode 100644
index 0000000000000000000000000000000000000000..f05d66f0a8f7aa5995c95c202af7fa81efb8a28f
--- /dev/null
+++ b/usr/task.py
@@ -0,0 +1,73 @@
+import torch
+
+import utils
+from .diff.diffusion import GaussianDiffusion
+from .diff.net import DiffNet
+from tasks.tts.fs2 import FastSpeech2Task
+from utils.hparams import hparams
+
+
+DIFF_DECODERS = {
+    'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
+}
+
+
+class DiffFsTask(FastSpeech2Task):
+    def build_tts_model(self):
+        mel_bins = hparams['audio_num_mel_bins']
+        self.model = GaussianDiffusion(
+            phone_encoder=self.phone_encoder,
+            out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
+            timesteps=hparams['timesteps'],
+            loss_type=hparams['diff_loss_type'],
+            spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+        )
+
+    def run_model(self, model, sample, return_output=False, infer=False):
+        txt_tokens = sample['txt_tokens']  # [B, T_t]
+        target = sample['mels']  # [B, T_s, 80]
+        mel2ph = sample['mel2ph']  # [B, T_s]
+        f0 = sample['f0']
+        uv = sample['uv']
+        energy = sample['energy']
+        spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+        if hparams['pitch_type'] == 'cwt':
+            cwt_spec = sample[f'cwt_spec']
+            f0_mean = sample['f0_mean']
+            f0_std = sample['f0_std']
+            sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
+
+        output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
+                       ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
+
+        losses = {}
+        if 'diff_loss' in output:
+            losses['mel'] = output['diff_loss']
+        self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
+        if hparams['use_pitch_embed']:
+            self.add_pitch_loss(output, sample, losses)
+        if hparams['use_energy_embed']:
+            self.add_energy_loss(output['energy_pred'], energy, losses)
+        if not return_output:
+            return losses
+        else:
+            return losses, output
+
+    def _training_step(self, sample, batch_idx, _):
+        log_outputs = self.run_model(self.model, sample)
+        total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad])
+        log_outputs['batch_size'] = sample['txt_tokens'].size()[0]
+        log_outputs['lr'] = self.scheduler.get_lr()[0]
+        return total_loss, log_outputs
+
+    def validation_step(self, sample, batch_idx):
+        outputs = {}
+        outputs['losses'] = {}
+        outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
+        outputs['total_loss'] = sum(outputs['losses'].values())
+        outputs['nsamples'] = sample['nsamples']
+        outputs = utils.tensors_to_scalars(outputs)
+        if batch_idx < hparams['num_valid_plots']:
+            _, model_out = self.run_model(self.model, sample, return_output=True, infer=True)
+            self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'])
+        return outputs
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..63921bfea9a95629b15d90498677c6a22de9fec8
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,285 @@
+import time
+import sys
+import types
+
+import chardet
+import numpy as np
+import torch
+import torch.distributed as dist
+from utils.ckpt_utils import load_ckpt
+
+
+def reduce_tensors(metrics):
+    new_metrics = {}
+    for k, v in metrics.items():
+        if isinstance(v, torch.Tensor):
+            dist.all_reduce(v)
+            v = v / dist.get_world_size()
+        if type(v) is dict:
+            v = reduce_tensors(v)
+        new_metrics[k] = v
+    return new_metrics
+
+
+def tensors_to_scalars(tensors):
+    if isinstance(tensors, torch.Tensor):
+        tensors = tensors.item()
+        return tensors
+    elif isinstance(tensors, dict):
+        new_tensors = {}
+        for k, v in tensors.items():
+            v = tensors_to_scalars(v)
+            new_tensors[k] = v
+        return new_tensors
+    elif isinstance(tensors, list):
+        return [tensors_to_scalars(v) for v in tensors]
+    else:
+        return tensors
+
+
+def tensors_to_np(tensors):
+    if isinstance(tensors, dict):
+        new_np = {}
+        for k, v in tensors.items():
+            if isinstance(v, torch.Tensor):
+                v = v.cpu().numpy()
+            if type(v) is dict:
+                v = tensors_to_np(v)
+            new_np[k] = v
+    elif isinstance(tensors, list):
+        new_np = []
+        for v in tensors:
+            if isinstance(v, torch.Tensor):
+                v = v.cpu().numpy()
+            if type(v) is dict:
+                v = tensors_to_np(v)
+            new_np.append(v)
+    elif isinstance(tensors, torch.Tensor):
+        v = tensors
+        if isinstance(v, torch.Tensor):
+            v = v.cpu().numpy()
+        if type(v) is dict:
+            v = tensors_to_np(v)
+        new_np = v
+    else:
+        raise Exception(f'tensors_to_np does not support type {type(tensors)}.')
+    return new_np
+
+
+def move_to_cpu(tensors):
+    ret = {}
+    for k, v in tensors.items():
+        if isinstance(v, torch.Tensor):
+            v = v.cpu()
+        if type(v) is dict:
+            v = move_to_cpu(v)
+        ret[k] = v
+    return ret
+
+
+def move_to_cuda(batch, gpu_id=0):
+    # base case: object can be directly moved using `cuda` or `to`
+    if callable(getattr(batch, 'cuda', None)):
+        return batch.cuda(gpu_id, non_blocking=True)
+    elif callable(getattr(batch, 'to', None)):
+        return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
+    elif isinstance(batch, list):
+        for i, x in enumerate(batch):
+            batch[i] = move_to_cuda(x, gpu_id)
+        return batch
+    elif isinstance(batch, tuple):
+        batch = list(batch)
+        for i, x in enumerate(batch):
+            batch[i] = move_to_cuda(x, gpu_id)
+        return tuple(batch)
+    elif isinstance(batch, dict):
+        for k, v in batch.items():
+            batch[k] = move_to_cuda(v, gpu_id)
+        return batch
+    return batch
+
+
+class AvgrageMeter(object):
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.avg = 0
+        self.sum = 0
+        self.cnt = 0
+
+    def update(self, val, n=1):
+        self.sum += val * n
+        self.cnt += n
+        self.avg = self.sum / self.cnt
+
+
+def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
+    """Convert a list of 1d tensors into a padded 2d tensor."""
+    size = max(v.size(0) for v in values) if max_len is None else max_len
+    res = values[0].new(len(values), size).fill_(pad_idx)
+
+    def copy_tensor(src, dst):
+        assert dst.numel() == src.numel()
+        if shift_right:
+            dst[1:] = src[:-1]
+            dst[0] = shift_id
+        else:
+            dst.copy_(src)
+
+    for i, v in enumerate(values):
+        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
+    return res
+
+
+def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
+    """Convert a list of 2d tensors into a padded 3d tensor."""
+    size = max(v.size(0) for v in values) if max_len is None else max_len
+    res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
+
+    def copy_tensor(src, dst):
+        assert dst.numel() == src.numel()
+        if shift_right:
+            dst[1:] = src[:-1]
+        else:
+            dst.copy_(src)
+
+    for i, v in enumerate(values):
+        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
+    return res
+
+
+def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+    if len(batch) == 0:
+        return 0
+    if len(batch) == max_sentences:
+        return 1
+    if num_tokens > max_tokens:
+        return 1
+    return 0
+
+
+def batch_by_size(
+        indices, num_tokens_fn, max_tokens=None, max_sentences=None,
+        required_batch_size_multiple=1, distributed=False
+):
+    """
+    Yield mini-batches of indices bucketed by size. Batches may contain
+    sequences of different lengths.
+
+    Args:
+        indices (List[int]): ordered list of dataset indices
+        num_tokens_fn (callable): function that returns the number of tokens at
+            a given index
+        max_tokens (int, optional): max number of tokens in each batch
+            (default: None).
+        max_sentences (int, optional): max number of sentences in each
+            batch (default: None).
+        required_batch_size_multiple (int, optional): require batch size to
+            be a multiple of N (default: 1).
+    """
+    max_tokens = max_tokens if max_tokens is not None else sys.maxsize
+    max_sentences = max_sentences if max_sentences is not None else sys.maxsize
+    bsz_mult = required_batch_size_multiple
+
+    if isinstance(indices, types.GeneratorType):
+        indices = np.fromiter(indices, dtype=np.int64, count=-1)
+
+    sample_len = 0
+    sample_lens = []
+    batch = []
+    batches = []
+    for i in range(len(indices)):
+        idx = indices[i]
+        num_tokens = num_tokens_fn(idx)
+        sample_lens.append(num_tokens)
+        sample_len = max(sample_len, num_tokens)
+
+        assert sample_len <= max_tokens, (
+            "sentence at index {} of size {} exceeds max_tokens "
+            "limit of {}!".format(idx, sample_len, max_tokens)
+        )
+        num_tokens = (len(batch) + 1) * sample_len
+
+        if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+            mod_len = max(
+                bsz_mult * (len(batch) // bsz_mult),
+                len(batch) % bsz_mult,
+            )
+            batches.append(batch[:mod_len])
+            batch = batch[mod_len:]
+            sample_lens = sample_lens[mod_len:]
+            sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
+        batch.append(idx)
+    if len(batch) > 0:
+        batches.append(batch)
+    return batches
+
+def unpack_dict_to_list(samples):
+    samples_ = []
+    bsz = samples.get('outputs').size(0)
+    for i in range(bsz):
+        res = {}
+        for k, v in samples.items():
+            try:
+                res[k] = v[i]
+            except:
+                pass
+        samples_.append(res)
+    return samples_
+
+
+def remove_padding(x, padding_idx=0):
+    if x is None:
+        return None
+    assert len(x.shape) in [1, 2]
+    if len(x.shape) == 2:  # [T, H]
+        return x[np.abs(x).sum(-1) != padding_idx]
+    elif len(x.shape) == 1:  # [T]
+        return x[x != padding_idx]
+
+
+class Timer:
+    timer_map = {}
+
+    def __init__(self, name, enable=False):
+        if name not in Timer.timer_map:
+            Timer.timer_map[name] = 0
+        self.name = name
+        self.enable = enable
+
+    def __enter__(self):
+        if self.enable:
+            if torch.cuda.is_available():
+                torch.cuda.synchronize()
+            self.t = time.time()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.enable:
+            if torch.cuda.is_available():
+                torch.cuda.synchronize()
+            Timer.timer_map[self.name] += time.time() - self.t
+            if self.enable:
+                print(f'[Timer] {self.name}: {Timer.timer_map[self.name]}')
+
+
+def print_arch(model, model_name='model'):
+    print(f"| {model_name} Arch: ", model)
+    num_params(model, model_name=model_name)
+
+
+def num_params(model, print_out=True, model_name="model"):
+    parameters = filter(lambda p: p.requires_grad, model.parameters())
+    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+    if print_out:
+        print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+    return parameters
+
+
+def get_encoding(file):
+    with open(file, 'rb') as f:
+        encoding = chardet.detect(f.read())['encoding']
+    if encoding == 'GB2312':
+        encoding = 'GB18030'
+    return encoding
diff --git a/utils/audio.py b/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..aba7ab926cf793d085bbdc70c97f376001183fe1
--- /dev/null
+++ b/utils/audio.py
@@ -0,0 +1,56 @@
+import subprocess
+import matplotlib
+
+matplotlib.use('Agg')
+import librosa
+import librosa.filters
+import numpy as np
+from scipy import signal
+from scipy.io import wavfile
+
+
+def save_wav(wav, path, sr, norm=False):
+    if norm:
+        wav = wav / np.abs(wav).max()
+    wav *= 32767
+    # proposed by @dsmiller
+    wavfile.write(path, sr, wav.astype(np.int16))
+
+
+def get_hop_size(hparams):
+    hop_size = hparams['hop_size']
+    if hop_size is None:
+        assert hparams['frame_shift_ms'] is not None
+        hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['audio_sample_rate'])
+    return hop_size
+
+
+###########################################################################################
+def _stft(y, hparams):
+    return librosa.stft(y=y, n_fft=hparams['fft_size'], hop_length=get_hop_size(hparams),
+                        win_length=hparams['win_size'], pad_mode='constant')
+
+
+def _istft(y, hparams):
+    return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_size'])
+
+
+def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
+    '''compute right padding (final frame) or both sides padding (first and final frames)
+    '''
+    assert pad_sides in (1, 2)
+    # return int(fsize // 2)
+    pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+    if pad_sides == 1:
+        return 0, pad
+    else:
+        return pad // 2, pad // 2 + pad % 2
+
+
+# Conversions
+def amp_to_db(x):
+    return 20 * np.log10(np.maximum(1e-5, x))
+
+
+def normalize(S, hparams):
+    return (S - hparams['min_level_db']) / -hparams['min_level_db']
diff --git a/utils/ckpt_utils.py b/utils/ckpt_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc321f9ba891ffffc374df65871c3085bf898afb
--- /dev/null
+++ b/utils/ckpt_utils.py
@@ -0,0 +1,68 @@
+import glob
+import logging
+import os
+import re
+import torch
+
+
+def get_last_checkpoint(work_dir, steps=None):
+    checkpoint = None
+    last_ckpt_path = None
+    ckpt_paths = get_all_ckpts(work_dir, steps)
+    if len(ckpt_paths) > 0:
+        last_ckpt_path = ckpt_paths[0]
+        checkpoint = torch.load(last_ckpt_path, map_location='cpu')
+        logging.info(f'load module from checkpoint: {last_ckpt_path}')
+    return checkpoint, last_ckpt_path
+
+
+def get_all_ckpts(work_dir, steps=None):
+    if steps is None:
+        ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
+    else:
+        ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
+    return sorted(glob.glob(ckpt_path_pattern),
+                  key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
+
+
+def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True):
+    if os.path.isfile(ckpt_base_dir):
+        base_dir = os.path.dirname(ckpt_base_dir)
+        ckpt_path = ckpt_base_dir
+        checkpoint = torch.load(ckpt_base_dir, map_location='cpu')
+    else:
+        base_dir = ckpt_base_dir
+        checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir)
+    if checkpoint is not None:
+        state_dict = checkpoint["state_dict"]
+        if len([k for k in state_dict.keys() if '.' in k]) > 0:
+            state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items()
+                          if k.startswith(f'{model_name}.')}
+        else:
+            if '.' not in model_name:
+                state_dict = state_dict[model_name]
+            else:
+                base_model_name = model_name.split('.')[0]
+                rest_model_name = model_name[len(base_model_name) + 1:]
+                state_dict = {
+                    k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items()
+                    if k.startswith(f'{rest_model_name}.')}
+        if not strict:
+            cur_model_state_dict = cur_model.state_dict()
+            unmatched_keys = []
+            for key, param in state_dict.items():
+                if key in cur_model_state_dict:
+                    new_param = cur_model_state_dict[key]
+                    if new_param.shape != param.shape:
+                        unmatched_keys.append(key)
+                        print("| Unmatched keys: ", key, new_param.shape, param.shape)
+            for key in unmatched_keys:
+                del state_dict[key]
+        cur_model.load_state_dict(state_dict, strict=strict)
+        print(f"| load '{model_name}' from '{ckpt_path}'.")
+    else:
+        e_msg = f"| ckpt not found in {base_dir}."
+        if force:
+            assert False, e_msg
+        else:
+            print(e_msg)
diff --git a/utils/common_schedulers.py b/utils/common_schedulers.py
new file mode 100755
index 0000000000000000000000000000000000000000..41c6f4a9250b2d5954ce93cb7c04e7b55025cb51
--- /dev/null
+++ b/utils/common_schedulers.py
@@ -0,0 +1,50 @@
+from utils.hparams import hparams
+
+
+class NoneSchedule(object):
+    def __init__(self, optimizer):
+        super().__init__()
+        self.optimizer = optimizer
+        self.constant_lr = hparams['lr']
+        self.step(0)
+
+    def step(self, num_updates):
+        self.lr = self.constant_lr
+        for param_group in self.optimizer.param_groups:
+            param_group['lr'] = self.lr
+        return self.lr
+
+    def get_lr(self):
+        return self.optimizer.param_groups[0]['lr']
+
+    def get_last_lr(self):
+        return self.get_lr()
+
+
+class RSQRTSchedule(object):
+    def __init__(self, optimizer):
+        super().__init__()
+        self.optimizer = optimizer
+        self.constant_lr = hparams['lr']
+        self.warmup_updates = hparams['warmup_updates']
+        self.hidden_size = hparams['hidden_size']
+        self.lr = hparams['lr']
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = self.lr
+        self.step(0)
+
+    def step(self, num_updates):
+        constant_lr = self.constant_lr
+        warmup = min(num_updates / self.warmup_updates, 1.0)
+        rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
+        rsqrt_hidden = self.hidden_size ** -0.5
+        self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
+        for param_group in self.optimizer.param_groups:
+            param_group['lr'] = self.lr
+        return self.lr
+
+    def get_lr(self):
+        return self.optimizer.param_groups[0]['lr']
+
+    def get_last_lr(self):
+        return self.get_lr()
diff --git a/utils/cwt.py b/utils/cwt.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a08461b9e422aac614438e6240b7355b8e4bb2c
--- /dev/null
+++ b/utils/cwt.py
@@ -0,0 +1,146 @@
+import librosa
+import numpy as np
+from pycwt import wavelet
+from scipy.interpolate import interp1d
+
+
+def load_wav(wav_file, sr):
+    wav, _ = librosa.load(wav_file, sr=sr, mono=True)
+    return wav
+
+
+def convert_continuos_f0(f0):
+    '''CONVERT F0 TO CONTINUOUS F0
+    Args:
+        f0 (ndarray): original f0 sequence with the shape (T)
+    Return:
+        (ndarray): continuous f0 with the shape (T)
+    '''
+    # get uv information as binary
+    f0 = np.copy(f0)
+    uv = np.float32(f0 != 0)
+
+    # get start and end of f0
+    if (f0 == 0).all():
+        print("| all of the f0 values are 0.")
+        return uv, f0
+    start_f0 = f0[f0 != 0][0]
+    end_f0 = f0[f0 != 0][-1]
+
+    # padding start and end of f0 sequence
+    start_idx = np.where(f0 == start_f0)[0][0]
+    end_idx = np.where(f0 == end_f0)[0][-1]
+    f0[:start_idx] = start_f0
+    f0[end_idx:] = end_f0
+
+    # get non-zero frame index
+    nz_frames = np.where(f0 != 0)[0]
+
+    # perform linear interpolation
+    f = interp1d(nz_frames, f0[nz_frames])
+    cont_f0 = f(np.arange(0, f0.shape[0]))
+
+    return uv, cont_f0
+
+
+def get_cont_lf0(f0, frame_period=5.0):
+    uv, cont_f0_lpf = convert_continuos_f0(f0)
+    # cont_f0_lpf = low_pass_filter(cont_f0_lpf, int(1.0 / (frame_period * 0.001)), cutoff=20)
+    cont_lf0_lpf = np.log(cont_f0_lpf)
+    return uv, cont_lf0_lpf
+
+
+def get_lf0_cwt(lf0):
+    '''
+    input:
+        signal of shape (N)
+    output:
+        Wavelet_lf0 of shape(10, N), scales of shape(10)
+    '''
+    mother = wavelet.MexicanHat()
+    dt = 0.005
+    dj = 1
+    s0 = dt * 2
+    J = 9
+
+    Wavelet_lf0, scales, _, _, _, _ = wavelet.cwt(np.squeeze(lf0), dt, dj, s0, J, mother)
+    # Wavelet.shape => (J + 1, len(lf0))
+    Wavelet_lf0 = np.real(Wavelet_lf0).T
+    return Wavelet_lf0, scales
+
+
+def norm_scale(Wavelet_lf0):
+    Wavelet_lf0_norm = np.zeros((Wavelet_lf0.shape[0], Wavelet_lf0.shape[1]))
+    mean = Wavelet_lf0.mean(0)[None, :]
+    std = Wavelet_lf0.std(0)[None, :]
+    Wavelet_lf0_norm = (Wavelet_lf0 - mean) / std
+    return Wavelet_lf0_norm, mean, std
+
+
+def normalize_cwt_lf0(f0, mean, std):
+    uv, cont_lf0_lpf = get_cont_lf0(f0)
+    cont_lf0_norm = (cont_lf0_lpf - mean) / std
+    Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_norm)
+    Wavelet_lf0_norm, _, _ = norm_scale(Wavelet_lf0)
+
+    return Wavelet_lf0_norm
+
+
+def get_lf0_cwt_norm(f0s, mean, std):
+    uvs = list()
+    cont_lf0_lpfs = list()
+    cont_lf0_lpf_norms = list()
+    Wavelet_lf0s = list()
+    Wavelet_lf0s_norm = list()
+    scaless = list()
+
+    means = list()
+    stds = list()
+    for f0 in f0s:
+        uv, cont_lf0_lpf = get_cont_lf0(f0)
+        cont_lf0_lpf_norm = (cont_lf0_lpf - mean) / std
+
+        Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)  # [560,10]
+        Wavelet_lf0_norm, mean_scale, std_scale = norm_scale(Wavelet_lf0)  # [560,10],[1,10],[1,10]
+
+        Wavelet_lf0s_norm.append(Wavelet_lf0_norm)
+        uvs.append(uv)
+        cont_lf0_lpfs.append(cont_lf0_lpf)
+        cont_lf0_lpf_norms.append(cont_lf0_lpf_norm)
+        Wavelet_lf0s.append(Wavelet_lf0)
+        scaless.append(scales)
+        means.append(mean_scale)
+        stds.append(std_scale)
+
+    return Wavelet_lf0s_norm, scaless, means, stds
+
+
+def inverse_cwt_torch(Wavelet_lf0, scales):
+    import torch
+    b = ((torch.arange(0, len(scales)).float().to(Wavelet_lf0.device)[None, None, :] + 1 + 2.5) ** (-2.5))
+    lf0_rec = Wavelet_lf0 * b
+    lf0_rec_sum = lf0_rec.sum(-1)
+    lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdim=True)) / lf0_rec_sum.std(-1, keepdim=True)
+    return lf0_rec_sum
+
+
+def inverse_cwt(Wavelet_lf0, scales):
+    b = ((np.arange(0, len(scales))[None, None, :] + 1 + 2.5) ** (-2.5))
+    lf0_rec = Wavelet_lf0 * b
+    lf0_rec_sum = lf0_rec.sum(-1)
+    lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdims=True)) / lf0_rec_sum.std(-1, keepdims=True)
+    return lf0_rec_sum
+
+
+def cwt2f0(cwt_spec, mean, std, cwt_scales):
+    assert len(mean.shape) == 1 and len(std.shape) == 1 and len(cwt_spec.shape) == 3
+    import torch
+    if isinstance(cwt_spec, torch.Tensor):
+        f0 = inverse_cwt_torch(cwt_spec, cwt_scales)
+        f0 = f0 * std[:, None] + mean[:, None]
+        f0 = f0.exp()  # [B, T]
+    else:
+        f0 = inverse_cwt(cwt_spec, cwt_scales)
+        f0 = f0 * std[:, None] + mean[:, None]
+        f0 = np.exp(f0)  # [B, T]
+    return f0
diff --git a/utils/ddp_utils.py b/utils/ddp_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..4b529198c13a1ffc622baea6e5178407b24aee8f
--- /dev/null
+++ b/utils/ddp_utils.py
@@ -0,0 +1,137 @@
+from torch.nn.parallel import DistributedDataParallel
+from torch.nn.parallel.distributed import _find_tensors
+import torch.optim
+import torch.utils.data
+import torch
+from packaging import version
+
+class DDP(DistributedDataParallel):
+    """
+    Override the forward call in lightning so it goes to training and validation step respectively
+    """
+
+    def forward(self, *inputs, **kwargs):  # pragma: no cover
+        if version.parse(torch.__version__[:6]) < version.parse("1.11"):
+            self._sync_params()
+            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+            assert len(self.device_ids) == 1
+            if self.module.training:
+                output = self.module.training_step(*inputs[0], **kwargs[0])
+            elif self.module.testing:
+                output = self.module.test_step(*inputs[0], **kwargs[0])
+            else:
+                output = self.module.validation_step(*inputs[0], **kwargs[0])
+            if torch.is_grad_enabled():
+                # We'll return the output object verbatim since it is a freeform
+                # object. We need to find any tensors in this object, though,
+                # because we need to figure out which parameters were used during
+                # this forward pass, to ensure we short circuit reduction for any
+                # unused parameters. Only if `find_unused_parameters` is set.
+                if self.find_unused_parameters:
+                    self.reducer.prepare_for_backward(list(_find_tensors(output)))
+                else:
+                    self.reducer.prepare_for_backward([])
+        else:
+            from torch.nn.parallel.distributed import \
+                logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
+            with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
+                if torch.is_grad_enabled() and self.require_backward_grad_sync:
+                    self.logger.set_runtime_stats_and_log()
+                    self.num_iterations += 1
+                    self.reducer.prepare_for_forward()
+
+                # Notify the join context that this process has not joined, if
+                # needed
+                work = Join.notify_join_context(self)
+                if work:
+                    self.reducer._set_forward_pass_work_handle(
+                        work, self._divide_by_initial_world_size
+                    )
+
+                # Calling _rebuild_buckets before forward compuation,
+                # It may allocate new buckets before deallocating old buckets
+                # inside _rebuild_buckets. To save peak memory usage,
+                # call _rebuild_buckets before the peak memory usage increases
+                # during forward computation.
+                # This should be called only once during whole training period.
+                if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
+                    logging.info("Reducer buckets have been rebuilt in this iteration.")
+                    self._has_rebuilt_buckets = True
+
+                # sync params according to location (before/after forward) user
+                # specified as part of hook, if hook was specified.
+                buffer_hook_registered = hasattr(self, 'buffer_hook')
+                if self._check_sync_bufs_pre_fwd():
+                    self._sync_buffers()
+
+                if self._join_config.enable:
+                    # Notify joined ranks whether they should sync in backwards pass or not.
+                    self._check_global_requires_backward_grad_sync(is_joined_rank=False)
+
+                inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+                if self.module.training:
+                    output = self.module.training_step(*inputs[0], **kwargs[0])
+                elif self.module.testing:
+                    output = self.module.test_step(*inputs[0], **kwargs[0])
+                else:
+                    output = self.module.validation_step(*inputs[0], **kwargs[0])
+
+                # sync params according to location (before/after forward) user
+                # specified as part of hook, if hook was specified.
+                if self._check_sync_bufs_post_fwd():
+                    self._sync_buffers()
+
+                if torch.is_grad_enabled() and self.require_backward_grad_sync:
+                    self.require_forward_param_sync = True
+                    # We'll return the output object verbatim since it is a freeform
+                    # object. We need to find any tensors in this object, though,
+                    # because we need to figure out which parameters were used during
+                    # this forward pass, to ensure we short circuit reduction for any
+                    # unused parameters. Only if `find_unused_parameters` is set.
+                    if self.find_unused_parameters and not self.static_graph:
+                        # Do not need to populate this for static graph.
+                        self.reducer.prepare_for_backward(list(_find_tensors(output)))
+                    else:
+                        self.reducer.prepare_for_backward([])
+                else:
+                    self.require_forward_param_sync = False
+
+            # TODO: DDPSink is currently enabled for unused parameter detection and
+            # static graph training for first iteration.
+            if (self.find_unused_parameters and not self.static_graph) or (
+                    self.static_graph and self.num_iterations == 1
+            ):
+                state_dict = {
+                    'static_graph': self.static_graph,
+                    'num_iterations': self.num_iterations,
+                }
+
+                output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
+                    output
+                )
+                output_placeholders = [None for _ in range(len(output_tensor_list))]
+                # Do not touch tensors that have no grad_fn, which can cause issues
+                # such as https://github.com/pytorch/pytorch/issues/60733
+                for i, output in enumerate(output_tensor_list):
+                    if torch.is_tensor(output) and output.grad_fn is None:
+                        output_placeholders[i] = output
+
+                # When find_unused_parameters=True, makes tensors which require grad
+                # run through the DDPSink backward pass. When not all outputs are
+                # used in loss, this makes those corresponding tensors receive
+                # undefined gradient which the reducer then handles to ensure
+                # param.grad field is not touched and we don't error out.
+                passthrough_tensor_list = _DDPSink.apply(
+                    self.reducer,
+                    state_dict,
+                    *output_tensor_list,
+                )
+                for i in range(len(output_placeholders)):
+                    if output_placeholders[i] is None:
+                        output_placeholders[i] = passthrough_tensor_list[i]
+
+                # Reconstruct output data structure.
+                output = _tree_unflatten_with_rref(
+                    output_placeholders, treespec, output_is_rref
+                )
+        return output
diff --git a/utils/hparams.py b/utils/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..7efa3025ec3b52949d7b20d432b3457fc60713c4
--- /dev/null
+++ b/utils/hparams.py
@@ -0,0 +1,121 @@
+import argparse
+import os
+import yaml
+
+global_print_hparams = True
+hparams = {}
+
+
+class Args:
+    def __init__(self, **kwargs):
+        for k, v in kwargs.items():
+            self.__setattr__(k, v)
+
+
+def override_config(old_config: dict, new_config: dict):
+    for k, v in new_config.items():
+        if isinstance(v, dict) and k in old_config:
+            override_config(old_config[k], new_config[k])
+        else:
+            old_config[k] = v
+
+
+def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
+    if config == '':
+        parser = argparse.ArgumentParser(description='neural music')
+        parser.add_argument('--config', type=str, default='',
+                            help='location of the data corpus')
+        parser.add_argument('--exp_name', type=str, default='', help='exp_name')
+        parser.add_argument('--hparams', type=str, default='',
+                            help='location of the data corpus')
+        parser.add_argument('--inference', action='store_true', help='inference')
+        parser.add_argument('--validate', action='store_true', help='validate')
+        parser.add_argument('--reset', action='store_true', help='reset hparams')
+        parser.add_argument('--debug', action='store_true', help='debug')
+        args, unknown = parser.parse_known_args()
+    else:
+        args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
+                    infer=False, validate=False, reset=False, debug=False)
+    args_work_dir = ''
+    if args.exp_name != '':
+        args.work_dir = args.exp_name
+        args_work_dir = f'checkpoints/{args.work_dir}'
+
+    config_chains = []
+    loaded_config = set()
+
+    def load_config(config_fn):  # deep first
+        with open(config_fn) as f:
+            hparams_ = yaml.safe_load(f)
+        loaded_config.add(config_fn)
+        if 'base_config' in hparams_:
+            ret_hparams = {}
+            if not isinstance(hparams_['base_config'], list):
+                hparams_['base_config'] = [hparams_['base_config']]
+            for c in hparams_['base_config']:
+                if c not in loaded_config:
+                    if c.startswith('.'):
+                        c = f'{os.path.dirname(config_fn)}/{c}'
+                        c = os.path.normpath(c)
+                    override_config(ret_hparams, load_config(c))
+            override_config(ret_hparams, hparams_)
+        else:
+            ret_hparams = hparams_
+        config_chains.append(config_fn)
+        return ret_hparams
+
+    global hparams
+    assert args.config != '' or args_work_dir != ''
+    saved_hparams = {}
+    if args_work_dir != 'checkpoints/':
+        ckpt_config_path = f'{args_work_dir}/config.yaml'
+        if os.path.exists(ckpt_config_path):
+            try:
+                with open(ckpt_config_path) as f:
+                    saved_hparams.update(yaml.safe_load(f))
+            except:
+                pass
+        if args.config == '':
+            args.config = ckpt_config_path
+
+    hparams_ = {}
+    hparams_.update(load_config(args.config))
+    
+    if not args.reset:
+        hparams_.update(saved_hparams)
+    hparams_['work_dir'] = args_work_dir
+
+    if args.hparams != "":
+        for new_hparam in args.hparams.split(","):
+            k, v = new_hparam.split("=")
+            if v in ['True', 'False'] or type(hparams_[k]) == bool:
+                hparams_[k] = eval(v)
+            else:
+                hparams_[k] = type(hparams_[k])(v)
+
+    if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
+        os.makedirs(hparams_['work_dir'], exist_ok=True)
+        with open(ckpt_config_path, 'w') as f:
+            yaml.safe_dump(hparams_, f)
+
+    hparams_['inference'] = args.infer
+    hparams_['debug'] = args.debug
+    hparams_['validate'] = args.validate
+    global global_print_hparams
+    if global_hparams:
+        hparams.clear()
+        hparams.update(hparams_)
+
+    if print_hparams and global_print_hparams and global_hparams:
+        print('| Hparams chains: ', config_chains)
+        print('| Hparams: ')
+        for i, (k, v) in enumerate(sorted(hparams_.items())):
+            print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
+        print("")
+        global_print_hparams = False
+    # print(hparams_.keys())
+    if hparams.get('exp_name') is None:
+        hparams['exp_name'] = args.exp_name
+    if hparams_.get('exp_name') is None:
+        hparams_['exp_name'] = args.exp_name
+    return hparams_
diff --git a/utils/indexed_datasets.py b/utils/indexed_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..e15632be30d6296a3c9aa80a1f351058003698b3
--- /dev/null
+++ b/utils/indexed_datasets.py
@@ -0,0 +1,71 @@
+import pickle
+from copy import deepcopy
+
+import numpy as np
+
+
+class IndexedDataset:
+    def __init__(self, path, num_cache=1):
+        super().__init__()
+        self.path = path
+        self.data_file = None
+        self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets']
+        self.data_file = open(f"{path}.data", 'rb', buffering=-1)
+        self.cache = []
+        self.num_cache = num_cache
+
+    def check_index(self, i):
+        if i < 0 or i >= len(self.data_offsets) - 1:
+            raise IndexError('index out of range')
+
+    def __del__(self):
+        if self.data_file:
+            self.data_file.close()
+
+    def __getitem__(self, i):
+        self.check_index(i)
+        if self.num_cache > 0:
+            for c in self.cache:
+                if c[0] == i:
+                    return c[1]
+        self.data_file.seek(self.data_offsets[i])
+        b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i])
+        item = pickle.loads(b)
+        if self.num_cache > 0:
+            self.cache = [(i, deepcopy(item))] + self.cache[:-1]
+        return item
+
+    def __len__(self):
+        return len(self.data_offsets) - 1
+
+class IndexedDatasetBuilder:
+    def __init__(self, path):
+        self.path = path
+        self.out_file = open(f"{path}.data", 'wb')
+        self.byte_offsets = [0]
+
+    def add_item(self, item):
+        s = pickle.dumps(item)
+        bytes = self.out_file.write(s)
+        self.byte_offsets.append(self.byte_offsets[-1] + bytes)
+
+    def finalize(self):
+        self.out_file.close()
+        np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets})
+
+
+if __name__ == "__main__":
+    import random
+    from tqdm import tqdm
+    ds_path = '/tmp/indexed_ds_example'
+    size = 100
+    items = [{"a": np.random.normal(size=[10000, 10]),
+              "b": np.random.normal(size=[10000, 10])} for i in range(size)]
+    builder = IndexedDatasetBuilder(ds_path)
+    for i in tqdm(range(size)):
+        builder.add_item(items[i])
+    builder.finalize()
+    ds = IndexedDataset(ds_path)
+    for i in tqdm(range(10000)):
+        idx = random.randint(0, size - 1)
+        assert (ds[idx]['a'] == items[idx]['a']).all()
diff --git a/utils/multiprocess_utils.py b/utils/multiprocess_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d3641a332eedfbaf27cda11dbd4a79b8a65072b
--- /dev/null
+++ b/utils/multiprocess_utils.py
@@ -0,0 +1,143 @@
+import os
+import traceback
+from functools import partial
+from tqdm import tqdm
+
+
+def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None):
+    ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
+    while True:
+        args = args_queue.get()
+        if args == '<KILL>':
+            return
+        job_idx, map_func, arg = args
+        try:
+            map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func
+            if isinstance(arg, dict):
+                res = map_func_(**arg)
+            elif isinstance(arg, (list, tuple)):
+                res = map_func_(*arg)
+            else:
+                res = map_func_(arg)
+            results_queue.put((job_idx, res))
+        except:
+            traceback.print_exc()
+            results_queue.put((job_idx, None))
+
+
+class MultiprocessManager:
+    def __init__(self, num_workers=None, init_ctx_func=None, multithread=False):
+        if multithread:
+            from multiprocessing.dummy import Queue, Process
+        else:
+            from multiprocessing import Queue, Process
+        if num_workers is None:
+            num_workers = int(os.getenv('N_PROC', os.cpu_count()))
+        self.num_workers = num_workers
+        self.results_queue = Queue(maxsize=-1)
+        self.args_queue = Queue(maxsize=-1)
+        self.workers = []
+        self.total_jobs = 0
+        for i in range(num_workers):
+            p = Process(target=chunked_worker,
+                        args=(i, self.args_queue, self.results_queue, init_ctx_func),
+                        daemon=True)
+            self.workers.append(p)
+            p.start()
+
+    def add_job(self, func, args):
+        self.args_queue.put((self.total_jobs, func, args))
+        self.total_jobs += 1
+
+    def get_results(self):
+        for w in range(self.num_workers):
+            self.args_queue.put("<KILL>")
+        self.n_finished = 0
+        while self.n_finished < self.total_jobs:
+            job_id, res = self.results_queue.get()
+            yield job_id, res
+            self.n_finished += 1
+        for w in self.workers:
+            w.join()
+
+    def __len__(self):
+        return self.total_jobs
+
+
+def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None,
+                          multithread=False, desc=None):
+    for i, res in tqdm(enumerate(
+            multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread)),
+            total=len(args), desc=desc):
+        yield i, res
+
+
+def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False):
+    """
+    Multiprocessing running chunked jobs.
+    Examples:
+    >>> for res in tqdm(multiprocess_run(job_func, args):
+    >>>     print(res)
+    :param map_func:
+    :param args:
+    :param num_workers:
+    :param ordered:
+    :param init_ctx_func:
+    :param q_max_size:
+    :param multithread:
+    :return:
+    """
+    if num_workers is None:
+        num_workers = int(os.getenv('N_PROC', os.cpu_count()))
+    manager = MultiprocessManager(num_workers, init_ctx_func, multithread)
+    for arg in args:
+        manager.add_job(map_func, arg)
+    if ordered:
+        n_jobs = len(args)
+        results = ['<WAIT>' for _ in range(n_jobs)]
+        i_now = 0
+        for job_i, res in manager.get_results():
+            results[job_i] = res
+            while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != '<WAIT>'):
+                yield results[i_now]
+                i_now += 1
+    else:
+        for res in manager.get_results():
+            yield res
+
+
+def chunked_multiprocess_run(
+        map_func, args, num_workers=None, ordered=True,
+        init_ctx_func=None, q_max_size=1000, multithread=False):
+    if multithread:
+        from multiprocessing.dummy import Queue, Process
+    else:
+        from multiprocessing import Queue, Process
+    args = zip(range(len(args)), args)
+    args = list(args)
+    n_jobs = len(args)
+    if num_workers is None:
+        num_workers = int(os.getenv('N_PROC', os.cpu_count()))
+    results_queues = []
+    if ordered:
+        for i in range(num_workers):
+            results_queues.append(Queue(maxsize=q_max_size // num_workers))
+    else:
+        results_queue = Queue(maxsize=q_max_size)
+        for i in range(num_workers):
+            results_queues.append(results_queue)
+    workers = []
+    for i in range(num_workers):
+        args_worker = args[i::num_workers]
+        p = Process(target=chunked_worker, args=(
+            i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True)
+        workers.append(p)
+        p.start()
+    for n_finished in range(n_jobs):
+        results_queue = results_queues[n_finished % num_workers]
+        job_idx, res = results_queue.get()
+        assert job_idx == n_finished or not ordered, (job_idx, n_finished)
+        yield res
+    for w in workers:
+        w.join()
+
diff --git a/utils/os_utils.py b/utils/os_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c78a44c04eadc3feb3c35f88c8a074f59ab23778
--- /dev/null
+++ b/utils/os_utils.py
@@ -0,0 +1,20 @@
+import os
+import subprocess
+
+
+def link_file(from_file, to_file):
+    subprocess.check_call(
+        f'ln -s "`realpath --relative-to="{os.path.dirname(to_file)}" "{from_file}"`" "{to_file}"', shell=True)
+
+
+def move_file(from_file, to_file):
+    subprocess.check_call(f'mv "{from_file}" "{to_file}"', shell=True)
+
+
+def copy_file(from_file, to_file):
+    subprocess.check_call(f'cp -r "{from_file}" "{to_file}"', shell=True)
+
+
+def remove_file(*fns):
+    for f in fns:
+        subprocess.check_call(f'rm -rf "{f}"', shell=True)
\ No newline at end of file
diff --git a/utils/pitch_utils.py b/utils/pitch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7fd166abd3a03bac5909e498669b482447435cf
--- /dev/null
+++ b/utils/pitch_utils.py
@@ -0,0 +1,76 @@
+#########
+# world
+##########
+import librosa
+import numpy as np
+import torch
+
+gamma = 0
+mcepInput = 3  # 0 for dB, 3 for magnitude
+alpha = 0.45
+en_floor = 10 ** (-80 / 20)
+FFT_SIZE = 2048
+
+
+f0_bin = 256
+f0_max = 1100.0
+f0_min = 50.0
+f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+
+
+def f0_to_coarse(f0):
+    is_torch = isinstance(f0, torch.Tensor)
+    f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+    f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+    f0_mel[f0_mel <= 1] = 1
+    f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+    f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
+    assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
+    return f0_coarse
+
+
+def norm_f0(f0, uv, hparams):
+    is_torch = isinstance(f0, torch.Tensor)
+    if hparams['pitch_norm'] == 'standard':
+        f0 = (f0 - hparams['f0_mean']) / hparams['f0_std']
+    if hparams['pitch_norm'] == 'log':
+        f0 = torch.log2(f0) if is_torch else np.log2(f0)
+    if uv is not None and hparams['use_uv']:
+        f0[uv > 0] = 0
+    return f0
+
+
+def norm_interp_f0(f0, hparams):
+    is_torch = isinstance(f0, torch.Tensor)
+    if is_torch:
+        device = f0.device
+        f0 = f0.data.cpu().numpy()
+    uv = f0 == 0
+    f0 = norm_f0(f0, uv, hparams)
+    if sum(uv) == len(f0):
+        f0[uv] = 0
+    elif sum(uv) > 0:
+        f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
+    uv = torch.FloatTensor(uv)
+    f0 = torch.FloatTensor(f0)
+    if is_torch:
+        f0 = f0.to(device)
+    return f0, uv
+
+
+def denorm_f0(f0, uv, hparams, pitch_padding=None, min=None, max=None):
+    if hparams['pitch_norm'] == 'standard':
+        f0 = f0 * hparams['f0_std'] + hparams['f0_mean']
+    if hparams['pitch_norm'] == 'log':
+        f0 = 2 ** f0
+    if min is not None:
+        f0 = f0.clamp(min=min)
+    if max is not None:
+        f0 = f0.clamp(max=max)
+    if uv is not None and hparams['use_uv']:
+        f0[uv > 0] = 0
+    if pitch_padding is not None:
+        f0[pitch_padding] = 0
+    return f0
diff --git a/utils/pl_utils.py b/utils/pl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..76a94ed6abe22e349c51c49afdbf052d52b8d98b
--- /dev/null
+++ b/utils/pl_utils.py
@@ -0,0 +1,1618 @@
+import matplotlib
+from torch.nn import DataParallel
+from torch.nn.parallel import DistributedDataParallel
+
+matplotlib.use('Agg')
+import glob
+import itertools
+import subprocess
+import threading
+import traceback
+
+from pytorch_lightning.callbacks import GradientAccumulationScheduler
+from pytorch_lightning.callbacks import ModelCheckpoint
+
+from functools import wraps
+from torch.cuda._utils import _get_device_index
+import numpy as np
+import torch.optim
+import torch.utils.data
+import copy
+import logging
+import os
+import re
+import sys
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import tqdm
+from torch.optim.optimizer import Optimizer
+
+
+def get_a_var(obj):  # pragma: no cover
+    if isinstance(obj, torch.Tensor):
+        return obj
+
+    if isinstance(obj, list) or isinstance(obj, tuple):
+        for result in map(get_a_var, obj):
+            if isinstance(result, torch.Tensor):
+                return result
+    if isinstance(obj, dict):
+        for result in map(get_a_var, obj.items()):
+            if isinstance(result, torch.Tensor):
+                return result
+    return None
+
+
+def data_loader(fn):
+    """
+    Decorator to make any fx with this use the lazy property
+    :param fn:
+    :return:
+    """
+
+    wraps(fn)
+    attr_name = '_lazy_' + fn.__name__
+
+    def _get_data_loader(self):
+        try:
+            value = getattr(self, attr_name)
+        except AttributeError:
+            try:
+                value = fn(self)  # Lazy evaluation, done only once.
+                if (
+                        value is not None and
+                        not isinstance(value, list) and
+                        fn.__name__ in ['test_dataloader', 'val_dataloader']
+                ):
+                    value = [value]
+            except AttributeError as e:
+                # Guard against AttributeError suppression. (Issue #142)
+                traceback.print_exc()
+                error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
+                raise RuntimeError(error) from e
+            setattr(self, attr_name, value)  # Memoize evaluation.
+        return value
+
+    return _get_data_loader
+
+
+def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):  # pragma: no cover
+    r"""Applies each `module` in :attr:`modules` in parallel on arguments
+    contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
+    on each of :attr:`devices`.
+
+    Args:
+        modules (Module): modules to be parallelized
+        inputs (tensor): inputs to the modules
+        devices (list of int or torch.device): CUDA devices
+
+    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
+    :attr:`devices` (if given) should all have same length. Moreover, each
+    element of :attr:`inputs` can either be a single object as the only argument
+    to a module, or a collection of positional arguments.
+    """
+    assert len(modules) == len(inputs)
+    if kwargs_tup is not None:
+        assert len(modules) == len(kwargs_tup)
+    else:
+        kwargs_tup = ({},) * len(modules)
+    if devices is not None:
+        assert len(modules) == len(devices)
+    else:
+        devices = [None] * len(modules)
+    devices = list(map(lambda x: _get_device_index(x, True), devices))
+    lock = threading.Lock()
+    results = {}
+    grad_enabled = torch.is_grad_enabled()
+
+    def _worker(i, module, input, kwargs, device=None):
+        torch.set_grad_enabled(grad_enabled)
+        if device is None:
+            device = get_a_var(input).get_device()
+        try:
+            with torch.cuda.device(device):
+                # this also avoids accidental slicing of `input` if it is a Tensor
+                if not isinstance(input, (list, tuple)):
+                    input = (input,)
+
+                # ---------------
+                # CHANGE
+                if module.training:
+                    output = module.training_step(*input, **kwargs)
+
+                elif module.testing:
+                    output = module.test_step(*input, **kwargs)
+
+                else:
+                    output = module.validation_step(*input, **kwargs)
+                # ---------------
+
+            with lock:
+                results[i] = output
+        except Exception as e:
+            with lock:
+                results[i] = e
+
+    # make sure each module knows what training state it's in...
+    # fixes weird bug where copies are out of sync
+    root_m = modules[0]
+    for m in modules[1:]:
+        m.training = root_m.training
+        m.testing = root_m.testing
+
+    if len(modules) > 1:
+        threads = [threading.Thread(target=_worker,
+                                    args=(i, module, input, kwargs, device))
+                   for i, (module, input, kwargs, device) in
+                   enumerate(zip(modules, inputs, kwargs_tup, devices))]
+
+        for thread in threads:
+            thread.start()
+        for thread in threads:
+            thread.join()
+    else:
+        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
+
+    outputs = []
+    for i in range(len(inputs)):
+        output = results[i]
+        if isinstance(output, Exception):
+            raise output
+        outputs.append(output)
+    return outputs
+
+
+def _find_tensors(obj):  # pragma: no cover
+    r"""
+    Recursively find all tensors contained in the specified object.
+    """
+    if isinstance(obj, torch.Tensor):
+        return [obj]
+    if isinstance(obj, (list, tuple)):
+        return itertools.chain(*map(_find_tensors, obj))
+    if isinstance(obj, dict):
+        return itertools.chain(*map(_find_tensors, obj.values()))
+    return []
+
+
+class DDP(DistributedDataParallel):
+    """
+    Override the forward call in lightning so it goes to training and validation step respectively
+    """
+
+    def parallel_apply(self, replicas, inputs, kwargs):
+        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
+
+    def forward(self, *inputs, **kwargs):  # pragma: no cover
+        self._sync_params()
+        if self.device_ids:
+            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+            if len(self.device_ids) == 1:
+                # --------------
+                # LIGHTNING MOD
+                # --------------
+                # normal
+                # output = self.module(*inputs[0], **kwargs[0])
+                # lightning
+                if self.module.training:
+                    output = self.module.training_step(*inputs[0], **kwargs[0])
+                elif self.module.testing:
+                    output = self.module.test_step(*inputs[0], **kwargs[0])
+                else:
+                    output = self.module.validation_step(*inputs[0], **kwargs[0])
+            else:
+                outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
+                output = self.gather(outputs, self.output_device)
+        else:
+            # normal
+            output = self.module(*inputs, **kwargs)
+
+        if torch.is_grad_enabled():
+            # We'll return the output object verbatim since it is a freeform
+            # object. We need to find any tensors in this object, though,
+            # because we need to figure out which parameters were used during
+            # this forward pass, to ensure we short circuit reduction for any
+            # unused parameters. Only if `find_unused_parameters` is set.
+            if self.find_unused_parameters:
+                self.reducer.prepare_for_backward(list(_find_tensors(output)))
+            else:
+                self.reducer.prepare_for_backward([])
+        return output
+
+
+class DP(DataParallel):
+    """
+    Override the forward call in lightning so it goes to training and validation step respectively
+    """
+
+    def forward(self, *inputs, **kwargs):
+        if not self.device_ids:
+            return self.module(*inputs, **kwargs)
+
+        for t in itertools.chain(self.module.parameters(), self.module.buffers()):
+            if t.device != self.src_device_obj:
+                raise RuntimeError("module must have its parameters and buffers "
+                                   "on device {} (device_ids[0]) but found one of "
+                                   "them on device: {}".format(self.src_device_obj, t.device))
+
+        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+        if len(self.device_ids) == 1:
+            # lightning
+            if self.module.training:
+                return self.module.training_step(*inputs[0], **kwargs[0])
+            elif self.module.testing:
+                return self.module.test_step(*inputs[0], **kwargs[0])
+            else:
+                return self.module.validation_step(*inputs[0], **kwargs[0])
+
+        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
+        outputs = self.parallel_apply(replicas, inputs, kwargs)
+        return self.gather(outputs, self.output_device)
+
+    def parallel_apply(self, replicas, inputs, kwargs):
+        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
+
+
+class GradientAccumulationScheduler:
+    def __init__(self, scheduling: dict):
+        if scheduling == {}:  # empty dict error
+            raise TypeError("Empty dict cannot be interpreted correct")
+
+        for key in scheduling.keys():
+            if not isinstance(key, int) or not isinstance(scheduling[key], int):
+                raise TypeError("All epoches and accumulation factor must be integers")
+
+        minimal_epoch = min(scheduling.keys())
+        if minimal_epoch < 1:
+            msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
+            raise IndexError(msg)
+        elif minimal_epoch != 1:  # if user didnt define first epoch accumulation factor
+            scheduling.update({1: 1})
+
+        self.scheduling = scheduling
+        self.epochs = sorted(scheduling.keys())
+
+    def on_epoch_begin(self, epoch, trainer):
+        epoch += 1  # indexing epochs from 1
+        for i in reversed(range(len(self.epochs))):
+            if epoch >= self.epochs[i]:
+                trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
+                break
+
+
+class LatestModelCheckpoint(ModelCheckpoint):
+    def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5,
+                 save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True):
+        super(ModelCheckpoint, self).__init__()
+        self.monitor = monitor
+        self.verbose = verbose
+        self.filepath = filepath
+        os.makedirs(filepath, exist_ok=True)
+        self.num_ckpt_keep = num_ckpt_keep
+        self.save_best = save_best
+        self.save_weights_only = save_weights_only
+        self.period = period
+        self.epochs_since_last_check = 0
+        self.prefix = prefix
+        self.best_k_models = {}
+        # {filename: monitor}
+        self.kth_best_model = ''
+        self.save_top_k = 1
+        self.task = None
+        if mode == 'min':
+            self.monitor_op = np.less
+            self.best = np.Inf
+            self.mode = 'min'
+        elif mode == 'max':
+            self.monitor_op = np.greater
+            self.best = -np.Inf
+            self.mode = 'max'
+        else:
+            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
+                self.monitor_op = np.greater
+                self.best = -np.Inf
+                self.mode = 'max'
+            else:
+                self.monitor_op = np.less
+                self.best = np.Inf
+                self.mode = 'min'
+        if os.path.exists(f'{self.filepath}/best_valid.npy'):
+            self.best = np.load(f'{self.filepath}/best_valid.npy')[0]
+
+    def get_all_ckpts(self):
+        return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'),
+                      key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
+
+    def on_epoch_end(self, epoch, logs=None):
+        logs = logs or {}
+        self.epochs_since_last_check += 1
+        best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt'
+        if self.epochs_since_last_check >= self.period:
+            self.epochs_since_last_check = 0
+            filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt'
+            if self.verbose > 0:
+                logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}')
+            self._save_model(filepath)
+            for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]:
+                subprocess.check_call(f'rm -rf "{old_ckpt}"', shell=True)
+                if self.verbose > 0:
+                    logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}')
+            current = logs.get(self.monitor)
+            if current is not None and self.save_best:
+                if self.monitor_op(current, self.best):
+                    self.best = current
+                    if self.verbose > 0:
+                        logging.info(
+                            f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached'
+                            f' {current:0.5f} (best {self.best:0.5f}), saving model to'
+                            f' {best_filepath} as top 1')
+                    self._save_model(best_filepath)
+                    np.save(f'{self.filepath}/best_valid.npy', [self.best])
+
+
+class BaseTrainer:
+    def __init__(
+            self,
+            logger=True,
+            checkpoint_callback=True,
+            default_save_path=None,
+            gradient_clip_val=0,
+            process_position=0,
+            gpus=-1,
+            log_gpu_memory=None,
+            show_progress_bar=True,
+            track_grad_norm=-1,
+            check_val_every_n_epoch=1,
+            accumulate_grad_batches=1,
+            max_updates=1000,
+            min_epochs=1,
+            val_check_interval=1.0,
+            log_save_interval=100,
+            row_log_interval=10,
+            print_nan_grads=False,
+            weights_summary='full',
+            num_sanity_val_steps=5,
+            resume_from_checkpoint=None,
+    ):
+        self.log_gpu_memory = log_gpu_memory
+        self.gradient_clip_val = gradient_clip_val
+        self.check_val_every_n_epoch = check_val_every_n_epoch
+        self.track_grad_norm = track_grad_norm
+        self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
+        self.process_position = process_position
+        self.weights_summary = weights_summary
+        self.max_updates = max_updates
+        self.min_epochs = min_epochs
+        self.num_sanity_val_steps = num_sanity_val_steps
+        self.print_nan_grads = print_nan_grads
+        self.resume_from_checkpoint = resume_from_checkpoint
+        self.default_save_path = default_save_path
+
+        # training bookeeping
+        self.total_batch_idx = 0
+        self.running_loss = []
+        self.avg_loss = 0
+        self.batch_idx = 0
+        self.tqdm_metrics = {}
+        self.callback_metrics = {}
+        self.num_val_batches = 0
+        self.num_training_batches = 0
+        self.num_test_batches = 0
+        self.get_train_dataloader = None
+        self.get_test_dataloaders = None
+        self.get_val_dataloaders = None
+        self.is_iterable_train_dataloader = False
+
+        # training state
+        self.model = None
+        self.testing = False
+        self.disable_validation = False
+        self.lr_schedulers = []
+        self.optimizers = None
+        self.global_step = 0
+        self.current_epoch = 0
+        self.total_batches = 0
+
+        # configure checkpoint callback
+        self.checkpoint_callback = checkpoint_callback
+        self.checkpoint_callback.save_function = self.save_checkpoint
+        self.weights_save_path = self.checkpoint_callback.filepath
+
+        # accumulated grads
+        self.configure_accumulated_gradients(accumulate_grad_batches)
+
+        # allow int, string and gpu list
+        self.data_parallel_device_ids = [
+            int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
+        if len(self.data_parallel_device_ids) == 0:
+            self.root_gpu = None
+            self.on_gpu = False
+        else:
+            self.root_gpu = self.data_parallel_device_ids[0]
+            self.on_gpu = True
+
+        # distributed backend choice
+        self.use_ddp = False
+        self.use_dp = False
+        self.single_gpu = False
+        self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp'
+        self.set_distributed_mode(self.distributed_backend)
+
+        self.proc_rank = 0
+        self.world_size = 1
+        self.node_rank = 0
+
+        # can't init progress bar here because starting a new process
+        # means the progress_bar won't survive pickling
+        self.show_progress_bar = show_progress_bar
+
+        # logging
+        self.log_save_interval = log_save_interval
+        self.val_check_interval = val_check_interval
+        self.logger = logger
+        self.logger.rank = 0
+        self.row_log_interval = row_log_interval
+
+    @property
+    def num_gpus(self):
+        gpus = self.data_parallel_device_ids
+        if gpus is None:
+            return 0
+        else:
+            return len(gpus)
+
+    @property
+    def data_parallel(self):
+        return self.use_dp or self.use_ddp
+
+    def get_model(self):
+        is_dp_module = isinstance(self.model, (DDP, DP))
+        model = self.model.module if is_dp_module else self.model
+        return model
+
+    # -----------------------------
+    # MODEL TRAINING
+    # -----------------------------
+    def fit(self, model):
+        if self.use_ddp:
+            mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
+        else:
+            model.model = model.build_model()
+            if not self.testing:
+                self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
+            if self.use_dp:
+                model.cuda(self.root_gpu)
+                model = DP(model, device_ids=self.data_parallel_device_ids)
+            elif self.single_gpu:
+                model.cuda(self.root_gpu)
+            self.run_pretrain_routine(model)
+        return 1
+
+    def init_optimizers(self, optimizers):
+
+        # single optimizer
+        if isinstance(optimizers, Optimizer):
+            return [optimizers], []
+
+        # two lists
+        elif len(optimizers) == 2 and isinstance(optimizers[0], list):
+            optimizers, lr_schedulers = optimizers
+            return optimizers, lr_schedulers
+
+        # single list or tuple
+        elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
+            return optimizers, []
+
+    def run_pretrain_routine(self, model):
+        """Sanity check a few things before starting actual training.
+
+        :param model:
+        """
+        ref_model = model
+        if self.data_parallel:
+            ref_model = model.module
+
+        # give model convenience properties
+        ref_model.trainer = self
+
+        # set local properties on the model
+        self.copy_trainer_model_properties(ref_model)
+
+        # link up experiment object
+        if self.logger is not None:
+            ref_model.logger = self.logger
+            self.logger.save()
+
+        if self.use_ddp:
+            dist.barrier()
+
+        # set up checkpoint callback
+        # self.configure_checkpoint_callback()
+
+        # transfer data loaders from model
+        self.get_dataloaders(ref_model)
+
+        # track model now.
+        # if cluster resets state, the model will update with the saved weights
+        self.model = model
+
+        # restore training and model before hpc call
+        self.restore_weights(model)
+
+        # when testing requested only run test and return
+        if self.testing:
+            self.run_evaluation(test=True)
+            return
+
+        # check if we should run validation during training
+        self.disable_validation = self.num_val_batches == 0
+
+        # run tiny validation (if validation defined)
+        # to make sure program won't crash during val
+        ref_model.on_sanity_check_start()
+        ref_model.on_train_start()
+        if not self.disable_validation and self.num_sanity_val_steps > 0:
+            # init progress bars for validation sanity check
+            pbar = tqdm.tqdm(desc='Validation sanity check',
+                             total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
+                             leave=False, position=2 * self.process_position,
+                             disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
+            self.main_progress_bar = pbar
+            # dummy validation progress bar
+            self.val_progress_bar = tqdm.tqdm(disable=True)
+
+            self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
+
+            # close progress bars
+            self.main_progress_bar.close()
+            self.val_progress_bar.close()
+
+        # init progress bar
+        pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
+                         disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
+                         file=sys.stdout)
+        self.main_progress_bar = pbar
+
+        # clear cache before training
+        if self.on_gpu:
+            torch.cuda.empty_cache()
+
+        # CORE TRAINING LOOP
+        self.train()
+
+    def test(self, model):
+        self.testing = True
+        self.fit(model)
+
+    @property
+    def training_tqdm_dict(self):
+        tqdm_dict = {
+            'step': '{}'.format(self.global_step),
+        }
+        tqdm_dict.update(self.tqdm_metrics)
+        return tqdm_dict
+
+    # --------------------
+    # restore ckpt
+    # --------------------
+    def restore_weights(self, model):
+        """
+        To restore weights we have two cases.
+        First, attempt to restore hpc weights. If successful, don't restore
+        other weights.
+
+        Otherwise, try to restore actual weights
+        :param model:
+        :return:
+        """
+        # clear cache before restore
+        if self.on_gpu:
+            torch.cuda.empty_cache()
+
+        if self.resume_from_checkpoint is not None:
+            self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)
+        else:
+            # restore weights if same exp version
+            self.restore_state_if_checkpoint_exists(model)
+
+        # wait for all models to restore weights
+        if self.use_ddp:
+            # wait for all processes to catch up
+            dist.barrier()
+
+        # clear cache after restore
+        if self.on_gpu:
+            torch.cuda.empty_cache()
+
+    def restore_state_if_checkpoint_exists(self, model):
+        did_restore = False
+
+        # do nothing if there's not dir or callback
+        no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback)
+        if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
+            return did_restore
+
+        # restore trainer state and model if there is a weight for this experiment
+        last_steps = -1
+        last_ckpt_name = None
+
+        # find last epoch
+        checkpoints = os.listdir(self.checkpoint_callback.filepath)
+        for name in checkpoints:
+            if '.ckpt' in name and not name.endswith('part'):
+                if 'steps_' in name:
+                    steps = name.split('steps_')[1]
+                    steps = int(re.sub('[^0-9]', '', steps))
+
+                    if steps > last_steps:
+                        last_steps = steps
+                        last_ckpt_name = name
+
+        # restore last checkpoint
+        if last_ckpt_name is not None:
+            last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
+            self.restore(last_ckpt_path, self.on_gpu)
+            logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}')
+            did_restore = True
+
+        return did_restore
+
+    def restore(self, checkpoint_path, on_gpu):
+        checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+        # load model state
+        model = self.get_model()
+
+        # load the state_dict on the model automatically
+        model.load_state_dict(checkpoint['state_dict'], strict=False)
+        if on_gpu:
+            model.cuda(self.root_gpu)
+        # load training state (affects trainer only)
+        self.restore_training_state(checkpoint)
+        model.global_step = self.global_step
+        del checkpoint
+
+        try:
+            if dist.is_initialized() and dist.get_rank() > 0:
+                return
+        except Exception as e:
+            print(e)
+            return
+
+    def restore_training_state(self, checkpoint):
+        """
+        Restore trainer state.
+        Model will get its change to update
+        :param checkpoint:
+        :return:
+        """
+        if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
+            self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
+
+        self.global_step = checkpoint['global_step']
+        self.current_epoch = checkpoint['epoch']
+
+        if self.testing:
+            return
+
+        # restore the optimizers
+        optimizer_states = checkpoint['optimizer_states']
+        for optimizer, opt_state in zip(self.optimizers, optimizer_states):
+            if optimizer is None:
+                return
+            optimizer.load_state_dict(opt_state)
+
+            # move optimizer to GPU 1 weight at a time
+            # avoids OOM
+            if self.root_gpu is not None:
+                for state in optimizer.state.values():
+                    for k, v in state.items():
+                        if isinstance(v, torch.Tensor):
+                            state[k] = v.cuda(self.root_gpu)
+
+        # restore the lr schedulers
+        lr_schedulers = checkpoint['lr_schedulers']
+        for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
+            scheduler.load_state_dict(lrs_state)
+
+    # --------------------
+    # MODEL SAVE CHECKPOINT
+    # --------------------
+    def _atomic_save(self, checkpoint, filepath):
+        """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
+
+        This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
+        saving is finished.
+
+        Args:
+            checkpoint (object): The object to save.
+                Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
+                accepts.
+            filepath (str|pathlib.Path): The path to which the checkpoint will be saved.
+                This points to the file that the checkpoint will be stored in.
+        """
+        tmp_path = str(filepath) + ".part"
+        torch.save(checkpoint, tmp_path)
+        os.replace(tmp_path, filepath)
+
+    def save_checkpoint(self, filepath):
+        checkpoint = self.dump_checkpoint()
+        self._atomic_save(checkpoint, filepath)
+
+    def dump_checkpoint(self):
+
+        checkpoint = {
+            'epoch': self.current_epoch,
+            'global_step': self.global_step
+        }
+
+        if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
+            checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
+
+        # save optimizers
+        optimizer_states = []
+        for i, optimizer in enumerate(self.optimizers):
+            if optimizer is not None:
+                optimizer_states.append(optimizer.state_dict())
+
+        checkpoint['optimizer_states'] = optimizer_states
+
+        # save lr schedulers
+        lr_schedulers = []
+        for i, scheduler in enumerate(self.lr_schedulers):
+            lr_schedulers.append(scheduler.state_dict())
+
+        checkpoint['lr_schedulers'] = lr_schedulers
+
+        # add the hparams and state_dict from the model
+        model = self.get_model()
+        checkpoint['state_dict'] = model.state_dict()
+        # give the model a chance to add a few things
+        model.on_save_checkpoint(checkpoint)
+
+        return checkpoint
+
+    def copy_trainer_model_properties(self, model):
+        if isinstance(model, DP):
+            ref_model = model.module
+        elif isinstance(model, DDP):
+            ref_model = model.module
+        else:
+            ref_model = model
+
+        for m in [model, ref_model]:
+            m.trainer = self
+            m.on_gpu = self.on_gpu
+            m.use_dp = self.use_dp
+            m.use_ddp = self.use_ddp
+            m.testing = self.testing
+            m.single_gpu = self.single_gpu
+
+    def transfer_batch_to_gpu(self, batch, gpu_id):
+        # base case: object can be directly moved using `cuda` or `to`
+        if callable(getattr(batch, 'cuda', None)):
+            return batch.cuda(gpu_id, non_blocking=True)
+
+        elif callable(getattr(batch, 'to', None)):
+            return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
+
+        # when list
+        elif isinstance(batch, list):
+            for i, x in enumerate(batch):
+                batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
+            return batch
+
+        # when tuple
+        elif isinstance(batch, tuple):
+            batch = list(batch)
+            for i, x in enumerate(batch):
+                batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
+            return tuple(batch)
+
+        # when dict
+        elif isinstance(batch, dict):
+            for k, v in batch.items():
+                batch[k] = self.transfer_batch_to_gpu(v, gpu_id)
+
+            return batch
+
+        # nothing matches, return the value as is without transform
+        return batch
+
+    def set_distributed_mode(self, distributed_backend):
+        # skip for CPU
+        if self.num_gpus == 0:
+            return
+
+        # single GPU case
+        # in single gpu case we allow ddp so we can train on multiple
+        # nodes, 1 gpu per node
+        elif self.num_gpus == 1:
+            self.single_gpu = True
+            self.use_dp = False
+            self.use_ddp = False
+            self.root_gpu = 0
+            self.data_parallel_device_ids = [0]
+        else:
+            if distributed_backend is not None:
+                self.use_dp = distributed_backend == 'dp'
+                self.use_ddp = distributed_backend == 'ddp'
+            elif distributed_backend is None:
+                self.use_dp = True
+                self.use_ddp = False
+
+        logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}')
+
+    def ddp_train(self, gpu_idx, model):
+        """
+        Entry point into a DP thread
+        :param gpu_idx:
+        :param model:
+        :param cluster_obj:
+        :return:
+        """
+        # otherwise default to node rank 0
+        self.node_rank = 0
+
+        # show progressbar only on progress_rank 0
+        self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0
+
+        # determine which process we are and world size
+        if self.use_ddp:
+            self.proc_rank = self.node_rank * self.num_gpus + gpu_idx
+            self.world_size = self.num_gpus
+
+        # let the exp know the rank to avoid overwriting logs
+        if self.logger is not None:
+            self.logger.rank = self.proc_rank
+
+        # set up server using proc 0's ip address
+        # try to init for 20 times at max in case ports are taken
+        # where to store ip_table
+        model.trainer = self
+        model.init_ddp_connection(self.proc_rank, self.world_size)
+
+        # CHOOSE OPTIMIZER
+        # allow for lr schedulers as well
+        model.model = model.build_model()
+        if not self.testing:
+            self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
+
+        # MODEL
+        # copy model to each gpu
+        if self.distributed_backend == 'ddp':
+            torch.cuda.set_device(gpu_idx)
+        model.cuda(gpu_idx)
+
+        # set model properties before going into wrapper
+        self.copy_trainer_model_properties(model)
+
+        # override root GPU
+        self.root_gpu = gpu_idx
+
+        if self.distributed_backend == 'ddp':
+            device_ids = [gpu_idx]
+        else:
+            device_ids = None
+
+        # allow user to configure ddp
+        model = model.configure_ddp(model, device_ids)
+
+        # continue training routine
+        self.run_pretrain_routine(model)
+
+    def resolve_root_node_address(self, root_node):
+        if '[' in root_node:
+            name = root_node.split('[')[0]
+            number = root_node.split(',')[0]
+            if '-' in number:
+                number = number.split('-')[0]
+
+            number = re.sub('[^0-9]', '', number)
+            root_node = name + number
+
+        return root_node
+
+    def log_metrics(self, metrics, grad_norm_dic, step=None):
+        """Logs the metric dict passed in.
+
+        :param metrics:
+        :param grad_norm_dic:
+        """
+        # added metrics by Lightning for convenience
+        metrics['epoch'] = self.current_epoch
+
+        # add norms
+        metrics.update(grad_norm_dic)
+
+        # turn all tensors to scalars
+        scalar_metrics = self.metrics_to_scalars(metrics)
+
+        step = step if step is not None else self.global_step
+        # log actual metrics
+        if self.proc_rank == 0 and self.logger is not None:
+            self.logger.log_metrics(scalar_metrics, step=step)
+            self.logger.save()
+
+    def add_tqdm_metrics(self, metrics):
+        for k, v in metrics.items():
+            if type(v) is torch.Tensor:
+                v = v.item()
+
+            self.tqdm_metrics[k] = v
+
+    def metrics_to_scalars(self, metrics):
+        new_metrics = {}
+        for k, v in metrics.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+
+            if type(v) is dict:
+                v = self.metrics_to_scalars(v)
+
+            new_metrics[k] = v
+
+        return new_metrics
+
+    def process_output(self, output, train=False):
+        """Reduces output according to the training mode.
+
+        Separates loss from logging and tqdm metrics
+        :param output:
+        :return:
+        """
+        # ---------------
+        # EXTRACT CALLBACK KEYS
+        # ---------------
+        # all keys not progress_bar or log are candidates for callbacks
+        callback_metrics = {}
+        for k, v in output.items():
+            if k not in ['progress_bar', 'log', 'hiddens']:
+                callback_metrics[k] = v
+
+        if train and self.use_dp:
+            num_gpus = self.num_gpus
+            callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
+
+        for k, v in callback_metrics.items():
+            if isinstance(v, torch.Tensor):
+                callback_metrics[k] = v.item()
+
+        # ---------------
+        # EXTRACT PROGRESS BAR KEYS
+        # ---------------
+        try:
+            progress_output = output['progress_bar']
+
+            # reduce progress metrics for tqdm when using dp
+            if train and self.use_dp:
+                num_gpus = self.num_gpus
+                progress_output = self.reduce_distributed_output(progress_output, num_gpus)
+
+            progress_bar_metrics = progress_output
+        except Exception:
+            progress_bar_metrics = {}
+
+        # ---------------
+        # EXTRACT LOGGING KEYS
+        # ---------------
+        # extract metrics to log to experiment
+        try:
+            log_output = output['log']
+
+            # reduce progress metrics for tqdm when using dp
+            if train and self.use_dp:
+                num_gpus = self.num_gpus
+                log_output = self.reduce_distributed_output(log_output, num_gpus)
+
+            log_metrics = log_output
+        except Exception:
+            log_metrics = {}
+
+        # ---------------
+        # EXTRACT LOSS
+        # ---------------
+        # if output dict doesn't have the keyword loss
+        # then assume the output=loss if scalar
+        loss = None
+        if train:
+            try:
+                loss = output['loss']
+            except Exception:
+                if type(output) is torch.Tensor:
+                    loss = output
+                else:
+                    raise RuntimeError(
+                        'No `loss` value in the dictionary returned from `model.training_step()`.'
+                    )
+
+            # when using dp need to reduce the loss
+            if self.use_dp:
+                loss = self.reduce_distributed_output(loss, self.num_gpus)
+
+        # ---------------
+        # EXTRACT HIDDEN
+        # ---------------
+        hiddens = output.get('hiddens')
+
+        # use every metric passed in as a candidate for callback
+        callback_metrics.update(progress_bar_metrics)
+        callback_metrics.update(log_metrics)
+
+        # convert tensors to numpy
+        for k, v in callback_metrics.items():
+            if isinstance(v, torch.Tensor):
+                callback_metrics[k] = v.item()
+
+        return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
+
+    def reduce_distributed_output(self, output, num_gpus):
+        if num_gpus <= 1:
+            return output
+
+        # when using DP, we get one output per gpu
+        # average outputs and return
+        if type(output) is torch.Tensor:
+            return output.mean()
+
+        for k, v in output.items():
+            # recurse on nested dics
+            if isinstance(output[k], dict):
+                output[k] = self.reduce_distributed_output(output[k], num_gpus)
+
+            # do nothing when there's a scalar
+            elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
+                pass
+
+            # reduce only metrics that have the same number of gpus
+            elif output[k].size(0) == num_gpus:
+                reduced = torch.mean(output[k])
+                output[k] = reduced
+        return output
+
+    def clip_gradients(self):
+        if self.gradient_clip_val > 0:
+            model = self.get_model()
+            torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
+
+    def print_nan_gradients(self):
+        model = self.get_model()
+        for param in model.parameters():
+            if (param.grad is not None) and torch.isnan(param.grad.float()).any():
+                logging.info(param, param.grad)
+
+    def configure_accumulated_gradients(self, accumulate_grad_batches):
+        self.accumulate_grad_batches = None
+
+        if isinstance(accumulate_grad_batches, dict):
+            self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
+        elif isinstance(accumulate_grad_batches, int):
+            schedule = {1: accumulate_grad_batches}
+            self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
+        else:
+            raise TypeError("Gradient accumulation supports only int and dict types")
+
+    def get_dataloaders(self, model):
+        if not self.testing:
+            self.init_train_dataloader(model)
+            self.init_val_dataloader(model)
+        else:
+            self.init_test_dataloader(model)
+
+        if self.use_ddp:
+            dist.barrier()
+            if not self.testing:
+                self.get_train_dataloader()
+                self.get_val_dataloaders()
+            else:
+                self.get_test_dataloaders()
+
+    def init_train_dataloader(self, model):
+        self.fisrt_epoch = True
+        self.get_train_dataloader = model.train_dataloader
+        if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader):
+            self.num_training_batches = len(self.get_train_dataloader())
+            self.num_training_batches = int(self.num_training_batches)
+        else:
+            self.num_training_batches = float('inf')
+            self.is_iterable_train_dataloader = True
+        if isinstance(self.val_check_interval, int):
+            self.val_check_batch = self.val_check_interval
+        else:
+            self._percent_range_check('val_check_interval')
+            self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
+            self.val_check_batch = max(1, self.val_check_batch)
+
+    def init_val_dataloader(self, model):
+        self.get_val_dataloaders = model.val_dataloader
+        self.num_val_batches = 0
+        if self.get_val_dataloaders() is not None:
+            if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader):
+                self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
+                self.num_val_batches = int(self.num_val_batches)
+            else:
+                self.num_val_batches = float('inf')
+
+    def init_test_dataloader(self, model):
+        self.get_test_dataloaders = model.test_dataloader
+        if self.get_test_dataloaders() is not None:
+            if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader):
+                self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
+                self.num_test_batches = int(self.num_test_batches)
+            else:
+                self.num_test_batches = float('inf')
+
+    def evaluate(self, model, dataloaders, max_batches, test=False):
+        """Run evaluation code.
+
+        :param model: PT model
+        :param dataloaders: list of PT dataloaders
+        :param max_batches: Scalar
+        :param test: boolean
+        :return:
+        """
+        # enable eval mode
+        model.zero_grad()
+        model.eval()
+
+        # copy properties for forward overrides
+        self.copy_trainer_model_properties(model)
+
+        # disable gradients to save memory
+        torch.set_grad_enabled(False)
+
+        if test:
+            self.get_model().test_start()
+        # bookkeeping
+        outputs = []
+
+        # run training
+        for dataloader_idx, dataloader in enumerate(dataloaders):
+            dl_outputs = []
+            for batch_idx, batch in enumerate(dataloader):
+
+                if batch is None:  # pragma: no cover
+                    continue
+
+                # stop short when on fast_dev_run (sets max_batch=1)
+                if batch_idx >= max_batches:
+                    break
+
+                # -----------------
+                # RUN EVALUATION STEP
+                # -----------------
+                output = self.evaluation_forward(model,
+                                                 batch,
+                                                 batch_idx,
+                                                 dataloader_idx,
+                                                 test)
+
+                # track outputs for collation
+                dl_outputs.append(output)
+
+                # batch done
+                if test:
+                    self.test_progress_bar.update(1)
+                else:
+                    self.val_progress_bar.update(1)
+            outputs.append(dl_outputs)
+
+        # with a single dataloader don't pass an array
+        if len(dataloaders) == 1:
+            outputs = outputs[0]
+
+        # give model a chance to do something with the outputs (and method defined)
+        model = self.get_model()
+        if test:
+            eval_results_ = model.test_end(outputs)
+        else:
+            eval_results_ = model.validation_end(outputs)
+        eval_results = eval_results_
+
+        # enable train mode again
+        model.train()
+
+        # enable gradients to save memory
+        torch.set_grad_enabled(True)
+
+        return eval_results
+
+    def run_evaluation(self, test=False):
+        # when testing make sure user defined a test step
+        model = self.get_model()
+        model.on_pre_performance_check()
+
+        # select dataloaders
+        if test:
+            dataloaders = self.get_test_dataloaders()
+            max_batches = self.num_test_batches
+        else:
+            # val
+            dataloaders = self.get_val_dataloaders()
+            max_batches = self.num_val_batches
+
+        # init validation or test progress bar
+        # main progress bar will already be closed when testing so initial position is free
+        position = 2 * self.process_position + (not test)
+        desc = 'Testing' if test else 'Validating'
+        pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
+                         disable=not self.show_progress_bar, dynamic_ncols=True,
+                         unit='batch', file=sys.stdout)
+        setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
+
+        # run evaluation
+        eval_results = self.evaluate(self.model,
+                                     dataloaders,
+                                     max_batches,
+                                     test)
+        if eval_results is not None:
+            _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
+                eval_results)
+
+            # add metrics to prog bar
+            self.add_tqdm_metrics(prog_bar_metrics)
+
+            # log metrics
+            self.log_metrics(log_metrics, {})
+
+            # track metrics for callbacks
+            self.callback_metrics.update(callback_metrics)
+
+        # hook
+        model.on_post_performance_check()
+
+        # add model specific metrics
+        tqdm_metrics = self.training_tqdm_dict
+        if not test:
+            self.main_progress_bar.set_postfix(**tqdm_metrics)
+
+        # close progress bar
+        if test:
+            self.test_progress_bar.close()
+        else:
+            self.val_progress_bar.close()
+
+        # model checkpointing
+        if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
+            self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
+                                                  logs=self.callback_metrics)
+
+    def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
+        # make dataloader_idx arg in validation_step optional
+        args = [batch, batch_idx]
+
+        if test and len(self.get_test_dataloaders()) > 1:
+            args.append(dataloader_idx)
+
+        elif not test and len(self.get_val_dataloaders()) > 1:
+            args.append(dataloader_idx)
+
+        # handle DP, DDP forward
+        if self.use_ddp or self.use_dp:
+            output = model(*args)
+            return output
+
+        # single GPU
+        if self.single_gpu:
+            # for single GPU put inputs on gpu manually
+            root_gpu = 0
+            if isinstance(self.data_parallel_device_ids, list):
+                root_gpu = self.data_parallel_device_ids[0]
+            batch = self.transfer_batch_to_gpu(batch, root_gpu)
+            args[0] = batch
+
+        # CPU
+        if test:
+            output = model.test_step(*args)
+        else:
+            output = model.validation_step(*args)
+
+        return output
+
+    def train(self):
+        model = self.get_model()
+        # run all epochs
+        for epoch in range(self.current_epoch, 1000000):
+            # set seed for distributed sampler (enables shuffling for each epoch)
+            if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
+                self.get_train_dataloader().sampler.set_epoch(epoch)
+
+            # get model
+            model = self.get_model()
+
+            # update training progress in trainer and model
+            model.current_epoch = epoch
+            self.current_epoch = epoch
+
+            total_val_batches = 0
+            if not self.disable_validation:
+                # val can be checked multiple times in epoch
+                is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
+                val_checks_per_epoch = self.num_training_batches // self.val_check_batch
+                val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
+                total_val_batches = self.num_val_batches * val_checks_per_epoch
+
+            # total batches includes multiple val checks
+            self.total_batches = self.num_training_batches + total_val_batches
+            self.batch_loss_value = 0  # accumulated grads
+
+            if self.is_iterable_train_dataloader:
+                # for iterable train loader, the progress bar never ends
+                num_iterations = None
+            else:
+                num_iterations = self.total_batches
+
+            # reset progress bar
+            # .reset() doesn't work on disabled progress bar so we should check
+            desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
+            self.main_progress_bar.set_description(desc)
+
+            # changing gradient according accumulation_scheduler
+            self.accumulation_scheduler.on_epoch_begin(epoch, self)
+
+            # -----------------
+            # RUN TNG EPOCH
+            # -----------------
+            self.run_training_epoch()
+
+            # update LR schedulers
+            if self.lr_schedulers is not None:
+                for lr_scheduler in self.lr_schedulers:
+                    lr_scheduler.step(epoch=self.current_epoch)
+
+        self.main_progress_bar.close()
+
+        model.on_train_end()
+
+        if self.logger is not None:
+            self.logger.finalize("success")
+
+    def run_training_epoch(self):
+        # before epoch hook
+        if self.is_function_implemented('on_epoch_start'):
+            model = self.get_model()
+            model.on_epoch_start()
+
+        # run epoch
+        for batch_idx, batch in enumerate(self.get_train_dataloader()):
+            # stop epoch if we limited the number of training batches
+            if batch_idx >= self.num_training_batches:
+                break
+
+            self.batch_idx = batch_idx
+
+            model = self.get_model()
+            model.global_step = self.global_step
+
+            # ---------------
+            # RUN TRAIN STEP
+            # ---------------
+            output = self.run_training_batch(batch, batch_idx)
+            batch_result, grad_norm_dic, batch_step_metrics = output
+
+            # when returning -1 from train_step, we end epoch early
+            early_stop_epoch = batch_result == -1
+
+            # ---------------
+            # RUN VAL STEP
+            # ---------------
+            should_check_val = (
+                    not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch)
+            self.fisrt_epoch = False
+
+            if should_check_val:
+                self.run_evaluation(test=self.testing)
+
+            # when logs should be saved
+            should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
+            if should_save_log:
+                if self.proc_rank == 0 and self.logger is not None:
+                    self.logger.save()
+
+            # when metrics should be logged
+            should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
+            if should_log_metrics:
+                # logs user requested information to logger
+                self.log_metrics(batch_step_metrics, grad_norm_dic)
+
+            self.global_step += 1
+            self.total_batch_idx += 1
+
+            # end epoch early
+            # stop when the flag is changed or we've gone past the amount
+            # requested in the batches
+            if early_stop_epoch:
+                break
+            if self.global_step > self.max_updates:
+                print("| Training end..")
+                exit()
+
+        # epoch end hook
+        if self.is_function_implemented('on_epoch_end'):
+            model = self.get_model()
+            model.on_epoch_end()
+
+    def run_training_batch(self, batch, batch_idx):
+        # track grad norms
+        grad_norm_dic = {}
+
+        # track all metrics for callbacks
+        all_callback_metrics = []
+
+        # track metrics to log
+        all_log_metrics = []
+
+        if batch is None:
+            return 0, grad_norm_dic, {}
+
+        # hook
+        if self.is_function_implemented('on_batch_start'):
+            model_ref = self.get_model()
+            response = model_ref.on_batch_start(batch)
+
+            if response == -1:
+                return -1, grad_norm_dic, {}
+
+        splits = [batch]
+        self.hiddens = None
+        for split_idx, split_batch in enumerate(splits):
+            self.split_idx = split_idx
+
+            # call training_step once per optimizer
+            for opt_idx, optimizer in enumerate(self.optimizers):
+                if optimizer is None:
+                    continue
+                # make sure only the gradients of the current optimizer's paramaters are calculated
+                # in the training step to prevent dangling gradients in multiple-optimizer setup.
+                if len(self.optimizers) > 1:
+                    for param in self.get_model().parameters():
+                        param.requires_grad = False
+                    for group in optimizer.param_groups:
+                        for param in group['params']:
+                            param.requires_grad = True
+
+                # wrap the forward step in a closure so second order methods work
+                def optimizer_closure():
+                    # forward pass
+                    output = self.training_forward(
+                        split_batch, batch_idx, opt_idx, self.hiddens)
+
+                    closure_loss = output[0]
+                    progress_bar_metrics = output[1]
+                    log_metrics = output[2]
+                    callback_metrics = output[3]
+                    self.hiddens = output[4]
+                    if closure_loss is None:
+                        return None
+
+                    # accumulate loss
+                    # (if accumulate_grad_batches = 1 no effect)
+                    closure_loss = closure_loss / self.accumulate_grad_batches
+
+                    # backward pass
+                    model_ref = self.get_model()
+                    if closure_loss.requires_grad:
+                        model_ref.backward(closure_loss, optimizer)
+
+                    # track metrics for callbacks
+                    all_callback_metrics.append(callback_metrics)
+
+                    # track progress bar metrics
+                    self.add_tqdm_metrics(progress_bar_metrics)
+                    all_log_metrics.append(log_metrics)
+
+                    # insert after step hook
+                    if self.is_function_implemented('on_after_backward'):
+                        model_ref = self.get_model()
+                        model_ref.on_after_backward()
+
+                    return closure_loss
+
+                # calculate loss
+                loss = optimizer_closure()
+                if loss is None:
+                    continue
+
+                # nan grads
+                if self.print_nan_grads:
+                    self.print_nan_gradients()
+
+                # track total loss for logging (avoid mem leaks)
+                self.batch_loss_value += loss.item()
+
+                # gradient update with accumulated gradients
+                if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
+
+                    # track gradient norms when requested
+                    if batch_idx % self.row_log_interval == 0:
+                        if self.track_grad_norm > 0:
+                            model = self.get_model()
+                            grad_norm_dic = model.grad_norm(
+                                self.track_grad_norm)
+
+                    # clip gradients
+                    self.clip_gradients()
+
+                    # calls .step(), .zero_grad()
+                    # override function to modify this behavior
+                    model = self.get_model()
+                    model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx)
+
+                    # calculate running loss for display
+                    self.running_loss.append(self.batch_loss_value)
+                    self.batch_loss_value = 0
+                    self.avg_loss = np.mean(self.running_loss[-100:])
+
+        # activate batch end hook
+        if self.is_function_implemented('on_batch_end'):
+            model = self.get_model()
+            model.on_batch_end()
+
+        # update progress bar
+        self.main_progress_bar.update(1)
+        self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
+
+        # collapse all metrics into one dict
+        all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
+
+        # track all metrics for callbacks
+        self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
+
+        return 0, grad_norm_dic, all_log_metrics
+
+    def training_forward(self, batch, batch_idx, opt_idx, hiddens):
+        """
+        Handle forward for each training case (distributed, single gpu, etc...)
+        :param batch:
+        :param batch_idx:
+        :return:
+        """
+        # ---------------
+        # FORWARD
+        # ---------------
+        # enable not needing to add opt_idx to training_step
+        args = [batch, batch_idx, opt_idx]
+
+        # distributed forward
+        if self.use_ddp or self.use_dp:
+            output = self.model(*args)
+        # single GPU forward
+        elif self.single_gpu:
+            gpu_id = 0
+            if isinstance(self.data_parallel_device_ids, list):
+                gpu_id = self.data_parallel_device_ids[0]
+            batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
+            args[0] = batch
+            output = self.model.training_step(*args)
+        # CPU forward
+        else:
+            output = self.model.training_step(*args)
+
+        # allow any mode to define training_end
+        model_ref = self.get_model()
+        output_ = model_ref.training_end(output)
+        if output_ is not None:
+            output = output_
+
+        # format and reduce outputs accordingly
+        output = self.process_output(output, train=True)
+
+        return output
+
+    # ---------------
+    # Utils
+    # ---------------
+    def is_function_implemented(self, f_name):
+        model = self.get_model()
+        f_op = getattr(model, f_name, None)
+        return callable(f_op)
+
+    def _percent_range_check(self, name):
+        value = getattr(self, name)
+        msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
+        if name == "val_check_interval":
+            msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
+
+        if not 0. <= value <= 1.:
+            raise ValueError(msg)
diff --git a/utils/plot.py b/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdca62a8cd80869c707890cd9febd39966cd3658
--- /dev/null
+++ b/utils/plot.py
@@ -0,0 +1,56 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime']
+
+
+def spec_to_figure(spec, vmin=None, vmax=None):
+    if isinstance(spec, torch.Tensor):
+        spec = spec.cpu().numpy()
+    fig = plt.figure(figsize=(12, 6))
+    plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
+    return fig
+
+
+def spec_f0_to_figure(spec, f0s, figsize=None):
+    max_y = spec.shape[1]
+    if isinstance(spec, torch.Tensor):
+        spec = spec.detach().cpu().numpy()
+        f0s = {k: f0.detach().cpu().numpy() for k, f0 in f0s.items()}
+    f0s = {k: f0 / 10 for k, f0 in f0s.items()}
+    fig = plt.figure(figsize=(12, 6) if figsize is None else figsize)
+    plt.pcolor(spec.T)
+    for i, (k, f0) in enumerate(f0s.items()):
+        plt.plot(f0.clip(0, max_y), label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.8)
+    plt.legend()
+    return fig
+
+
+def dur_to_figure(dur_gt, dur_pred, txt):
+    dur_gt = dur_gt.long().cpu().numpy()
+    dur_pred = dur_pred.long().cpu().numpy()
+    dur_gt = np.cumsum(dur_gt)
+    dur_pred = np.cumsum(dur_pred)
+    fig = plt.figure(figsize=(12, 6))
+    for i in range(len(dur_gt)):
+        shift = (i % 8) + 1
+        plt.text(dur_gt[i], shift, txt[i])
+        plt.text(dur_pred[i], 10 + shift, txt[i])
+        plt.vlines(dur_gt[i], 0, 10, colors='b')  # blue is gt
+        plt.vlines(dur_pred[i], 10, 20, colors='r')  # red is pred
+    return fig
+
+
+def f0_to_figure(f0_gt, f0_cwt=None, f0_pred=None):
+    fig = plt.figure()
+    f0_gt = f0_gt.cpu().numpy()
+    plt.plot(f0_gt, color='r', label='gt')
+    if f0_cwt is not None:
+        f0_cwt = f0_cwt.cpu().numpy()
+        plt.plot(f0_cwt, color='b', label='cwt')
+    if f0_pred is not None:
+        f0_pred = f0_pred.cpu().numpy()
+        plt.plot(f0_pred, color='green', label='pred')
+    plt.legend()
+    return fig
diff --git a/utils/rnnoise.py b/utils/rnnoise.py
new file mode 100755
index 0000000000000000000000000000000000000000..47f4eb6471918ca8144f217580a71d1720cd8c36
--- /dev/null
+++ b/utils/rnnoise.py
@@ -0,0 +1,48 @@
+# rnnoise.py, requirements: ffmpeg, sox, rnnoise, python
+import os
+import subprocess
+
+INSTALL_STR = """
+RNNoise library not found. Please install RNNoise (https://github.com/xiph/rnnoise) to $REPO/rnnoise:
+sudo apt-get install -y autoconf automake libtool ffmpeg sox
+git clone https://github.com/xiph/rnnoise.git
+rm -rf rnnoise/.git 
+cd rnnoise
+./autogen.sh && ./configure && make
+cd ..
+"""
+
+
+def rnnoise(filename, out_fn=None, verbose=False, out_sample_rate=22050):
+    assert os.path.exists('./rnnoise/examples/rnnoise_demo'), INSTALL_STR
+    if out_fn is None:
+        out_fn = f"{filename[:-4]}.denoised.wav"
+    out_48k_fn = f"{out_fn}.48000.wav"
+    tmp0_fn = f"{out_fn}.0.wav"
+    tmp1_fn = f"{out_fn}.1.wav"
+    tmp2_fn = f"{out_fn}.2.raw"
+    tmp3_fn = f"{out_fn}.3.raw"
+    if verbose:
+        print("Pre-processing audio...")  # wav to pcm raw
+    subprocess.check_call(
+        f'sox "{filename}" -G -r48000 "{tmp0_fn}"', shell=True, stdin=subprocess.PIPE)  # convert to raw
+    subprocess.check_call(
+        f'sox -v 0.95 "{tmp0_fn}" "{tmp1_fn}"', shell=True, stdin=subprocess.PIPE)  # convert to raw
+    subprocess.check_call(
+        f'ffmpeg -y -i "{tmp1_fn}" -loglevel quiet -f s16le -ac 1 -ar 48000 "{tmp2_fn}"',
+        shell=True, stdin=subprocess.PIPE)  # convert to raw
+    if verbose:
+        print("Applying rnnoise algorithm to audio...")  # rnnoise
+    subprocess.check_call(
+        f'./rnnoise/examples/rnnoise_demo "{tmp2_fn}" "{tmp3_fn}"', shell=True)
+
+    if verbose:
+        print("Post-processing audio...")  # pcm raw to wav
+    if filename == out_fn:
+        subprocess.check_call(f'rm -f "{out_fn}"', shell=True)
+    subprocess.check_call(
+        f'sox -t raw -r 48000 -b 16 -e signed-integer -c 1 "{tmp3_fn}" "{out_48k_fn}"', shell=True)
+    subprocess.check_call(f'sox "{out_48k_fn}" -G -r{out_sample_rate} "{out_fn}"', shell=True)
+    subprocess.check_call(f'rm -f "{tmp0_fn}" "{tmp1_fn}" "{tmp2_fn}" "{tmp3_fn}" "{out_48k_fn}"', shell=True)
+    if verbose:
+        print("Audio-filtering completed!")
diff --git a/utils/text_encoder.py b/utils/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9e0758abc7b4e1f452481cba9715df08ceab543
--- /dev/null
+++ b/utils/text_encoder.py
@@ -0,0 +1,304 @@
+import re
+import six
+from six.moves import range  # pylint: disable=redefined-builtin
+
+PAD = "<pad>"
+EOS = "<EOS>"
+UNK = "<UNK>"
+SEG = "|"
+RESERVED_TOKENS = [PAD, EOS, UNK]
+NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
+PAD_ID = RESERVED_TOKENS.index(PAD)  # Normally 0
+EOS_ID = RESERVED_TOKENS.index(EOS)  # Normally 1
+UNK_ID = RESERVED_TOKENS.index(UNK)  # Normally 2
+
+if six.PY2:
+    RESERVED_TOKENS_BYTES = RESERVED_TOKENS
+else:
+    RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
+
+# Regular expression for unescaping token strings.
+# '\u' is converted to '_'
+# '\\' is converted to '\'
+# '\213;' is converted to unichr(213)
+_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
+_ESCAPE_CHARS = set(u"\\_u;0123456789")
+
+
+def strip_ids(ids, ids_to_strip):
+    """Strip ids_to_strip from the end ids."""
+    ids = list(ids)
+    while ids and ids[-1] in ids_to_strip:
+        ids.pop()
+    return ids
+
+
+class TextEncoder(object):
+    """Base class for converting from ints to/from human readable strings."""
+
+    def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
+        self._num_reserved_ids = num_reserved_ids
+
+    @property
+    def num_reserved_ids(self):
+        return self._num_reserved_ids
+
+    def encode(self, s):
+        """Transform a human-readable string into a sequence of int ids.
+
+        The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
+        num_reserved_ids) are reserved.
+
+        EOS is not appended.
+
+        Args:
+        s: human-readable string to be converted.
+
+        Returns:
+        ids: list of integers
+        """
+        return [int(w) + self._num_reserved_ids for w in s.split()]
+
+    def decode(self, ids, strip_extraneous=False):
+        """Transform a sequence of int ids into a human-readable string.
+
+        EOS is not expected in ids.
+
+        Args:
+        ids: list of integers to be converted.
+        strip_extraneous: bool, whether to strip off extraneous tokens
+            (EOS and PAD).
+
+        Returns:
+        s: human-readable string.
+        """
+        if strip_extraneous:
+            ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
+        return " ".join(self.decode_list(ids))
+
+    def decode_list(self, ids):
+        """Transform a sequence of int ids into a their string versions.
+
+        This method supports transforming individual input/output ids to their
+        string versions so that sequence to/from text conversions can be visualized
+        in a human readable format.
+
+        Args:
+        ids: list of integers to be converted.
+
+        Returns:
+        strs: list of human-readable string.
+        """
+        decoded_ids = []
+        for id_ in ids:
+            if 0 <= id_ < self._num_reserved_ids:
+                decoded_ids.append(RESERVED_TOKENS[int(id_)])
+            else:
+                decoded_ids.append(id_ - self._num_reserved_ids)
+        return [str(d) for d in decoded_ids]
+
+    @property
+    def vocab_size(self):
+        raise NotImplementedError()
+
+
+class ByteTextEncoder(TextEncoder):
+    """Encodes each byte to an id. For 8-bit strings only."""
+
+    def encode(self, s):
+        numres = self._num_reserved_ids
+        if six.PY2:
+            if isinstance(s, unicode):
+                s = s.encode("utf-8")
+            return [ord(c) + numres for c in s]
+        # Python3: explicitly convert to UTF-8
+        return [c + numres for c in s.encode("utf-8")]
+
+    def decode(self, ids, strip_extraneous=False):
+        if strip_extraneous:
+            ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
+        numres = self._num_reserved_ids
+        decoded_ids = []
+        int2byte = six.int2byte
+        for id_ in ids:
+            if 0 <= id_ < numres:
+                decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
+            else:
+                decoded_ids.append(int2byte(id_ - numres))
+        if six.PY2:
+            return "".join(decoded_ids)
+        # Python3: join byte arrays and then decode string
+        return b"".join(decoded_ids).decode("utf-8", "replace")
+
+    def decode_list(self, ids):
+        numres = self._num_reserved_ids
+        decoded_ids = []
+        int2byte = six.int2byte
+        for id_ in ids:
+            if 0 <= id_ < numres:
+                decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
+            else:
+                decoded_ids.append(int2byte(id_ - numres))
+        # Python3: join byte arrays and then decode string
+        return decoded_ids
+
+    @property
+    def vocab_size(self):
+        return 2**8 + self._num_reserved_ids
+
+
+class ByteTextEncoderWithEos(ByteTextEncoder):
+  """Encodes each byte to an id and appends the EOS token."""
+
+  def encode(self, s):
+        return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID]
+
+
+class TokenTextEncoder(TextEncoder):
+    """Encoder based on a user-supplied vocabulary (file or list)."""
+
+    def __init__(self,
+               vocab_filename,
+               reverse=False,
+               vocab_list=None,
+               replace_oov=None,
+               num_reserved_ids=NUM_RESERVED_TOKENS):
+        """Initialize from a file or list, one token per line.
+
+        Handling of reserved tokens works as follows:
+        - When initializing from a list, we add reserved tokens to the vocab.
+        - When initializing from a file, we do not add reserved tokens to the vocab.
+        - When saving vocab files, we save reserved tokens to the file.
+
+        Args:
+            vocab_filename: If not None, the full filename to read vocab from. If this
+                is not None, then vocab_list should be None.
+            reverse: Boolean indicating if tokens should be reversed during encoding
+                and decoding.
+            vocab_list: If not None, a list of elements of the vocabulary. If this is
+                not None, then vocab_filename should be None.
+            replace_oov: If not None, every out-of-vocabulary token seen when
+                encoding will be replaced by this string (which must be in vocab).
+            num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
+        """
+        super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
+        self._reverse = reverse
+        self._replace_oov = replace_oov
+        if vocab_filename:
+            self._init_vocab_from_file(vocab_filename)
+        else:
+            assert vocab_list is not None
+            self._init_vocab_from_list(vocab_list)
+        self.pad_index = self._token_to_id[PAD]
+        self.eos_index = self._token_to_id[EOS]
+        self.unk_index = self._token_to_id[UNK]
+        self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index
+
+    def encode(self, s):
+        """Converts a space-separated string of tokens to a list of ids."""
+        sentence = s
+        tokens = sentence.strip().split()
+        if self._replace_oov is not None:
+            tokens = [t if t in self._token_to_id else self._replace_oov
+                        for t in tokens]
+        ret = [self._token_to_id[tok] for tok in tokens]
+        return ret[::-1] if self._reverse else ret
+
+    def decode(self, ids, strip_eos=False, strip_padding=False):
+        if strip_padding and self.pad() in list(ids):
+            pad_pos = list(ids).index(self.pad())
+            ids = ids[:pad_pos]
+        if strip_eos and self.eos() in list(ids):
+            eos_pos = list(ids).index(self.eos())
+            ids = ids[:eos_pos]
+        return " ".join(self.decode_list(ids))
+
+    def decode_list(self, ids):
+        seq = reversed(ids) if self._reverse else ids
+        return [self._safe_id_to_token(i) for i in seq]
+
+    @property
+    def vocab_size(self):
+        return len(self._id_to_token)
+
+    def __len__(self):
+        return self.vocab_size
+
+    def _safe_id_to_token(self, idx):
+        return self._id_to_token.get(idx, "ID_%d" % idx)
+
+    def _init_vocab_from_file(self, filename):
+        """Load vocab from a file.
+
+        Args:
+        filename: The file to load vocabulary from.
+        """
+        with open(filename) as f:
+            tokens = [token.strip() for token in f.readlines()]
+
+        def token_gen():
+            for token in tokens:
+                yield token
+
+        self._init_vocab(token_gen(), add_reserved_tokens=False)
+
+    def _init_vocab_from_list(self, vocab_list):
+        """Initialize tokens from a list of tokens.
+
+        It is ok if reserved tokens appear in the vocab list. They will be
+        removed. The set of tokens in vocab_list should be unique.
+
+        Args:
+        vocab_list: A list of tokens.
+        """
+        def token_gen():
+            for token in vocab_list:
+                if token not in RESERVED_TOKENS:
+                    yield token
+
+        self._init_vocab(token_gen())
+
+    def _init_vocab(self, token_generator, add_reserved_tokens=True):
+        """Initialize vocabulary with tokens from token_generator."""
+
+        self._id_to_token = {}
+        non_reserved_start_index = 0
+
+        if add_reserved_tokens:
+            self._id_to_token.update(enumerate(RESERVED_TOKENS))
+            non_reserved_start_index = len(RESERVED_TOKENS)
+
+        self._id_to_token.update(
+            enumerate(token_generator, start=non_reserved_start_index))
+
+        # _token_to_id is the reverse of _id_to_token
+        self._token_to_id = dict((v, k)
+                                for k, v in six.iteritems(self._id_to_token))
+
+    def pad(self):
+        return self.pad_index
+
+    def eos(self):
+        return self.eos_index
+
+    def unk(self):
+        return self.unk_index
+
+    def seg(self):
+        return self.seg_index
+
+    def store_to_file(self, filename):
+        """Write vocab file to disk.
+
+        Vocab files have one token per line. The file ends in a newline. Reserved
+        tokens are written to the vocab file as well.
+
+        Args:
+        filename: Full path of the file to store the vocab to.
+        """
+        with open(filename, "w") as f:
+            for i in range(len(self._id_to_token)):
+                f.write(self._id_to_token[i] + "\n")
+
+    def sil_phonemes(self):
+        return [p for p in self._id_to_token.values() if not p[0].isalpha()]
diff --git a/utils/text_norm.py b/utils/text_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0973cebc91e0525aeb6657e70012a1d37b5e6ff
--- /dev/null
+++ b/utils/text_norm.py
@@ -0,0 +1,790 @@
+# coding=utf-8
+# Authors:
+#   2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
+#   2019.9 Jiayu DU
+#
+# requirements:
+#   - python 3.X
+# notes: python 2.X WILL fail or produce misleading results
+
+import sys, os, argparse, codecs, string, re
+
+# ================================================================================ #
+#                                    basic constant
+# ================================================================================ #
+CHINESE_DIGIS = u'零一二三四五六七八九'
+BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
+BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
+
+ZERO_ALT = u'〇'
+ONE_ALT = u'幺'
+TWO_ALTS = [u'两', u'兩']
+
+POSITIVE = [u'正', u'正']
+NEGATIVE = [u'负', u'負']
+POINT = [u'点', u'點']
+# PLUS = [u'加', u'加']
+# SIL = [u'杠', u'槓']
+
+# 中文数字系统类型
+NUMBERING_TYPES = ['low', 'mid', 'high']
+
+CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
+                 '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
+CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
+COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
+                  '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
+                  '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
+                  '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
+                  '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
+                  '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
+
+# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
+CHINESE_PUNC_STOP = '!?。。'
+CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
+CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
+
+
+# ================================================================================ #
+#                                    basic class
+# ================================================================================ #
+class ChineseChar(object):
+    """
+    中文字符
+    每个字符对应简体和繁体,
+    e.g. 简体 = '负', 繁体 = '負'
+    转换时可转换为简体或繁体
+    """
+
+    def __init__(self, simplified, traditional):
+        self.simplified = simplified
+        self.traditional = traditional
+        # self.__repr__ = self.__str__
+
+    def __str__(self):
+        return self.simplified or self.traditional or None
+
+    def __repr__(self):
+        return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+    """
+    中文数字/数位字符
+    每个字符除繁简体外还有一个额外的大写字符
+    e.g. '陆' 和 '陸'
+    """
+
+    def __init__(self, power, simplified, traditional, big_s, big_t):
+        super(ChineseNumberUnit, self).__init__(simplified, traditional)
+        self.power = power
+        self.big_s = big_s
+        self.big_t = big_t
+
+    def __str__(self):
+        return '10^{}'.format(self.power)
+
+    @classmethod
+    def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+        if small_unit:
+            return ChineseNumberUnit(power=index + 1,
+                                     simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
+        elif numbering_type == NUMBERING_TYPES[0]:
+            return ChineseNumberUnit(power=index + 8,
+                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+        elif numbering_type == NUMBERING_TYPES[1]:
+            return ChineseNumberUnit(power=(index + 2) * 4,
+                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+        elif numbering_type == NUMBERING_TYPES[2]:
+            return ChineseNumberUnit(power=pow(2, index + 3),
+                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+        else:
+            raise ValueError(
+                'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
+
+
+class ChineseNumberDigit(ChineseChar):
+    """
+    中文数字字符
+    """
+
+    def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
+        super(ChineseNumberDigit, self).__init__(simplified, traditional)
+        self.value = value
+        self.big_s = big_s
+        self.big_t = big_t
+        self.alt_s = alt_s
+        self.alt_t = alt_t
+
+    def __str__(self):
+        return str(self.value)
+
+    @classmethod
+    def create(cls, i, v):
+        return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+    """
+    中文数位字符
+    """
+
+    def __init__(self, simplified, traditional, symbol, expression=None):
+        super(ChineseMath, self).__init__(simplified, traditional)
+        self.symbol = symbol
+        self.expression = expression
+        self.big_s = simplified
+        self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+    """
+    中文数字系统
+    """
+    pass
+
+
+class MathSymbol(object):
+    """
+    用于中文数字系统的数学符号 (繁/简体), e.g.
+    positive = ['正', '正']
+    negative = ['负', '負']
+    point = ['点', '點']
+    """
+
+    def __init__(self, positive, negative, point):
+        self.positive = positive
+        self.negative = negative
+        self.point = point
+
+    def __iter__(self):
+        for v in self.__dict__.values():
+            yield v
+
+
+# class OtherSymbol(object):
+#     """
+#     其他符号
+#     """
+#
+#     def __init__(self, sil):
+#         self.sil = sil
+#
+#     def __iter__(self):
+#         for v in self.__dict__.values():
+#             yield v
+
+
+# ================================================================================ #
+#                                    basic utils
+# ================================================================================ #
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+    """
+    根据数字系统类型返回创建相应的数字系统,默认为 mid
+    NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
+        low:  '兆' = '亿' * '十' = $10^{9}$,  '京' = '兆' * '十', etc.
+        mid:  '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
+        high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
+    返回对应的数字系统
+    """
+
+    # chinese number units of '亿' and larger
+    all_larger_units = zip(
+        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+    larger_units = [CNU.create(i, v, numbering_type, False)
+                    for i, v in enumerate(all_larger_units)]
+    # chinese number units of '十, 百, 千, 万'
+    all_smaller_units = zip(
+        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+    smaller_units = [CNU.create(i, v, small_unit=True)
+                     for i, v in enumerate(all_smaller_units)]
+    # digis
+    chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
+                        BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
+    digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+    digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+    digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+    digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+    # symbols
+    positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
+    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
+    point_cn = CM(POINT[0], POINT[1], '.', lambda x,
+                                                  y: float(str(x) + '.' + str(y)))
+    # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+    system = NumberSystem()
+    system.units = smaller_units + larger_units
+    system.digits = digits
+    system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+    # system.symbols = OtherSymbol(sil_cn)
+    return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+    def get_symbol(char, system):
+        for u in system.units:
+            if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+                return u
+        for d in system.digits:
+            if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
+                return d
+        for m in system.math:
+            if char in [m.traditional, m.simplified]:
+                return m
+
+    def string2symbols(chinese_string, system):
+        int_string, dec_string = chinese_string, ''
+        for p in [system.math.point.simplified, system.math.point.traditional]:
+            if p in chinese_string:
+                int_string, dec_string = chinese_string.split(p)
+                break
+        return [get_symbol(c, system) for c in int_string], \
+               [get_symbol(c, system) for c in dec_string]
+
+    def correct_symbols(integer_symbols, system):
+        """
+        一百八 to 一百八十
+        一亿一千三百万 to 一亿 一千万 三百万
+        """
+
+        if integer_symbols and isinstance(integer_symbols[0], CNU):
+            if integer_symbols[0].power == 1:
+                integer_symbols = [system.digits[1]] + integer_symbols
+
+        if len(integer_symbols) > 1:
+            if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
+                integer_symbols.append(
+                    CNU(integer_symbols[-2].power - 1, None, None, None, None))
+
+        result = []
+        unit_count = 0
+        for s in integer_symbols:
+            if isinstance(s, CND):
+                result.append(s)
+                unit_count = 0
+            elif isinstance(s, CNU):
+                current_unit = CNU(s.power, None, None, None, None)
+                unit_count += 1
+
+            if unit_count == 1:
+                result.append(current_unit)
+            elif unit_count > 1:
+                for i in range(len(result)):
+                    if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
+                        result[-i - 1] = CNU(result[-i - 1].power +
+                                             current_unit.power, None, None, None, None)
+        return result
+
+    def compute_value(integer_symbols):
+        """
+        Compute the value.
+        When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+        e.g. '两千万' = 2000 * 10000 not 2000 + 10000
+        """
+        value = [0]
+        last_power = 0
+        for s in integer_symbols:
+            if isinstance(s, CND):
+                value[-1] = s.value
+            elif isinstance(s, CNU):
+                value[-1] *= pow(10, s.power)
+                if s.power > last_power:
+                    value[:-1] = list(map(lambda v: v *
+                                                    pow(10, s.power), value[:-1]))
+                    last_power = s.power
+                value.append(0)
+        return sum(value)
+
+    system = create_system(numbering_type)
+    int_part, dec_part = string2symbols(chinese_string, system)
+    int_part = correct_symbols(int_part, system)
+    int_str = str(compute_value(int_part))
+    dec_str = ''.join([str(d.value) for d in dec_part])
+    if dec_part:
+        return '{0}.{1}'.format(int_str, dec_str)
+    else:
+        return int_str
+
+
+def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
+            traditional=False, alt_zero=False, alt_one=False, alt_two=True,
+            use_zeros=True, use_units=True):
+    def get_value(value_string, use_zeros=True):
+
+        striped_string = value_string.lstrip('0')
+
+        # record nothing if all zeros
+        if not striped_string:
+            return []
+
+        # record one digits
+        elif len(striped_string) == 1:
+            if use_zeros and len(value_string) != len(striped_string):
+                return [system.digits[0], system.digits[int(striped_string)]]
+            else:
+                return [system.digits[int(striped_string)]]
+
+        # recursively record multiple digits
+        else:
+            result_unit = next(u for u in reversed(
+                system.units) if u.power < len(striped_string))
+            result_string = value_string[:-result_unit.power]
+            return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
+
+    system = create_system(numbering_type)
+
+    int_dec = number_string.split('.')
+    if len(int_dec) == 1:
+        int_string = int_dec[0]
+        dec_string = ""
+    elif len(int_dec) == 2:
+        int_string = int_dec[0]
+        dec_string = int_dec[1]
+    else:
+        raise ValueError(
+            "invalid input num string with more than one dot: {}".format(number_string))
+
+    if use_units and len(int_string) > 1:
+        result_symbols = get_value(int_string)
+    else:
+        result_symbols = [system.digits[int(c)] for c in int_string]
+    dec_symbols = [system.digits[int(c)] for c in dec_string]
+    if dec_string:
+        result_symbols += [system.math.point] + dec_symbols
+
+    if alt_two:
+        liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
+                    system.digits[2].big_s, system.digits[2].big_t)
+        for i, v in enumerate(result_symbols):
+            if isinstance(v, CND) and v.value == 2:
+                next_symbol = result_symbols[i +
+                                             1] if i < len(result_symbols) - 1 else None
+                previous_symbol = result_symbols[i - 1] if i > 0 else None
+                if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
+                    if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
+                        result_symbols[i] = liang
+
+    # if big is True, '两' will not be used and `alt_two` has no impact on output
+    if big:
+        attr_name = 'big_'
+        if traditional:
+            attr_name += 't'
+        else:
+            attr_name += 's'
+    else:
+        if traditional:
+            attr_name = 'traditional'
+        else:
+            attr_name = 'simplified'
+
+    result = ''.join([getattr(s, attr_name) for s in result_symbols])
+
+    # if not use_zeros:
+    #     result = result.strip(getattr(system.digits[0], attr_name))
+
+    if alt_zero:
+        result = result.replace(
+            getattr(system.digits[0], attr_name), system.digits[0].alt_s)
+
+    if alt_one:
+        result = result.replace(
+            getattr(system.digits[1], attr_name), system.digits[1].alt_s)
+
+    for i, p in enumerate(POINT):
+        if result.startswith(p):
+            return CHINESE_DIGIS[0] + result
+
+    # ^10, 11, .., 19
+    if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+                                          SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
+            result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
+        result = result[1:]
+
+    return result
+
+
+# ================================================================================ #
+#                          different types of rewriters
+# ================================================================================ #
+class Cardinal:
+    """
+    CARDINAL类
+    """
+
+    def __init__(self, cardinal=None, chntext=None):
+        self.cardinal = cardinal
+        self.chntext = chntext
+
+    def chntext2cardinal(self):
+        return chn2num(self.chntext)
+
+    def cardinal2chntext(self):
+        return num2chn(self.cardinal)
+
+
+class Digit:
+    """
+    DIGIT类
+    """
+
+    def __init__(self, digit=None, chntext=None):
+        self.digit = digit
+        self.chntext = chntext
+
+    # def chntext2digit(self):
+    #     return chn2num(self.chntext)
+
+    def digit2chntext(self):
+        return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+class TelePhone:
+    """
+    TELEPHONE类
+    """
+
+    def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+        self.telephone = telephone
+        self.raw_chntext = raw_chntext
+        self.chntext = chntext
+
+    # def chntext2telephone(self):
+    #     sil_parts = self.raw_chntext.split('<SIL>')
+    #     self.telephone = '-'.join([
+    #         str(chn2num(p)) for p in sil_parts
+    #     ])
+    #     return self.telephone
+
+    def telephone2chntext(self, fixed=False):
+
+        if fixed:
+            sil_parts = self.telephone.split('-')
+            self.raw_chntext = '<SIL>'.join([
+                num2chn(part, alt_two=False, use_units=False) for part in sil_parts
+            ])
+            self.chntext = self.raw_chntext.replace('<SIL>', '')
+        else:
+            sp_parts = self.telephone.strip('+').split()
+            self.raw_chntext = '<SP>'.join([
+                num2chn(part, alt_two=False, use_units=False) for part in sp_parts
+            ])
+            self.chntext = self.raw_chntext.replace('<SP>', '')
+        return self.chntext
+
+
+class Fraction:
+    """
+    FRACTION类
+    """
+
+    def __init__(self, fraction=None, chntext=None):
+        self.fraction = fraction
+        self.chntext = chntext
+
+    def chntext2fraction(self):
+        denominator, numerator = self.chntext.split('分之')
+        return chn2num(numerator) + '/' + chn2num(denominator)
+
+    def fraction2chntext(self):
+        numerator, denominator = self.fraction.split('/')
+        return num2chn(denominator) + '分之' + num2chn(numerator)
+
+
+class Date:
+    """
+    DATE类
+    """
+
+    def __init__(self, date=None, chntext=None):
+        self.date = date
+        self.chntext = chntext
+
+    # def chntext2date(self):
+    #     chntext = self.chntext
+    #     try:
+    #         year, other = chntext.strip().split('年', maxsplit=1)
+    #         year = Digit(chntext=year).digit2chntext() + '年'
+    #     except ValueError:
+    #         other = chntext
+    #         year = ''
+    #     if other:
+    #         try:
+    #             month, day = other.strip().split('月', maxsplit=1)
+    #             month = Cardinal(chntext=month).chntext2cardinal() + '月'
+    #         except ValueError:
+    #             day = chntext
+    #             month = ''
+    #         if day:
+    #             day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+    #     else:
+    #         month = ''
+    #         day = ''
+    #     date = year + month + day
+    #     self.date = date
+    #     return self.date
+
+    def date2chntext(self):
+        date = self.date
+        try:
+            year, other = date.strip().split('年', 1)
+            year = Digit(digit=year).digit2chntext() + '年'
+        except ValueError:
+            other = date
+            year = ''
+        if other:
+            try:
+                month, day = other.strip().split('月', 1)
+                month = Cardinal(cardinal=month).cardinal2chntext() + '月'
+            except ValueError:
+                day = date
+                month = ''
+            if day:
+                day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+        else:
+            month = ''
+            day = ''
+        chntext = year + month + day
+        self.chntext = chntext
+        return self.chntext
+
+
+class Money:
+    """
+    MONEY类
+    """
+
+    def __init__(self, money=None, chntext=None):
+        self.money = money
+        self.chntext = chntext
+
+    # def chntext2money(self):
+    #     return self.money
+
+    def money2chntext(self):
+        money = self.money
+        pattern = re.compile(r'(\d+(\.\d+)?)')
+        matchers = pattern.findall(money)
+        if matchers:
+            for matcher in matchers:
+                money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
+        self.chntext = money
+        return self.chntext
+
+
+class Percentage:
+    """
+    PERCENTAGE类
+    """
+
+    def __init__(self, percentage=None, chntext=None):
+        self.percentage = percentage
+        self.chntext = chntext
+
+    def chntext2percentage(self):
+        return chn2num(self.chntext.strip().strip('百分之')) + '%'
+
+    def percentage2chntext(self):
+        return '百分之' + num2chn(self.percentage.strip().strip('%'))
+
+
+# ================================================================================ #
+#                            NSW Normalizer
+# ================================================================================ #
+class NSWNormalizer:
+    def __init__(self, raw_text):
+        self.raw_text = '^' + raw_text + '$'
+        self.norm_text = ''
+
+    def _particular(self):
+        text = self.norm_text
+        pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('particular')
+            for matcher in matchers:
+                text = text.replace(matcher[0], matcher[1] + '2' + matcher[2], 1)
+        self.norm_text = text
+        return self.norm_text
+
+    def normalize(self, remove_punc=True):
+        text = self.raw_text
+
+        # 规范化日期
+        pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('date')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+        # 规范化金钱
+        pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('money')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
+
+        # 规范化固话/手机号码
+        # 手机
+        # http://www.jihaoba.com/news/show/13680
+        # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
+        # 联通:130、131、132、156、155、186、185、176
+        # 电信:133、153、189、180、181、177
+        pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('telephone')
+            for matcher in matchers:
+                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
+        # 固话
+        pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('fixed telephone')
+            for matcher in matchers:
+                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
+
+        # 规范化分数
+        pattern = re.compile(r"(\d+/\d+)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('fraction')
+            for matcher in matchers:
+                text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
+
+        # 规范化百分数
+        text = text.replace('%', '%')
+        pattern = re.compile(r"(\d+(\.\d+)?%)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('percentage')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
+
+        # 规范化纯数+量词
+        pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('cardinal+quantifier')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+        # 规范化数字编号
+        pattern = re.compile(r"(\d{4,32})")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('digit')
+            for matcher in matchers:
+                text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+        # 规范化纯数
+        pattern = re.compile(r"(\d+(\.\d+)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('cardinal')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+        self.norm_text = text
+        self._particular()
+
+        text = self.norm_text.lstrip('^').rstrip('$')
+        if remove_punc:
+            # Punctuations removal
+            old_chars = CHINESE_PUNC_LIST + string.punctuation  # includes all CN and EN punctuations
+            new_chars = ' ' * len(old_chars)
+            del_chars = ''
+            text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
+        return text
+
+
+def nsw_test_case(raw_text):
+    print('I:' + raw_text)
+    print('O:' + NSWNormalizer(raw_text).normalize())
+    print('')
+
+
+def nsw_test():
+    nsw_test_case('固话:0595-23865596或23880880。')
+    nsw_test_case('固话:0595-23865596或23880880。')
+    nsw_test_case('手机:+86 19859213959或15659451527。')
+    nsw_test_case('分数:32477/76391。')
+    nsw_test_case('百分数:80.03%。')
+    nsw_test_case('编号:31520181154418。')
+    nsw_test_case('纯数:2983.07克或12345.60米。')
+    nsw_test_case('日期:1999年2月20日或09年3月15号。')
+    nsw_test_case('金钱:12块5,34.5元,20.1万')
+    nsw_test_case('特殊:O2O或B2C。')
+    nsw_test_case('3456万吨')
+    nsw_test_case('2938个')
+    nsw_test_case('938')
+    nsw_test_case('今天吃了115个小笼包231个馒头')
+    nsw_test_case('有62%的概率')
+
+
+if __name__ == '__main__':
+    # nsw_test()
+
+    p = argparse.ArgumentParser()
+    p.add_argument('ifile', help='input filename, assume utf-8 encoding')
+    p.add_argument('ofile', help='output filename')
+    p.add_argument('--to_upper', action='store_true', help='convert to upper case')
+    p.add_argument('--to_lower', action='store_true', help='convert to lower case')
+    p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
+    p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
+    args = p.parse_args()
+
+    ifile = codecs.open(args.ifile, 'r', 'utf8')
+    ofile = codecs.open(args.ofile, 'w+', 'utf8')
+
+    n = 0
+    for l in ifile:
+        key = ''
+        text = ''
+        if args.has_key:
+            cols = l.split(maxsplit=1)
+            key = cols[0]
+            if len(cols) == 2:
+                text = cols[1]
+            else:
+                text = ''
+        else:
+            text = l
+
+        # cases
+        if args.to_upper and args.to_lower:
+            sys.stderr.write('text norm: to_upper OR to_lower?')
+            exit(1)
+        if args.to_upper:
+            text = text.upper()
+        if args.to_lower:
+            text = text.lower()
+
+        # NSW(Non-Standard-Word) normalization
+        text = NSWNormalizer(text).normalize()
+
+        #
+        if args.has_key:
+            ofile.write(key + '\t' + text)
+        else:
+            ofile.write(text)
+
+        n += 1
+        if n % args.log_interval == 0:
+            sys.stderr.write("text norm: {} lines done.\n".format(n))
+
+    sys.stderr.write("text norm: {} lines done in total.\n".format(n))
+
+    ifile.close()
+    ofile.close()
diff --git a/utils/trainer.py b/utils/trainer.py
new file mode 100755
index 0000000000000000000000000000000000000000..6821fee1a4a08174bd3f3916dbc368fe89f1ba5b
--- /dev/null
+++ b/utils/trainer.py
@@ -0,0 +1,518 @@
+import random
+from torch.cuda.amp import GradScaler, autocast
+from utils import move_to_cuda
+import subprocess
+import numpy as np
+import torch.optim
+import torch.utils.data
+import copy
+import logging
+import os
+import re
+import sys
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import tqdm
+
+from utils.ckpt_utils import get_last_checkpoint, get_all_ckpts
+from utils.ddp_utils import DDP
+from utils.hparams import hparams
+
+
+class Trainer:
+    def __init__(
+            self,
+            work_dir,
+            default_save_path=None,
+            accumulate_grad_batches=1,
+            max_updates=160000,
+            print_nan_grads=False,
+            val_check_interval=2000,
+            num_sanity_val_steps=5,
+            amp=False,
+            # tb logger
+            log_save_interval=100,
+            tb_log_interval=10,
+            # checkpoint
+            monitor_key='val_loss',
+            monitor_mode='min',
+            num_ckpt_keep=5,
+            save_best=True,
+            resume_from_checkpoint=0,
+            seed=1234,
+            debug=False,
+    ):
+        os.makedirs(work_dir, exist_ok=True)
+        self.work_dir = work_dir
+        self.accumulate_grad_batches = accumulate_grad_batches
+        self.max_updates = max_updates
+        self.num_sanity_val_steps = num_sanity_val_steps
+        self.print_nan_grads = print_nan_grads
+        self.default_save_path = default_save_path
+        self.resume_from_checkpoint = resume_from_checkpoint if resume_from_checkpoint > 0 else None
+        self.seed = seed
+        self.debug = debug
+        # model and optm
+        self.task = None
+        self.optimizers = []
+
+        # trainer state
+        self.testing = False
+        self.global_step = 0
+        self.current_epoch = 0
+        self.total_batches = 0
+
+        # configure checkpoint
+        self.monitor_key = monitor_key
+        self.num_ckpt_keep = num_ckpt_keep
+        self.save_best = save_best
+        self.monitor_op = np.less if monitor_mode == 'min' else np.greater
+        self.best_val_results = np.Inf if monitor_mode == 'min' else -np.Inf
+        self.mode = 'min'
+
+        # allow int, string and gpu list
+        self.all_gpu_ids = [
+            int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
+        self.num_gpus = len(self.all_gpu_ids)
+        self.on_gpu = self.num_gpus > 0
+        self.root_gpu = 0
+        logging.info(f'GPU available: {torch.cuda.is_available()}, GPU used: {self.all_gpu_ids}')
+        self.use_ddp = self.num_gpus > 1
+        self.proc_rank = 0
+        # Tensorboard logging
+        self.log_save_interval = log_save_interval
+        self.val_check_interval = val_check_interval
+        self.tb_log_interval = tb_log_interval
+        self.amp = amp
+        self.amp_scalar = GradScaler()
+
+    def test(self, task_cls):
+        self.testing = True
+        self.fit(task_cls)
+
+    def fit(self, task_cls):
+        if len(self.all_gpu_ids) > 1:
+            mp.spawn(self.ddp_run, nprocs=self.num_gpus, args=(task_cls, copy.deepcopy(hparams)))
+        else:
+            self.task = task_cls()
+            self.task.trainer = self
+            self.run_single_process(self.task)
+        return 1
+
+    def ddp_run(self, gpu_idx, task_cls, hparams_):
+        hparams.update(hparams_)
+        task = task_cls()
+        self.ddp_init(gpu_idx, task)
+        self.run_single_process(task)
+
+    def run_single_process(self, task):
+        """Sanity check a few things before starting actual training.
+
+        :param task:
+        """
+        # build model, optm and load checkpoint
+        model = task.build_model()
+        if model is not None:
+            task.model = model
+        checkpoint, _ = get_last_checkpoint(self.work_dir, self.resume_from_checkpoint)
+        if checkpoint is not None:
+            self.restore_weights(checkpoint)
+        elif self.on_gpu:
+            task.cuda(self.root_gpu)
+        if not self.testing:
+            self.optimizers = task.configure_optimizers()
+            self.fisrt_epoch = True
+        if checkpoint is not None:
+            self.restore_opt_state(checkpoint)
+        del checkpoint
+        # clear cache after restore
+        if self.on_gpu:
+            torch.cuda.empty_cache()
+
+        if self.use_ddp:
+            self.task = self.configure_ddp(self.task)
+            dist.barrier()
+
+        task_ref = self.get_task_ref()
+        task_ref.trainer = self
+        task_ref.testing = self.testing
+        # link up experiment object
+        if self.proc_rank == 0:
+            task_ref.build_tensorboard(save_dir=self.work_dir, name='lightning_logs', version='lastest')
+        else:
+            os.makedirs('tmp', exist_ok=True)
+            task_ref.build_tensorboard(save_dir='tmp', name='tb_tmp', version='lastest')
+        self.logger = task_ref.logger
+        try:
+            if self.testing:
+                self.run_evaluation(test=True)
+            else:
+                self.train()
+        except KeyboardInterrupt as e:
+            task_ref.on_keyboard_interrupt()
+
+    ####################
+    # valid and test
+    ####################
+    def run_evaluation(self, test=False):
+        eval_results = self.evaluate(self.task, test, tqdm_desc='Valid' if not test else 'test')
+        if eval_results is not None and 'tb_log' in eval_results:
+            tb_log_output = eval_results['tb_log']
+            self.log_metrics_to_tb(tb_log_output)
+        if self.proc_rank == 0 and not test:
+            self.save_checkpoint(epoch=self.current_epoch, logs=eval_results)
+
+    def evaluate(self, task, test=False, tqdm_desc='Valid', max_batches=None):
+        # enable eval mode
+        task.zero_grad()
+        task.eval()
+        torch.set_grad_enabled(False)
+
+        task_ref = self.get_task_ref()
+        if test:
+            ret = task_ref.test_start()
+            if ret == 'EXIT':
+                return
+
+        outputs = []
+        dataloader = task_ref.test_dataloader() if test else task_ref.val_dataloader()
+        pbar = tqdm.tqdm(dataloader, desc=tqdm_desc, total=max_batches, dynamic_ncols=True, unit='step',
+                         disable=self.root_gpu > 0)
+        for batch_idx, batch in enumerate(pbar):
+            if batch is None:  # pragma: no cover
+                continue
+            # stop short when on fast_dev_run (sets max_batch=1)
+            if max_batches is not None and batch_idx >= max_batches:
+                break
+
+            # make dataloader_idx arg in validation_step optional
+            if self.on_gpu:
+                batch = move_to_cuda(batch, self.root_gpu)
+            args = [batch, batch_idx]
+            if self.use_ddp:
+                output = task(*args)
+            else:
+                if test:
+                    output = task_ref.test_step(*args)
+                else:
+                    output = task_ref.validation_step(*args)
+            # track outputs for collation
+            outputs.append(output)
+        # give model a chance to do something with the outputs (and method defined)
+        if test:
+            eval_results = task_ref.test_end(outputs)
+        else:
+            eval_results = task_ref.validation_end(outputs)
+        # enable train mode again
+        task.train()
+        torch.set_grad_enabled(True)
+        return eval_results
+
+    ####################
+    # train
+    ####################
+    def train(self):
+        task_ref = self.get_task_ref()
+        task_ref.on_train_start()
+        if self.num_sanity_val_steps > 0:
+            # run tiny validation (if validation defined) to make sure program won't crash during val
+            self.evaluate(self.task, False, 'Sanity Val', max_batches=self.num_sanity_val_steps)
+        # clear cache before training
+        if self.on_gpu:
+            torch.cuda.empty_cache()
+        dataloader = task_ref.train_dataloader()
+        epoch = self.current_epoch
+        # run all epochs
+        while True:
+            # set seed for distributed sampler (enables shuffling for each epoch)
+            if self.use_ddp and hasattr(dataloader.sampler, 'set_epoch'):
+                dataloader.sampler.set_epoch(epoch)
+            # update training progress in trainer and model
+            task_ref.current_epoch = epoch
+            self.current_epoch = epoch
+            # total batches includes multiple val checks
+            self.batch_loss_value = 0  # accumulated grads
+            # before epoch hook
+            task_ref.on_epoch_start()
+
+            # run epoch
+            train_pbar = tqdm.tqdm(dataloader, initial=self.global_step, total=float('inf'),
+                                   dynamic_ncols=True, unit='step', disable=self.root_gpu > 0)
+            for batch_idx, batch in enumerate(train_pbar):
+                pbar_metrics, tb_metrics = self.run_training_batch(batch_idx, batch)
+                train_pbar.set_postfix(**pbar_metrics)
+                should_check_val = (self.global_step % self.val_check_interval == 0
+                                    and not self.fisrt_epoch)
+                if should_check_val:
+                    self.run_evaluation()
+                self.fisrt_epoch = False
+                # when metrics should be logged
+                if (self.global_step + 1) % self.tb_log_interval == 0:
+                    # logs user requested information to logger
+                    self.log_metrics_to_tb(tb_metrics)
+
+                self.global_step += 1
+                task_ref.global_step = self.global_step
+                if self.global_step > self.max_updates:
+                    print("| Training end..")
+                    break
+            # epoch end hook
+            task_ref.on_epoch_end()
+            epoch += 1
+            if self.global_step > self.max_updates:
+                break
+        task_ref.on_train_end()
+
+    def run_training_batch(self, batch_idx, batch):
+        if batch is None:
+            return {}
+        all_progress_bar_metrics = []
+        all_log_metrics = []
+        task_ref = self.get_task_ref()
+        for opt_idx, optimizer in enumerate(self.optimizers):
+            if optimizer is None:
+                continue
+            # make sure only the gradients of the current optimizer's paramaters are calculated
+            # in the training step to prevent dangling gradients in multiple-optimizer setup.
+            if len(self.optimizers) > 1:
+                for param in task_ref.parameters():
+                    param.requires_grad = False
+                for group in optimizer.param_groups:
+                    for param in group['params']:
+                        param.requires_grad = True
+
+            # forward pass
+            with autocast(enabled=self.amp):
+                if self.on_gpu:
+                    batch = move_to_cuda(copy.copy(batch), self.root_gpu)
+                args = [batch, batch_idx, opt_idx]
+                if self.use_ddp:
+                    output = self.task(*args)
+                else:
+                    output = task_ref.training_step(*args)
+                loss = output['loss']
+                if loss is None:
+                    continue
+                progress_bar_metrics = output['progress_bar']
+                log_metrics = output['tb_log']
+                # accumulate loss
+                loss = loss / self.accumulate_grad_batches
+
+            # backward pass
+            if loss.requires_grad:
+                if self.amp:
+                    self.amp_scalar.scale(loss).backward()
+                else:
+                    loss.backward()
+
+            # track progress bar metrics
+            all_log_metrics.append(log_metrics)
+            all_progress_bar_metrics.append(progress_bar_metrics)
+
+            if loss is None:
+                continue
+
+            # nan grads
+            if self.print_nan_grads:
+                has_nan_grad = False
+                for name, param in task_ref.named_parameters():
+                    if (param.grad is not None) and torch.isnan(param.grad.float()).any():
+                        print("| NaN params: ", name, param, param.grad)
+                        has_nan_grad = True
+                if has_nan_grad:
+                    exit(0)
+
+            # gradient update with accumulated gradients
+            if (self.global_step + 1) % self.accumulate_grad_batches == 0:
+                task_ref.on_before_optimization(opt_idx)
+                if self.amp:
+                    self.amp_scalar.step(optimizer)
+                    self.amp_scalar.update()
+                else:
+                    optimizer.step()
+                optimizer.zero_grad()
+                task_ref.on_after_optimization(self.current_epoch, batch_idx, optimizer, opt_idx)
+
+        # collapse all metrics into one dict
+        all_progress_bar_metrics = {k: v for d in all_progress_bar_metrics for k, v in d.items()}
+        all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
+        return all_progress_bar_metrics, all_log_metrics
+
+    ####################
+    # load and save checkpoint
+    ####################
+    def restore_weights(self, checkpoint):
+        # load model state
+        task_ref = self.get_task_ref()
+
+        if len([k for k in checkpoint['state_dict'].keys() if '.' in k]) > 0:
+            task_ref.load_state_dict(checkpoint['state_dict'])
+        else:
+            for k, v in checkpoint['state_dict'].items():
+                getattr(task_ref, k).load_state_dict(v)
+
+        if self.on_gpu:
+            task_ref.cuda(self.root_gpu)
+        # load training state (affects trainer only)
+        self.best_val_results = checkpoint['checkpoint_callback_best']
+        self.global_step = checkpoint['global_step']
+        self.current_epoch = checkpoint['epoch']
+        task_ref.global_step = self.global_step
+
+        # wait for all models to restore weights
+        if self.use_ddp:
+            # wait for all processes to catch up
+            dist.barrier()
+
+    def restore_opt_state(self, checkpoint):
+        if self.testing:
+            return
+        # restore the optimizers
+        optimizer_states = checkpoint['optimizer_states']
+        for optimizer, opt_state in zip(self.optimizers, optimizer_states):
+            if optimizer is None:
+                return
+            try:
+                optimizer.load_state_dict(opt_state)
+                # move optimizer to GPU 1 weight at a time
+                if self.on_gpu:
+                    for state in optimizer.state.values():
+                        for k, v in state.items():
+                            if isinstance(v, torch.Tensor):
+                                state[k] = v.cuda(self.root_gpu)
+            except ValueError:
+                print("| WARMING: optimizer parameters not match !!!")
+        try:
+            if dist.is_initialized() and dist.get_rank() > 0:
+                return
+        except Exception as e:
+            print(e)
+            return
+        did_restore = True
+        return did_restore
+
+    def save_checkpoint(self, epoch, logs=None):
+        monitor_op = np.less
+        ckpt_path = f'{self.work_dir}/model_ckpt_steps_{self.global_step}.ckpt'
+        logging.info(f'Epoch {epoch:05d}@{self.global_step}: saving model to {ckpt_path}')
+        self._atomic_save(ckpt_path)
+        for old_ckpt in get_all_ckpts(self.work_dir)[self.num_ckpt_keep:]:
+            subprocess.check_call(f'rm -rf "{old_ckpt}"', shell=True)
+            logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}')
+        current = None
+        if logs is not None and self.monitor_key in logs:
+            current = logs[self.monitor_key]
+        if current is not None and self.save_best:
+            if monitor_op(current, self.best_val_results):
+                best_filepath = f'{self.work_dir}/model_ckpt_best.pt'
+                self.best_val_results = current
+                logging.info(
+                    f'Epoch {epoch:05d}@{self.global_step}: {self.monitor_key} reached {current:0.5f}. '
+                    f'Saving model to {best_filepath}')
+                self._atomic_save(best_filepath)
+
+    def _atomic_save(self, filepath):
+        checkpoint = self.dump_checkpoint()
+        tmp_path = str(filepath) + ".part"
+        torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
+        os.replace(tmp_path, filepath)
+
+    def dump_checkpoint(self):
+        checkpoint = {'epoch': self.current_epoch, 'global_step': self.global_step,
+                      'checkpoint_callback_best': self.best_val_results}
+        # save optimizers
+        optimizer_states = []
+        for i, optimizer in enumerate(self.optimizers):
+            if optimizer is not None:
+                optimizer_states.append(optimizer.state_dict())
+
+        checkpoint['optimizer_states'] = optimizer_states
+        task_ref = self.get_task_ref()
+        checkpoint['state_dict'] = {
+            k: v.state_dict() for k, v in task_ref.named_children() if len(list(v.parameters())) > 0}
+        return checkpoint
+
+    ####################
+    # DDP
+    ####################
+    def ddp_init(self, gpu_idx, task):
+        # determine which process we are and world size
+        self.proc_rank = gpu_idx
+        task.trainer = self
+        self.init_ddp_connection(self.proc_rank, self.num_gpus)
+
+        # copy model to each gpu
+        torch.cuda.set_device(gpu_idx)
+        # override root GPU
+        self.root_gpu = gpu_idx
+        self.task = task
+
+    def configure_ddp(self, task):
+        task = DDP(task, device_ids=[self.root_gpu], find_unused_parameters=True)
+        if dist.get_rank() != 0 and not self.debug:
+            sys.stdout = open(os.devnull, "w")
+            sys.stderr = open(os.devnull, "w")
+        random.seed(self.seed)
+        np.random.seed(self.seed)
+        return task
+
+    def init_ddp_connection(self, proc_rank, world_size):
+        root_node = '127.0.0.1'
+        root_node = self.resolve_root_node_address(root_node)
+        os.environ['MASTER_ADDR'] = root_node
+        dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
+
+    def resolve_root_node_address(self, root_node):
+        if '[' in root_node:
+            name = root_node.split('[')[0]
+            number = root_node.split(',')[0]
+            if '-' in number:
+                number = number.split('-')[0]
+            number = re.sub('[^0-9]', '', number)
+            root_node = name + number
+        return root_node
+
+    ####################
+    # utils
+    ####################
+    def get_task_ref(self):
+        from tasks.base_task import BaseTask
+        task: BaseTask = self.task.module if isinstance(self.task, DDP) else self.task
+        return task
+
+    def log_metrics_to_tb(self, metrics, step=None):
+        """Logs the metric dict passed in.
+
+        :param metrics:
+        """
+        # added metrics by Lightning for convenience
+        metrics['epoch'] = self.current_epoch
+
+        # turn all tensors to scalars
+        scalar_metrics = self.metrics_to_scalars(metrics)
+
+        step = step if step is not None else self.global_step
+        # log actual metrics
+        if self.proc_rank == 0:
+            self.log_metrics(self.logger, scalar_metrics, step=step)
+
+    @staticmethod
+    def log_metrics(logger, metrics, step=None):
+        for k, v in metrics.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            logger.add_scalar(k, v, step)
+
+    def metrics_to_scalars(self, metrics):
+        new_metrics = {}
+        for k, v in metrics.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+
+            if type(v) is dict:
+                v = self.metrics_to_scalars(v)
+
+            new_metrics[k] = v
+
+        return new_metrics
diff --git a/utils/training_utils.py b/utils/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..409b15388790b1aadb24632313bdd1f41b4b06ac
--- /dev/null
+++ b/utils/training_utils.py
@@ -0,0 +1,27 @@
+from utils.hparams import hparams
+
+
+class RSQRTSchedule(object):
+    def __init__(self, optimizer):
+        super().__init__()
+        self.optimizer = optimizer
+        self.constant_lr = hparams['lr']
+        self.warmup_updates = hparams['warmup_updates']
+        self.hidden_size = hparams['hidden_size']
+        self.lr = hparams['lr']
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = self.lr
+        self.step(0)
+
+    def step(self, num_updates):
+        constant_lr = self.constant_lr
+        warmup = min(num_updates / self.warmup_updates, 1.0)
+        rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
+        rsqrt_hidden = self.hidden_size ** -0.5
+        self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
+        for param_group in self.optimizer.param_groups:
+            param_group['lr'] = self.lr
+        return self.lr
+
+    def get_lr(self):
+        return self.optimizer.param_groups[0]['lr']
diff --git a/utils/tts_utils.py b/utils/tts_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..9da2385ba52ce735a2d3c46ad8743d4a5bb7cd5c
--- /dev/null
+++ b/utils/tts_utils.py
@@ -0,0 +1,371 @@
+from collections import defaultdict
+import torch
+import torch.nn.functional as F
+
+
+def make_positions(tensor, padding_idx):
+    """Replace non-padding symbols with their position numbers.
+
+    Position numbers begin at padding_idx+1. Padding symbols are ignored.
+    """
+    # The series of casts and type-conversions here are carefully
+    # balanced to both work with ONNX export and XLA. In particular XLA
+    # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+    # how to handle the dtype kwarg in cumsum.
+    mask = tensor.ne(padding_idx).int()
+    return (
+                   torch.cumsum(mask, dim=1).type_as(mask) * mask
+           ).long() + padding_idx
+
+
+def softmax(x, dim):
+    return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+def sequence_mask(lengths, maxlen, dtype=torch.bool):
+    if maxlen is None:
+        maxlen = lengths.max()
+    mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
+    mask.type(dtype)
+    return mask
+
+
+INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
+
+
+def _get_full_incremental_state_key(module_instance, key):
+    module_name = module_instance.__class__.__name__
+
+    # assign a unique ID to each module instance, so that incremental state is
+    # not shared across module instances
+    if not hasattr(module_instance, '_instance_id'):
+        INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
+        module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
+
+    return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
+
+
+def get_incremental_state(module, incremental_state, key):
+    """Helper for getting incremental state for an nn.Module."""
+    full_key = _get_full_incremental_state_key(module, key)
+    if incremental_state is None or full_key not in incremental_state:
+        return None
+    return incremental_state[full_key]
+
+
+def set_incremental_state(module, incremental_state, key, value):
+    """Helper for setting incremental state for an nn.Module."""
+    if incremental_state is not None:
+        full_key = _get_full_incremental_state_key(module, key)
+        incremental_state[full_key] = value
+
+
+def fill_with_neg_inf(t):
+    """FP16-compatible function that fills a tensor with -inf."""
+    return t.float().fill_(float('-inf')).type_as(t)
+
+
+def fill_with_neg_inf2(t):
+    """FP16-compatible function that fills a tensor with -inf."""
+    return t.float().fill_(-1e8).type_as(t)
+
+
+def get_focus_rate(attn, src_padding_mask=None, tgt_padding_mask=None):
+    '''
+    attn: bs x L_t x L_s
+    '''
+    if src_padding_mask is not None:
+        attn = attn * (1 - src_padding_mask.float())[:, None, :]
+
+    if tgt_padding_mask is not None:
+        attn = attn * (1 - tgt_padding_mask.float())[:, :, None]
+
+    focus_rate = attn.max(-1).values.sum(-1)
+    focus_rate = focus_rate / attn.sum(-1).sum(-1)
+    return focus_rate
+
+
+def get_phone_coverage_rate(attn, src_padding_mask=None, src_seg_mask=None, tgt_padding_mask=None):
+    '''
+    attn: bs x L_t x L_s
+    '''
+    src_mask = attn.new(attn.size(0), attn.size(-1)).bool().fill_(False)
+    if src_padding_mask is not None:
+        src_mask |= src_padding_mask
+    if src_seg_mask is not None:
+        src_mask |= src_seg_mask
+
+    attn = attn * (1 - src_mask.float())[:, None, :]
+    if tgt_padding_mask is not None:
+        attn = attn * (1 - tgt_padding_mask.float())[:, :, None]
+
+    phone_coverage_rate = attn.max(1).values.sum(-1)
+    # phone_coverage_rate = phone_coverage_rate / attn.sum(-1).sum(-1)
+    phone_coverage_rate = phone_coverage_rate / (1 - src_mask.float()).sum(-1)
+    return phone_coverage_rate
+
+
+def get_diagonal_focus_rate(attn, attn_ks, target_len, src_padding_mask=None, tgt_padding_mask=None,
+                            band_mask_factor=5, band_width=50):
+    '''
+    attn: bx x L_t x L_s
+    attn_ks: shape: tensor with shape [batch_size], input_lens/output_lens
+
+    diagonal: y=k*x (k=attn_ks, x:output, y:input)
+    1 0 0
+    0 1 0
+    0 0 1
+    y>=k*(x-width) and y<=k*(x+width):1
+    else:0
+    '''
+    # width = min(target_len/band_mask_factor, 50)
+    width1 = target_len / band_mask_factor
+    width2 = target_len.new(target_len.size()).fill_(band_width)
+    width = torch.where(width1 < width2, width1, width2).float()
+    base = torch.ones(attn.size()).to(attn.device)
+    zero = torch.zeros(attn.size()).to(attn.device)
+    x = torch.arange(0, attn.size(1)).to(attn.device)[None, :, None].float() * base
+    y = torch.arange(0, attn.size(2)).to(attn.device)[None, None, :].float() * base
+    cond = (y - attn_ks[:, None, None] * x)
+    cond1 = cond + attn_ks[:, None, None] * width[:, None, None]
+    cond2 = cond - attn_ks[:, None, None] * width[:, None, None]
+    mask1 = torch.where(cond1 < 0, zero, base)
+    mask2 = torch.where(cond2 > 0, zero, base)
+    mask = mask1 * mask2
+
+    if src_padding_mask is not None:
+        attn = attn * (1 - src_padding_mask.float())[:, None, :]
+    if tgt_padding_mask is not None:
+        attn = attn * (1 - tgt_padding_mask.float())[:, :, None]
+
+    diagonal_attn = attn * mask
+    diagonal_focus_rate = diagonal_attn.sum(-1).sum(-1) / attn.sum(-1).sum(-1)
+    return diagonal_focus_rate, mask
+
+
+def select_attn(attn_logits, type='best'):
+    """
+
+    :param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
+    :return:
+    """
+    encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
+    # [n_layers * n_head, B, T_sp, T_txt]
+    encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
+    if type == 'best':
+        indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
+        encdec_attn = encdec_attn.gather(
+            0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
+        return encdec_attn
+    elif type == 'mean':
+        return encdec_attn.mean(0)
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1):
+    """Make mask tensor containing indices of padded part.
+    Args:
+        lengths (LongTensor or List): Batch of lengths (B,).
+        xs (Tensor, optional): The reference tensor.
+            If set, masks will be the same shape as this tensor.
+        length_dim (int, optional): Dimension indicator of the above tensor.
+            See the example.
+    Returns:
+        Tensor: Mask tensor containing indices of padded part.
+                dtype=torch.uint8 in PyTorch 1.2-
+                dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+    Examples:
+        With only lengths.
+        >>> lengths = [5, 3, 2]
+        >>> make_non_pad_mask(lengths)
+        masks = [[0, 0, 0, 0 ,0],
+                 [0, 0, 0, 1, 1],
+                 [0, 0, 1, 1, 1]]
+        With the reference tensor.
+        >>> xs = torch.zeros((3, 2, 4))
+        >>> make_pad_mask(lengths, xs)
+        tensor([[[0, 0, 0, 0],
+                 [0, 0, 0, 0]],
+                [[0, 0, 0, 1],
+                 [0, 0, 0, 1]],
+                [[0, 0, 1, 1],
+                 [0, 0, 1, 1]]], dtype=torch.uint8)
+        >>> xs = torch.zeros((3, 2, 6))
+        >>> make_pad_mask(lengths, xs)
+        tensor([[[0, 0, 0, 0, 0, 1],
+                 [0, 0, 0, 0, 0, 1]],
+                [[0, 0, 0, 1, 1, 1],
+                 [0, 0, 0, 1, 1, 1]],
+                [[0, 0, 1, 1, 1, 1],
+                 [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+        With the reference tensor and dimension indicator.
+        >>> xs = torch.zeros((3, 6, 6))
+        >>> make_pad_mask(lengths, xs, 1)
+        tensor([[[0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [1, 1, 1, 1, 1, 1]],
+                [[0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1]],
+                [[0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
+        >>> make_pad_mask(lengths, xs, 2)
+        tensor([[[0, 0, 0, 0, 0, 1],
+                 [0, 0, 0, 0, 0, 1],
+                 [0, 0, 0, 0, 0, 1],
+                 [0, 0, 0, 0, 0, 1],
+                 [0, 0, 0, 0, 0, 1],
+                 [0, 0, 0, 0, 0, 1]],
+                [[0, 0, 0, 1, 1, 1],
+                 [0, 0, 0, 1, 1, 1],
+                 [0, 0, 0, 1, 1, 1],
+                 [0, 0, 0, 1, 1, 1],
+                 [0, 0, 0, 1, 1, 1],
+                 [0, 0, 0, 1, 1, 1]],
+                [[0, 0, 1, 1, 1, 1],
+                 [0, 0, 1, 1, 1, 1],
+                 [0, 0, 1, 1, 1, 1],
+                 [0, 0, 1, 1, 1, 1],
+                 [0, 0, 1, 1, 1, 1],
+                 [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+    """
+    if length_dim == 0:
+        raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+    if not isinstance(lengths, list):
+        lengths = lengths.tolist()
+    bs = int(len(lengths))
+    if xs is None:
+        maxlen = int(max(lengths))
+    else:
+        maxlen = xs.size(length_dim)
+
+    seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+    seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+    seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+    mask = seq_range_expand >= seq_length_expand
+
+    if xs is not None:
+        assert xs.size(0) == bs, (xs.size(0), bs)
+
+        if length_dim < 0:
+            length_dim = xs.dim() + length_dim
+        # ind = (:, None, ..., None, :, , None, ..., None)
+        ind = tuple(
+            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+        )
+        mask = mask[ind].expand_as(xs).to(xs.device)
+    return mask
+
+
+def make_non_pad_mask(lengths, xs=None, length_dim=-1):
+    """Make mask tensor containing indices of non-padded part.
+    Args:
+        lengths (LongTensor or List): Batch of lengths (B,).
+        xs (Tensor, optional): The reference tensor.
+            If set, masks will be the same shape as this tensor.
+        length_dim (int, optional): Dimension indicator of the above tensor.
+            See the example.
+    Returns:
+        ByteTensor: mask tensor containing indices of padded part.
+                    dtype=torch.uint8 in PyTorch 1.2-
+                    dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+    Examples:
+        With only lengths.
+        >>> lengths = [5, 3, 2]
+        >>> make_non_pad_mask(lengths)
+        masks = [[1, 1, 1, 1 ,1],
+                 [1, 1, 1, 0, 0],
+                 [1, 1, 0, 0, 0]]
+        With the reference tensor.
+        >>> xs = torch.zeros((3, 2, 4))
+        >>> make_non_pad_mask(lengths, xs)
+        tensor([[[1, 1, 1, 1],
+                 [1, 1, 1, 1]],
+                [[1, 1, 1, 0],
+                 [1, 1, 1, 0]],
+                [[1, 1, 0, 0],
+                 [1, 1, 0, 0]]], dtype=torch.uint8)
+        >>> xs = torch.zeros((3, 2, 6))
+        >>> make_non_pad_mask(lengths, xs)
+        tensor([[[1, 1, 1, 1, 1, 0],
+                 [1, 1, 1, 1, 1, 0]],
+                [[1, 1, 1, 0, 0, 0],
+                 [1, 1, 1, 0, 0, 0]],
+                [[1, 1, 0, 0, 0, 0],
+                 [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+        With the reference tensor and dimension indicator.
+        >>> xs = torch.zeros((3, 6, 6))
+        >>> make_non_pad_mask(lengths, xs, 1)
+        tensor([[[1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [0, 0, 0, 0, 0, 0]],
+                [[1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0]],
+                [[1, 1, 1, 1, 1, 1],
+                 [1, 1, 1, 1, 1, 1],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0],
+                 [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
+        >>> make_non_pad_mask(lengths, xs, 2)
+        tensor([[[1, 1, 1, 1, 1, 0],
+                 [1, 1, 1, 1, 1, 0],
+                 [1, 1, 1, 1, 1, 0],
+                 [1, 1, 1, 1, 1, 0],
+                 [1, 1, 1, 1, 1, 0],
+                 [1, 1, 1, 1, 1, 0]],
+                [[1, 1, 1, 0, 0, 0],
+                 [1, 1, 1, 0, 0, 0],
+                 [1, 1, 1, 0, 0, 0],
+                 [1, 1, 1, 0, 0, 0],
+                 [1, 1, 1, 0, 0, 0],
+                 [1, 1, 1, 0, 0, 0]],
+                [[1, 1, 0, 0, 0, 0],
+                 [1, 1, 0, 0, 0, 0],
+                 [1, 1, 0, 0, 0, 0],
+                 [1, 1, 0, 0, 0, 0],
+                 [1, 1, 0, 0, 0, 0],
+                 [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+    """
+    return ~make_pad_mask(lengths, xs, length_dim)
+
+
+def get_mask_from_lengths(lengths):
+    max_len = torch.max(lengths).item()
+    ids = torch.arange(0, max_len).to(lengths.device)
+    mask = (ids < lengths.unsqueeze(1)).bool()
+    return mask
+
+
+def group_hidden_by_segs(h, seg_ids, max_len):
+    """
+
+    :param h: [B, T, H]
+    :param seg_ids: [B, T]
+    :return: h_ph: [B, T_ph, H]
+    """
+    B, T, H = h.shape
+    h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
+    all_ones = h.new_ones(h.shape[:2])
+    cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
+    h_gby_segs = h_gby_segs[:, 1:]
+    cnt_gby_segs = cnt_gby_segs[:, 1:]
+    h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
+    return h_gby_segs, cnt_gby_segs
diff --git a/vocoders/__init__.py b/vocoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4abf21d1cd113f65d353f0101e3550de3bac3
--- /dev/null
+++ b/vocoders/__init__.py
@@ -0,0 +1,2 @@
+from vocoders import hifigan
+from vocoders import fastdiff
diff --git a/vocoders/base_vocoder.py b/vocoders/base_vocoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe49a9e4f790ecdc5e76d60a23f96602b59fc48d
--- /dev/null
+++ b/vocoders/base_vocoder.py
@@ -0,0 +1,39 @@
+import importlib
+VOCODERS = {}
+
+
+def register_vocoder(cls):
+    VOCODERS[cls.__name__.lower()] = cls
+    VOCODERS[cls.__name__] = cls
+    return cls
+
+
+def get_vocoder_cls(hparams):
+    if hparams['vocoder'] in VOCODERS:
+        return VOCODERS[hparams['vocoder']]
+    else:
+        vocoder_cls = hparams['vocoder']
+        pkg = ".".join(vocoder_cls.split(".")[:-1])
+        cls_name = vocoder_cls.split(".")[-1]
+        vocoder_cls = getattr(importlib.import_module(pkg), cls_name)
+        return vocoder_cls
+
+
+class BaseVocoder:
+    def spec2wav(self, mel):
+        """
+
+        :param mel: [T, 80]
+        :return: wav: [T']
+        """
+
+        raise NotImplementedError
+
+    @staticmethod
+    def wav2spec(wav_fn):
+        """
+
+        :param wav_fn: str
+        :return: wav, mel: [T, 80]
+        """
+        raise NotImplementedError
diff --git a/vocoders/fastdiff.py b/vocoders/fastdiff.py
new file mode 100644
index 0000000000000000000000000000000000000000..1769085832bfc902eeff0155b788141ae194e85e
--- /dev/null
+++ b/vocoders/fastdiff.py
@@ -0,0 +1,162 @@
+import glob
+import re
+import librosa
+import torch
+import yaml
+from sklearn.preprocessing import StandardScaler
+from torch import nn
+from modules.FastDiff.module.FastDiff_model import FastDiff as FastDiff_model
+from utils.hparams import hparams
+from modules.parallel_wavegan.utils import read_hdf5
+from vocoders.base_vocoder import BaseVocoder, register_vocoder
+import numpy as np
+from modules.FastDiff.module.util import theta_timestep_loss, compute_hyperparams_given_schedule, sampling_given_noise_schedule
+
+def load_fastdiff_model(config_path, checkpoint_path):
+    # load config
+    with open(config_path) as f:
+        config = yaml.load(f, Loader=yaml.Loader)
+
+    # setup
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+    model = FastDiff_model(audio_channels=config['audio_channels'],
+                 inner_channels=config['inner_channels'],
+                 cond_channels=config['cond_channels'],
+                 upsample_ratios=config['upsample_ratios'],
+                 lvc_layers_each_block=config['lvc_layers_each_block'],
+                 lvc_kernel_size=config['lvc_kernel_size'],
+                 kpnet_hidden_channels=config['kpnet_hidden_channels'],
+                 kpnet_conv_size=config['kpnet_conv_size'],
+                 dropout=config['dropout'],
+                 diffusion_step_embed_dim_in=config['diffusion_step_embed_dim_in'],
+                 diffusion_step_embed_dim_mid=config['diffusion_step_embed_dim_mid'],
+                 diffusion_step_embed_dim_out=config['diffusion_step_embed_dim_out'],
+                 use_weight_norm=config['use_weight_norm'])
+
+    model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"]["model"], strict=True)
+
+    # Init hyperparameters by linear schedule
+    noise_schedule = torch.linspace(float(config["beta_0"]), float(config["beta_T"]), int(config["T"])).cuda()
+    diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule)
+
+    # map diffusion hyperparameters to gpu
+    for key in diffusion_hyperparams:
+        if key in ["beta", "alpha", "sigma"]:
+            diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()
+    diffusion_hyperparams = diffusion_hyperparams
+
+
+    if config['noise_schedule'] != '':
+        noise_schedule = config['noise_schedule']
+        if isinstance(noise_schedule, list):
+            noise_schedule = torch.FloatTensor(noise_schedule).cuda()
+    else:
+        # Select Schedule
+        try:
+            reverse_step = int(hparams.get('N'))
+        except:
+            print('Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.')
+            reverse_step = 4
+        if reverse_step == 1000:
+            noise_schedule = torch.linspace(0.000001, 0.01, 1000).cuda()
+        elif reverse_step == 200:
+            noise_schedule = torch.linspace(0.0001, 0.02, 200).cuda()
+
+        # Below are schedules derived by Noise Predictor
+        elif reverse_step == 8:
+            noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513,
+                             0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5]
+        elif reverse_step == 6:
+            noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984,
+                              0.006634317338466644, 0.09357017278671265, 0.6000000238418579]
+        elif reverse_step == 4:
+            noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01]
+        elif reverse_step == 3:
+            noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01]
+        else:
+            raise NotImplementedError
+
+    if isinstance(noise_schedule, list):
+        noise_schedule = torch.FloatTensor(noise_schedule).cuda()
+
+    model.remove_weight_norm()
+    model = model.eval().to(device)
+    print(f"| Loaded model parameters from {checkpoint_path}.")
+    print(f"| FastDiff device: {device}.")
+    return model, diffusion_hyperparams, noise_schedule, config, device
+
+
+@register_vocoder
+class FastDiff(BaseVocoder):
+    def __init__(self):
+        if hparams['vocoder_ckpt'] == '':  # load LJSpeech FastDiff pretrained model
+            base_dir = 'checkpoint/FastDiff'
+            config_path = f'{base_dir}/config.yaml'
+            ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
+            lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
+            print('| load FastDiff: ', ckpt)
+            self.scaler = None
+            self.model, self.dh, self.noise_schedule, self.config, self.device = load_fastdiff_model(
+                config_path=config_path,
+                checkpoint_path=ckpt,
+            )
+        else:
+            base_dir = hparams['vocoder_ckpt']
+            print(base_dir)
+            config_path = f'{base_dir}/config.yaml'
+            ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
+            lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
+            print('| load FastDiff: ', ckpt)
+            self.scaler = None
+            self.model, self.dh, self.noise_schedule, self.config, self.device = load_fastdiff_model(
+                config_path=config_path,
+                checkpoint_path=ckpt,
+            )
+
+    def spec2wav(self, mel, **kwargs):
+        # start generation
+        device = self.device
+        with torch.no_grad():
+            c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device)
+            audio_length = c.shape[-1] * hparams["hop_size"]
+            y = sampling_given_noise_schedule(
+                self.model, (1, 1, audio_length), self.dh, self.noise_schedule, condition=c, ddim=False, return_sequence=False)
+        wav_out = y.cpu().numpy()
+        return wav_out
+
+    @staticmethod
+    def wav2spec(wav_fn, return_linear=False):
+        from data_gen.tts.data_gen_utils import process_utterance
+        res = process_utterance(
+            wav_fn, fft_size=hparams['fft_size'],
+            hop_size=hparams['hop_size'],
+            win_length=hparams['win_size'],
+            num_mels=hparams['audio_num_mel_bins'],
+            fmin=hparams['fmin'],
+            fmax=hparams['fmax'],
+            sample_rate=hparams['audio_sample_rate'],
+            loud_norm=hparams['loud_norm'],
+            min_level_db=hparams['min_level_db'],
+            return_linear=return_linear, vocoder='fastdiff', eps=float(hparams.get('wav2spec_eps', 1e-10)))
+        if return_linear:
+            return res[0], res[1].T, res[2].T  # [T, 80], [T, n_fft]
+        else:
+            return res[0], res[1].T
+
+    @staticmethod
+    def wav2mfcc(wav_fn):
+        fft_size = hparams['fft_size']
+        hop_size = hparams['hop_size']
+        win_length = hparams['win_size']
+        sample_rate = hparams['audio_sample_rate']
+        wav, _ = librosa.core.load(wav_fn, sr=sample_rate)
+        mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13,
+                                    n_fft=fft_size, hop_length=hop_size,
+                                    win_length=win_length, pad_mode="constant", power=1.0)
+        mfcc_delta = librosa.feature.delta(mfcc, order=1)
+        mfcc_delta_delta = librosa.feature.delta(mfcc, order=2)
+        mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T
+        return mfcc
diff --git a/vocoders/hifigan.py b/vocoders/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..810d3c931b556387f8a2e85537f4964add1e76b0
--- /dev/null
+++ b/vocoders/hifigan.py
@@ -0,0 +1,76 @@
+import glob
+import json
+import os
+import re
+
+import librosa
+import torch
+
+import utils
+from modules.hifigan.hifigan import HifiGanGenerator
+from utils.hparams import hparams, set_hparams
+from vocoders.base_vocoder import register_vocoder
+from vocoders.pwg import PWG
+from vocoders.vocoder_utils import denoise
+
+
+def load_model(config_path, checkpoint_path):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
+    if '.yaml' in config_path:
+        config = set_hparams(config_path, global_hparams=False)
+        state = ckpt_dict["state_dict"]["model_gen"]
+    elif '.json' in config_path:
+        config = json.load(open(config_path, 'r'))
+        state = ckpt_dict["generator"]
+
+    model = HifiGanGenerator(config)
+    model.load_state_dict(state, strict=True)
+    model.remove_weight_norm()
+    model = model.eval().to(device)
+    print(f"| Loaded model parameters from {checkpoint_path}.")
+    print(f"| HifiGAN device: {device}.")
+    return model, config, device
+
+
+total_time = 0
+
+
+@register_vocoder
+class HifiGAN(PWG):
+    def __init__(self):
+        base_dir = hparams['vocoder_ckpt']
+        config_path = f'{base_dir}/config.yaml'
+        if os.path.exists(config_path):
+            ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
+            lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
+            print('| load HifiGAN: ', ckpt)
+            self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt)
+        else:
+            config_path = f'{base_dir}/config.json'
+            ckpt = f'{base_dir}/generator_v1'
+            if os.path.exists(config_path):
+                self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt)
+
+    def spec2wav(self, mel, **kwargs):
+        device = self.device
+        with torch.no_grad():
+            c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device)
+            with utils.Timer('hifigan', print_time=hparams['profile_infer']):
+                f0 = kwargs.get('f0')
+                if f0 is not None and hparams.get('use_nsf'):
+                    f0 = torch.FloatTensor(f0[None, :]).to(device)
+                    y = self.model(c, f0).view(-1)
+                else:
+                    y = self.model(c).view(-1)
+        wav_out = y.cpu().numpy()
+        if hparams.get('vocoder_denoise_c', 0.0) > 0:
+            wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c'])
+        return wav_out
+
+    # @staticmethod
+    # def wav2spec(wav_fn, **kwargs):
+    #     wav, _ = librosa.core.load(wav_fn, sr=hparams['audio_sample_rate'])
+    #     wav_torch = torch.FloatTensor(wav)[None, :]
+    #     mel = mel_spectrogram(wav_torch, hparams).numpy()[0]
+    #     return wav, mel.T
diff --git a/vocoders/pwg.py b/vocoders/pwg.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca9b6891ab2ba5cb413eeca97a41534e5db129d5
--- /dev/null
+++ b/vocoders/pwg.py
@@ -0,0 +1,137 @@
+import glob
+import re
+import librosa
+import torch
+import yaml
+from sklearn.preprocessing import StandardScaler
+from torch import nn
+from modules.parallel_wavegan.models import ParallelWaveGANGenerator
+from modules.parallel_wavegan.utils import read_hdf5
+from utils.hparams import hparams
+from utils.pitch_utils import f0_to_coarse
+from vocoders.base_vocoder import BaseVocoder, register_vocoder
+import numpy as np
+
+
+def load_pwg_model(config_path, checkpoint_path, stats_path):
+    # load config
+    with open(config_path) as f:
+        config = yaml.load(f, Loader=yaml.Loader)
+
+    # setup
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+    model = ParallelWaveGANGenerator(**config["generator_params"])
+
+    ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
+    if 'state_dict' not in ckpt_dict:  # official vocoder
+        model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]["generator"])
+        scaler = StandardScaler()
+        if config["format"] == "hdf5":
+            scaler.mean_ = read_hdf5(stats_path, "mean")
+            scaler.scale_ = read_hdf5(stats_path, "scale")
+        elif config["format"] == "npy":
+            scaler.mean_ = np.load(stats_path)[0]
+            scaler.scale_ = np.load(stats_path)[1]
+        else:
+            raise ValueError("support only hdf5 or npy format.")
+    else:  # custom PWG vocoder
+        fake_task = nn.Module()
+        fake_task.model_gen = model
+        fake_task.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"], strict=False)
+        scaler = None
+
+    model.remove_weight_norm()
+    model = model.eval().to(device)
+    print(f"| Loaded model parameters from {checkpoint_path}.")
+    print(f"| PWG device: {device}.")
+    return model, scaler, config, device
+
+
+@register_vocoder
+class PWG(BaseVocoder):
+    def __init__(self):
+        if hparams['vocoder_ckpt'] == '':  # load LJSpeech PWG pretrained model
+            base_dir = 'wavegan_pretrained'
+            ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl')
+            ckpt = sorted(ckpts, key=
+            lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1]
+            config_path = f'{base_dir}/config.yaml'
+            print('| load PWG: ', ckpt)
+            self.model, self.scaler, self.config, self.device = load_pwg_model(
+                config_path=config_path,
+                checkpoint_path=ckpt,
+                stats_path=f'{base_dir}/stats.h5',
+            )
+        else:
+            base_dir = hparams['vocoder_ckpt']
+            print(base_dir)
+            config_path = f'{base_dir}/config.yaml'
+            ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
+            lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
+            print('| load PWG: ', ckpt)
+            self.scaler = None
+            self.model, _, self.config, self.device = load_pwg_model(
+                config_path=config_path,
+                checkpoint_path=ckpt,
+                stats_path=f'{base_dir}/stats.h5',
+            )
+
+    def spec2wav(self, mel, **kwargs):
+        # start generation
+        config = self.config
+        device = self.device
+        pad_size = (config["generator_params"]["aux_context_window"],
+                    config["generator_params"]["aux_context_window"])
+        c = mel
+        if self.scaler is not None:
+            c = self.scaler.transform(c)
+
+        with torch.no_grad():
+            z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device)
+            c = np.pad(c, (pad_size, (0, 0)), "edge")
+            c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device)
+            p = kwargs.get('f0')
+            if p is not None:
+                p = f0_to_coarse(p)
+                p = np.pad(p, (pad_size,), "edge")
+                p = torch.LongTensor(p[None, :]).to(device)
+            y = self.model(z, c, p).view(-1)
+        wav_out = y.cpu().numpy()
+        return wav_out
+
+    @staticmethod
+    def wav2spec(wav_fn, return_linear=False):
+        from data_gen.tts.data_gen_utils import process_utterance
+        res = process_utterance(
+            wav_fn, fft_size=hparams['fft_size'],
+            hop_size=hparams['hop_size'],
+            win_length=hparams['win_size'],
+            num_mels=hparams['audio_num_mel_bins'],
+            fmin=hparams['fmin'],
+            fmax=hparams['fmax'],
+            sample_rate=hparams['audio_sample_rate'],
+            loud_norm=hparams['loud_norm'],
+            min_level_db=hparams['min_level_db'],
+            return_linear=return_linear, vocoder='pwg', eps=float(hparams.get('wav2spec_eps', 1e-10)))
+        if return_linear:
+            return res[0], res[1].T, res[2].T  # [T, 80], [T, n_fft]
+        else:
+            return res[0], res[1].T
+
+    @staticmethod
+    def wav2mfcc(wav_fn):
+        fft_size = hparams['fft_size']
+        hop_size = hparams['hop_size']
+        win_length = hparams['win_size']
+        sample_rate = hparams['audio_sample_rate']
+        wav, _ = librosa.core.load(wav_fn, sr=sample_rate)
+        mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13,
+                                    n_fft=fft_size, hop_length=hop_size,
+                                    win_length=win_length, pad_mode="constant", power=1.0)
+        mfcc_delta = librosa.feature.delta(mfcc, order=1)
+        mfcc_delta_delta = librosa.feature.delta(mfcc, order=2)
+        mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T
+        return mfcc
diff --git a/vocoders/vocoder_utils.py b/vocoders/vocoder_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db5d5ca1765928e4b047db04435a8a39b52592ca
--- /dev/null
+++ b/vocoders/vocoder_utils.py
@@ -0,0 +1,15 @@
+import librosa
+
+from utils.hparams import hparams
+import numpy as np
+
+
+def denoise(wav, v=0.1):
+    spec = librosa.stft(y=wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'],
+                        win_length=hparams['win_size'], pad_mode='constant')
+    spec_m = np.abs(spec)
+    spec_m = np.clip(spec_m - v, a_min=0, a_max=None)
+    spec_a = np.angle(spec)
+
+    return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=hparams['hop_size'],
+                         win_length=hparams['win_size'])