diff --git a/TTS/.models.json b/TTS/.models.json new file mode 100644 index 0000000000000000000000000000000000000000..93d9f417be9313976151dc8ff7c4ee67e41418f7 --- /dev/null +++ b/TTS/.models.json @@ -0,0 +1,500 @@ +{ + "tts_models": { + "multilingual":{ + "multi-dataset":{ + "your_tts":{ + "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--multilingual--multi-dataset--your_tts.zip", + "default_vocoder": null, + "commit": "e9a1953e", + "license": "CC BY-NC-ND 4.0", + "contact": "egolge@coqui.ai" + } + } + }, + "en": { + "ek1": { + "tacotron2": { + "description": "EK1 en-rp tacotron2 by NMStoker", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ek1--tacotron2.zip", + "default_vocoder": "vocoder_models/en/ek1/wavegrad", + "commit": "c802255", + "license": "apache 2.0" + } + }, + "ljspeech": { + "tacotron2-DDC": { + "description": "Tacotron2 with Double Decoder Consistency.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "bae2ad0f", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "tacotron2-DDC_ph": { + "description": "Tacotron2 with Double Decoder Consistency with phonemes.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--tacotron2-DDC_ph.zip", + "default_vocoder": "vocoder_models/en/ljspeech/univnet", + "commit": "3900448", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "glow-tts": { + "description": "", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--glow-tts.zip", + "stats_file": null, + "default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan", + "commit": "", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "speedy-speech": { + "description": "Speedy Speech model trained on LJSpeech dataset using the Alignment Network for learning the durations.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--speedy-speech.zip", + "stats_file": null, + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "4581e3d", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "tacotron2-DCA": { + "description": "", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--tacotron2-DCA.zip", + "default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan", + "commit": "", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "vits": { + "description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--vits.zip", + "default_vocoder": null, + "commit": "3900448", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "fast_pitch": { + "description": "FastPitch model trained on LJSpeech using the Aligner Network", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--fast_pitch.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "b27b3ba", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + } + }, + "vctk": { + "vits": { + "description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--vctk--vits.zip", + "default_vocoder": null, + "commit": "3900448", + "author": "Eren @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + }, + "fast_pitch":{ + "description": "FastPitch model trained on VCTK dataseset.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--vctk--fast_pitch.zip", + "default_vocoder": null, + "commit": "bdab788d", + "author": "Eren @erogol", + "license": "CC BY-NC-ND 4.0", + "contact": "egolge@coqui.ai" + } + }, + "sam": { + "tacotron-DDC": { + "description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--sam--tacotron-DDC.zip", + "default_vocoder": "vocoder_models/en/sam/hifigan_v2", + "commit": "bae2ad0f", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + } + }, + "blizzard2013": { + "capacitron-t2-c50": { + "description": "Capacitron additions to Tacotron 2 with Capacity at 50 as in https://arxiv.org/pdf/1906.03402.pdf", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--en--blizzard2013--capacitron-t2-c50.zip", + "commit": "d6284e7", + "default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2", + "author": "Adam Froghyar @a-froghyar", + "license": "apache 2.0", + "contact": "adamfroghyar@gmail.com" + }, + "capacitron-t2-c150": { + "description": "Capacitron additions to Tacotron 2 with Capacity at 150 as in https://arxiv.org/pdf/1906.03402.pdf", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--en--blizzard2013--capacitron-t2-c150.zip", + "commit": "d6284e7", + "default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2", + "author": "Adam Froghyar @a-froghyar", + "license": "apache 2.0", + "contact": "adamfroghyar@gmail.com" + } + } + }, + "es": { + "mai": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--es--mai--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan", + "commit": "", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + } + } + }, + "fr": { + "mai": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--fr--mai--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan", + "commit": "", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + } + } + }, + "uk":{ + "mai": { + "glow-tts": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--uk--mai--glow-tts.zip", + "author":"@robinhad", + "commit": "bdab788d", + "license": "MIT", + "contact": "", + "default_vocoder": "vocoder_models/uk/mai/multiband-melgan" + } + } + }, + "zh-CN": { + "baker": { + "tacotron2-DDC-GST": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", + "commit": "unknown", + "author": "@kirianguiller", + "license": "apache 2.0", + "default_vocoder": null + } + } + }, + "nl": { + "mai": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--nl--mai--tacotron2-DDC.zip", + "author": "@r-dh", + "license": "apache 2.0", + "default_vocoder": "vocoder_models/nl/mai/parallel-wavegan", + "stats_file": null, + "commit": "540d811" + } + } + }, + "de": { + "thorsten": { + "tacotron2-DCA": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--de--thorsten--tacotron2-DCA.zip", + "default_vocoder": "vocoder_models/de/thorsten/fullband-melgan", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + }, + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--de--thorsten--vits.zip", + "default_vocoder": null, + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + } + } + }, + "ja": { + "kokoro": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--ja--kokoro--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/ja/kokoro/hifigan_v1", + "description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.", + "author": "@kaiidams", + "license": "apache 2.0", + "commit": "401fbd89" + } + } + }, + "tr":{ + "common-voice": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--tr--common-voice--glow-tts.zip", + "default_vocoder": "vocoder_models/tr/common-voice/hifigan", + "license": "MIT", + "description": "Turkish GlowTTS model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "commit": null + } + } + }, + "it": { + "mai_female": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_female--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + }, + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_female--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + } + }, + "mai_male": { + "glow-tts":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_male--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + }, + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_male--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + } + } + }, + "ewe": { + "openbible": { + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--ewe--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "hau": { + "openbible": { + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--hau--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "lin": { + "openbible": { + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--lin--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "tw_akuapem": { + "openbible": { + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--tw_akuapem--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "tw_asante": { + "openbible": { + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--tw_asante--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "yor": { + "openbible": { + "vits":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--yor--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + } + }, + "vocoder_models": { + "universal": { + "libri-tts": { + "wavegrad": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--universal--libri-tts--wavegrad.zip", + "commit": "ea976b0", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "fullband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--universal--libri-tts--fullband-melgan.zip", + "commit": "4132240", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + } + } + }, + "en": { + "ek1": { + "wavegrad": { + "description": "EK1 en-rp wavegrad by NMStoker", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ek1--wavegrad.zip", + "commit": "c802255", + "license": "apache 2.0" + } + }, + "ljspeech": { + "multiband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ljspeech--multiband-melgan.zip", + "commit": "ea976b0", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "hifigan_v2": { + "description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ljspeech--hifigan_v2.zip", + "commit": "bae2ad0f", + "author": "@erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + }, + "univnet": { + "description": "UnivNet model finetuned on TacotronDDC_ph spectrograms for better compatibility.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ljspeech--univnet_v2.zip", + "commit": "4581e3d", + "author": "Eren @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + } + }, + "blizzard2013": { + "hifigan_v2": { + "description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/vocoder_models--en--blizzard2013--hifigan_v2.zip", + "commit": "d6284e7", + "author": "Adam Froghyar @a-froghyar", + "license": "apache 2.0", + "contact": "adamfroghyar@gmail.com" + } + }, + "vctk": { + "hifigan_v2": { + "description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--vctk--hifigan_v2.zip", + "commit": "2f07160", + "author": "Edresson Casanova", + "license": "apache 2.0", + "contact": "" + } + }, + "sam": { + "hifigan_v2": { + "description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--sam--hifigan_v2.zip", + "commit": "2f07160", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + } + } + }, + "nl": { + "mai": { + "parallel-wavegan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--nl--mai--parallel-wavegan.zip", + "author": "@r-dh", + "license": "apache 2.0", + "commit": "unknown" + } + } + }, + "de": { + "thorsten": { + "wavegrad": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--de--thorsten--wavegrad.zip", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + }, + "fullband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--de--thorsten--fullband-melgan.zip", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + } + } + }, + "ja": { + "kokoro": { + "hifigan_v1": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--ja--kokoro--hifigan_v1.zip", + "description": "HifiGAN model trained for kokoro dataset by @kaiidams", + "author": "@kaiidams", + "license": "apache 2.0", + "commit": "3900448" + } + } + }, + "uk": { + "mai": { + "multiband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--uk--mai--multiband-melgan.zip", + "author":"@robinhad", + "commit": "bdab788d", + "license": "MIT", + "contact": "" + } + } + }, + "tr":{ + "common-voice": { + "hifigan":{ + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--tr--common-voice--hifigan.zip", + "description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "license": "MIT", + "commit": null + } + } + } + } +} \ No newline at end of file diff --git a/TTS/__init__.py b/TTS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf05db1b950d82bfd7e20857e09a0fef45b430a --- /dev/null +++ b/TTS/__init__.py @@ -0,0 +1,6 @@ +import os + +with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f: + version = f.read().strip() + +__version__ = version diff --git a/TTS/bin/__init__.py b/TTS/bin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/bin/collect_env_info.py b/TTS/bin/collect_env_info.py new file mode 100644 index 0000000000000000000000000000000000000000..662fcd02ece0fad387b6bfc4bad9316c7e2a0bad --- /dev/null +++ b/TTS/bin/collect_env_info.py @@ -0,0 +1,48 @@ +"""Get detailed info about the working environment.""" +import os +import platform +import sys + +import numpy +import torch + +sys.path += [os.path.abspath(".."), os.path.abspath(".")] +import json + +import TTS + + +def system_info(): + return { + "OS": platform.system(), + "architecture": platform.architecture(), + "version": platform.version(), + "processor": platform.processor(), + "python": platform.python_version(), + } + + +def cuda_info(): + return { + "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], + "available": torch.cuda.is_available(), + "version": torch.version.cuda, + } + + +def package_info(): + return { + "numpy": numpy.__version__, + "PyTorch_version": torch.__version__, + "PyTorch_debug": torch.version.debug, + "TTS": TTS.__version__, + } + + +def main(): + details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()} + print(json.dumps(details, indent=4, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab520be7d9f41ecf4f124446400b5e1b597ae8b --- /dev/null +++ b/TTS/bin/compute_attention_masks.py @@ -0,0 +1,165 @@ +import argparse +import importlib +import os +from argparse import RawTextHelpFormatter + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets.TTSDataset import TTSDataset +from TTS.tts.models import setup_model +from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_checkpoint + +if __name__ == "__main__": + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Extract attention masks from trained Tacotron/Tacotron2 models. +These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" + """Each attention mask is written to the same path as the input wav file with ".npy" file extension. +(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" + """ +Example run: + CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py + --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth + --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json + --dataset_metafile metadata.csv + --data_path /root/LJSpeech-1.1/ + --batch_size 32 + --dataset ljspeech + --use_cuda True +""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Path to Tacotron/Tacotron2 config file.", + ) + parser.add_argument( + "--dataset", + type=str, + default="", + required=True, + help="Target dataset processor name from TTS.tts.dataset.preprocess.", + ) + + parser.add_argument( + "--dataset_metafile", + type=str, + default="", + required=True, + help="Dataset metafile inclusing file paths with transcripts.", + ) + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") + + parser.add_argument( + "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." + ) + args = parser.parse_args() + + C = load_config(args.config_path) + ap = AudioProcessor(**C.audio) + + # if the vocabulary was passed, replace the default + if "characters" in C.keys(): + symbols, phonemes = make_symbols(**C.characters) + + # load the model + num_chars = len(phonemes) if C.use_phonemes else len(symbols) + # TODO: handle multi-speaker + model = setup_model(C) + model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) + + # data loader + preprocessor = importlib.import_module("TTS.tts.datasets.formatters") + preprocessor = getattr(preprocessor, args.dataset) + meta_data = preprocessor(args.data_path, args.dataset_metafile) + dataset = TTSDataset( + model.decoder.r, + C.text_cleaner, + compute_linear_spec=False, + ap=ap, + meta_data=meta_data, + characters=C.characters if "characters" in C.keys() else None, + add_blank=C["add_blank"] if "add_blank" in C.keys() else False, + use_phonemes=C.use_phonemes, + phoneme_cache_path=C.phoneme_cache_path, + phoneme_language=C.phoneme_language, + enable_eos_bos=C.enable_eos_bos_chars, + ) + + dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=4, + collate_fn=dataset.collate_fn, + shuffle=False, + drop_last=False, + ) + + # compute attentions + file_paths = [] + with torch.no_grad(): + for data in tqdm(loader): + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[3] + mel_input = data[4] + mel_lengths = data[5] + stop_targets = data[6] + item_idxs = data[7] + + # dispatch data to GPU + if args.use_cuda: + text_input = text_input.cuda() + text_lengths = text_lengths.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + + model_outputs = model.forward(text_input, text_lengths, mel_input) + + alignments = model_outputs["alignments"].detach() + for idx, alignment in enumerate(alignments): + item_idx = item_idxs[idx] + # interpolate if r > 1 + alignment = ( + torch.nn.functional.interpolate( + alignment.transpose(0, 1).unsqueeze(0), + size=None, + scale_factor=model.decoder.r, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + ) + .squeeze(0) + .transpose(0, 1) + ) + # remove paddings + alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() + # set file paths + wav_file_name = os.path.basename(item_idx) + align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" + file_path = item_idx.replace(wav_file_name, align_file_name) + # save output + wav_file_abs_path = os.path.abspath(item_idx) + file_abs_path = os.path.abspath(file_path) + file_paths.append([wav_file_abs_path, file_abs_path]) + np.save(file_path, alignment) + + # ourput metafile + metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") + + with open(metafile, "w", encoding="utf-8") as f: + for p in file_paths: + f.write(f"{p[0]}|{p[1]}\n") + print(f" >> Metafile created: {metafile}") diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fe3c4bdedf2045ee503b669622695932942145 --- /dev/null +++ b/TTS/bin/compute_embeddings.py @@ -0,0 +1,84 @@ +import argparse +import os +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.managers import save_file +from TTS.tts.utils.speakers import SpeakerManager + +parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each wav file in a dataset.\n\n""" + """ + Example runs: + python TTS/bin/compute_embeddings.py speaker_encoder_model.pth speaker_encoder_config.json dataset_config.json + """, + formatter_class=RawTextHelpFormatter, +) +parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") +parser.add_argument("config_path", type=str, help="Path to model config file.") +parser.add_argument("config_dataset_path", type=str, help="Path to dataset config file.") +parser.add_argument("--output_path", type=str, help="Path for output `pth` or `json` file.", default="speakers.pth") +parser.add_argument("--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None) +parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False) +parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False) + +args = parser.parse_args() + +use_cuda = torch.cuda.is_available() and not args.disable_cuda + +c_dataset = load_config(args.config_dataset_path) + +meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not args.no_eval) + +if meta_data_eval is None: + wav_files = meta_data_train +else: + wav_files = meta_data_train + meta_data_eval + +encoder_manager = SpeakerManager( + encoder_model_path=args.model_path, + encoder_config_path=args.config_path, + d_vectors_file_path=args.old_file, + use_cuda=use_cuda, +) + +class_name_key = encoder_manager.encoder_config.class_name_key + +# compute speaker embeddings +speaker_mapping = {} +for idx, wav_file in enumerate(tqdm(wav_files)): + if isinstance(wav_file, dict): + class_name = wav_file[class_name_key] + wav_file = wav_file["audio_file"] + else: + class_name = None + + wav_file_name = os.path.basename(wav_file) + if args.old_file is not None and wav_file_name in encoder_manager.clip_ids: + # get the embedding from the old file + embedd = encoder_manager.get_embedding_by_clip(wav_file_name) + else: + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(wav_file) + + # create speaker_mapping if target dataset is defined + speaker_mapping[wav_file_name] = {} + speaker_mapping[wav_file_name]["name"] = class_name + speaker_mapping[wav_file_name]["embedding"] = embedd + +if speaker_mapping: + # save speaker_mapping if target dataset is defined + if os.path.isdir(args.output_path): + mapping_file_path = os.path.join(args.output_path, "speakers.pth") + else: + mapping_file_path = args.output_path + + if os.path.dirname(mapping_file_path) != "": + os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) + + save_file(speaker_mapping, mapping_file_path) + print("Speaker embeddings saved at:", mapping_file_path) diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py new file mode 100755 index 0000000000000000000000000000000000000000..3ab7ea7a3b10ec3cc23d8a744c7bdc79de52dbf2 --- /dev/null +++ b/TTS/bin/compute_statistics.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import glob +import os + +import numpy as np +from tqdm import tqdm + +# from TTS.utils.io import load_config +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") + parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") + parser.add_argument("out_path", type=str, help="save path (directory and filename).") + parser.add_argument( + "--data_path", + type=str, + required=False, + help="folder including the target set of wavs overriding dataset config.", + ) + args, overrides = parser.parse_known_args() + + CONFIG = load_config(args.config_path) + CONFIG.parse_known_args(overrides, relaxed_parser=True) + + # load config + CONFIG.audio.signal_norm = False # do not apply earlier normalization + CONFIG.audio.stats_path = None # discard pre-defined stats + + # load audio processor + ap = AudioProcessor(**CONFIG.audio.to_dict()) + + # load the meta data of target dataset + if args.data_path: + dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True) + else: + dataset_items = load_tts_samples(CONFIG.datasets)[0] # take only train data + print(f" > There are {len(dataset_items)} files.") + + mel_sum = 0 + mel_square_sum = 0 + linear_sum = 0 + linear_square_sum = 0 + N = 0 + for item in tqdm(dataset_items): + # compute features + wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"]) + linear = ap.spectrogram(wav) + mel = ap.melspectrogram(wav) + + # compute stats + N += mel.shape[1] + mel_sum += mel.sum(1) + linear_sum += linear.sum(1) + mel_square_sum += (mel**2).sum(axis=1) + linear_square_sum += (linear**2).sum(axis=1) + + mel_mean = mel_sum / N + mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2) + linear_mean = linear_sum / N + linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2) + + output_file_path = args.out_path + stats = {} + stats["mel_mean"] = mel_mean + stats["mel_std"] = mel_scale + stats["linear_mean"] = linear_mean + stats["linear_std"] = linear_scale + + print(f" > Avg mel spec mean: {mel_mean.mean()}") + print(f" > Avg mel spec scale: {mel_scale.mean()}") + print(f" > Avg linear spec mean: {linear_mean.mean()}") + print(f" > Avg linear spec scale: {linear_scale.mean()}") + + # set default config values for mean-var scaling + CONFIG.audio.stats_path = output_file_path + CONFIG.audio.signal_norm = True + # remove redundant values + del CONFIG.audio.max_norm + del CONFIG.audio.min_level_db + del CONFIG.audio.symmetric_norm + del CONFIG.audio.clip_norm + stats["audio_config"] = CONFIG.audio.to_dict() + np.save(output_file_path, stats, allow_pickle=True) + print(f" > stats saved to {output_file_path}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7f9fdf937079d75a673654471871130129c13c0a --- /dev/null +++ b/TTS/bin/eval_encoder.py @@ -0,0 +1,89 @@ +import argparse +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.speakers import SpeakerManager + + +def compute_encoder_accuracy(dataset_items, encoder_manager): + + class_name_key = encoder_manager.encoder_config.class_name_key + map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None) + + class_acc_dict = {} + + # compute embeddings for all wav_files + for item in tqdm(dataset_items): + class_name = item[class_name_key] + wav_file = item["audio_file"] + + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(wav_file) + if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None: + embedding = torch.FloatTensor(embedd).unsqueeze(0) + if encoder_manager.use_cuda: + embedding = embedding.cuda() + + class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item() + predicted_label = map_classid_to_classname[str(class_id)] + else: + predicted_label = None + + if class_name is not None and predicted_label is not None: + is_equal = int(class_name == predicted_label) + if class_name not in class_acc_dict: + class_acc_dict[class_name] = [is_equal] + else: + class_acc_dict[class_name].append(is_equal) + else: + raise RuntimeError("Error: class_name or/and predicted_label are None") + + acc_avg = 0 + for key, values in class_acc_dict.items(): + acc = sum(values) / len(values) + print("Class", key, "Accuracy:", acc) + acc_avg += acc + + print("Average Accuracy:", acc_avg / len(class_acc_dict)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Compute the accuracy of the encoder.\n\n""" + """ + Example runs: + python TTS/bin/eval_encoder.py emotion_encoder_model.pth emotion_encoder_config.json dataset_config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", + ) + + parser.add_argument( + "config_dataset_path", + type=str, + help="Path to dataset config file.", + ) + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + + args = parser.parse_args() + + c_dataset = load_config(args.config_dataset_path) + + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) + items = meta_data_train + meta_data_eval + + enc_manager = SpeakerManager( + encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda + ) + + compute_encoder_accuracy(items, enc_manager) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py new file mode 100755 index 0000000000000000000000000000000000000000..a0dd0549ed8e86aeb3a1aeab28bba6f78f4edd84 --- /dev/null +++ b/TTS/bin/extract_tts_spectrograms.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +"""Extract Mel spectrograms with teacher forcing.""" + +import argparse +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.models import setup_model +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import count_parameters + +use_cuda = torch.cuda.is_available() + + +def setup_loader(ap, r, verbose=False): + tokenizer, _ = TTSTokenizer.init_from_config(c) + dataset = TTSDataset( + outputs_per_step=r, + compute_linear_spec=False, + samples=meta_data, + tokenizer=tokenizer, + ap=ap, + batch_group_size=0, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, + phoneme_cache_path=c.phoneme_cache_path, + precompute_num_workers=0, + use_noise_augment=False, + verbose=verbose, + speaker_id_mapping=speaker_manager.ids if c.use_speaker_embedding else None, + d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, + ) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.preprocess_samples() + + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=None, + num_workers=c.num_loader_workers, + pin_memory=False, + ) + return loader + + +def set_filename(wav_path, out_path): + wav_file = os.path.basename(wav_path) + file_name = wav_file.split(".")[0] + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) + wavq_path = os.path.join(out_path, "quant", file_name) + mel_path = os.path.join(out_path, "mel", file_name) + wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav") + wav_path = os.path.join(out_path, "wav", file_name + ".wav") + return file_name, wavq_path, mel_path, wav_gl_path, wav_path + + +def format_data(data): + # setup input data + text_input = data["token_id"] + text_lengths = data["token_id_lengths"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + item_idx = data["item_idxs"] + d_vectors = data["d_vectors"] + speaker_ids = data["speaker_ids"] + attn_mask = data["attns"] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_ids is not None: + speaker_ids = speaker_ids.cuda(non_blocking=True) + if d_vectors is not None: + d_vectors = d_vectors.cuda(non_blocking=True) + if attn_mask is not None: + attn_mask = attn_mask.cuda(non_blocking=True) + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) + + +@torch.no_grad() +def inference( + model_name, + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=None, + d_vectors=None, +): + if model_name == "glow_tts": + speaker_c = None + if speaker_ids is not None: + speaker_c = speaker_ids + elif d_vectors is not None: + speaker_c = d_vectors + outputs = model.inference_with_MAS( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, + ) + model_output = outputs["model_outputs"] + model_output = model_output.detach().cpu().numpy() + + elif "tacotron" in model_name: + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input) + postnet_outputs = outputs["model_outputs"] + # normalize tacotron output + if model_name == "tacotron": + mel_specs = [] + postnet_outputs = postnet_outputs.data.cpu().numpy() + for b in range(postnet_outputs.shape[0]): + postnet_output = postnet_outputs[b] + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) + model_output = torch.stack(mel_specs).cpu().numpy() + + elif model_name == "tacotron2": + model_output = postnet_outputs.detach().cpu().numpy() + return model_output + + +def extract_spectrograms( + data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt" +): + model.eval() + export_metadata = [] + for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): + + # format data + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + _, + _, + _, + item_idx, + ) = format_data(data) + + model_output = inference( + c.model.lower(), + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + ) + + for idx in range(text_input.shape[0]): + wav_file_path = item_idx[idx] + wav = ap.load_wav(wav_file_path) + _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) + + # quantize and save wav + if quantized_wav: + wavq = ap.quantize(wav) + np.save(wavq_path, wavq) + + # save TTS mel + mel = model_output[idx] + mel_length = mel_lengths[idx] + mel = mel[:mel_length, :].T + np.save(mel_path, mel) + + export_metadata.append([wav_file_path, mel_path]) + if save_audio: + ap.save_wav(wav, wav_path) + + if debug: + print("Audio for debug saved at:", wav_gl_path) + wav = ap.inv_melspectrogram(mel) + ap.save_wav(wav, wav_gl_path) + + with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: + for data in export_metadata: + f.write(f"{data[0]}|{data[1]+'.npy'}\n") + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data, speaker_manager + + # Audio processor + ap = AudioProcessor(**c.audio) + + # load data instances + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + # use eval and training partitions + meta_data = meta_data_train + meta_data_eval + + # init speaker manager + if c.use_speaker_embedding: + speaker_manager = SpeakerManager(data_items=meta_data) + elif c.use_d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file) + else: + speaker_manager = None + + # setup model + model = setup_model(c) + + # restore model + model.load_checkpoint(c, args.checkpoint_path, eval=True) + + if use_cuda: + model.cuda() + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + # set r + r = 1 if c.model.lower() == "glow_tts" else model.decoder.r + own_loader = setup_loader(ap, r, verbose=True) + + extract_spectrograms( + own_loader, + model, + ap, + args.output_path, + quantized_wav=args.quantized, + save_audio=args.save_audio, + debug=args.debug, + metada_name="metada.txt", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) + parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) + parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) + parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") + parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") + parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + args = parser.parse_args() + + c = load_config(args.config_path) + c.audio.trim_silence = False + main(args) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py new file mode 100644 index 0000000000000000000000000000000000000000..ea16974839df6cf9942ef24a5535597940fde5b2 --- /dev/null +++ b/TTS/bin/find_unique_chars.py @@ -0,0 +1,45 @@ +"""Find all the unique characters in a dataset""" +import argparse +from argparse import RawTextHelpFormatter + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples + + +def main(): + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_chars.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + items = train_items + eval_items + + texts = "".join(item["text"] for item in items) + chars = set(texts) + lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + + print(f" > Number of unique characters: {len(chars)}") + print(f" > Unique characters: {''.join(sorted(chars))}") + print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") + print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae74bd4749711d7d96aac1341a2e117cd62bd3b --- /dev/null +++ b/TTS/bin/find_unique_phonemes.py @@ -0,0 +1,70 @@ +"""Find all the unique characters in a dataset""" +import argparse +import multiprocessing +from argparse import RawTextHelpFormatter + +from tqdm.contrib.concurrent import process_map + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + +phonemizer = Gruut(language="en-us") + + +def compute_phonemes(item): + try: + text = item[0] + ph = phonemizer.phonemize(text).split("|") + except: + return [] + return list(set(ph)) + + +def main(): + # pylint: disable=W0601 + global c + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_chars.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + items = train_items + eval_items + print("Num items:", len(items)) + + is_lang_def = all(item["language"] for item in items) + + if not c.phoneme_language or not is_lang_def: + raise ValueError("Phoneme language must be defined in config.") + + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) + phones = [] + for ph in phonemes: + phones.extend(ph) + phones = set(phones) + lower_phones = filter(lambda c: c.islower(), phones) + phones_force_lower = [c.lower() for c in phones] + phones_force_lower = set(phones_force_lower) + + print(f" > Number of unique phonemes: {len(phones)}") + print(f" > Unique phonemes: {''.join(sorted(phones))}") + print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") + print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py new file mode 100755 index 0000000000000000000000000000000000000000..7d88ae914eda5053aa4a52c40e2ffaa5318a10e5 --- /dev/null +++ b/TTS/bin/remove_silence_using_vad.py @@ -0,0 +1,85 @@ +import argparse +import glob +import os +import pathlib + +from tqdm import tqdm + +from TTS.utils.vad import get_vad_model_and_utils, remove_silence + + +def adjust_path_and_remove_silence(audio_path): + output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) + # ignore if the file exists + if os.path.exists(output_path) and not args.force: + return output_path + + # create all directory structure + pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) + # remove the silence and save the audio + output_path = remove_silence( + model_and_utils, + audio_path, + output_path, + trim_just_beginning_and_end=args.trim_just_beginning_and_end, + use_cuda=args.use_cuda, + ) + + return output_path + + +def preprocess_audios(): + files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True)) + print("> Number of files: ", len(files)) + if not args.force: + print("> Ignoring files that already exist in the output directory.") + + if args.trim_just_beginning_and_end: + print("> Trimming just the beginning and the end with nonspeech parts.") + else: + print("> Trimming all nonspeech parts.") + + if files: + # create threads + # num_threads = multiprocessing.cpu_count() + # process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15) + for f in tqdm(files): + adjust_path_and_remove_silence(f) + else: + print("> No files Found !") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" + ) + parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir") + parser.add_argument( + "-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir" + ) + parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files") + parser.add_argument( + "-g", + "--glob", + type=str, + default="**/*.wav", + help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav", + ) + parser.add_argument( + "-t", + "--trim_just_beginning_and_end", + type=bool, + default=True, + help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True", + ) + parser.add_argument( + "-c", + "--use_cuda", + type=bool, + default=False, + help="If True use cuda", + ) + args = parser.parse_args() + # load the model and utils + model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda) + preprocess_audios() diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f1166a647e2e761118862c2e8ac82a131428a9 --- /dev/null +++ b/TTS/bin/resample.py @@ -0,0 +1,87 @@ +import argparse +import glob +import os +from argparse import RawTextHelpFormatter +from distutils.dir_util import copy_tree +from multiprocessing import Pool + +import librosa +import soundfile as sf +from tqdm import tqdm + + +def resample_file(func_args): + filename, output_sr = func_args + y, sr = librosa.load(filename, sr=output_sr) + sf.write(filename, y, sr) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""Resample a folder recusively with librosa + Can be used in place or create a copy of the folder as an output.\n\n + Example run: + python TTS/bin/resample.py + --input_dir /root/LJSpeech-1.1/ + --output_sr 22050 + --output_dir /root/resampled_LJSpeech-1.1/ + --file_ext wav + --n_jobs 24 + """, + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--input_dir", + type=str, + default=None, + required=True, + help="Path of the folder containing the audio files to resample", + ) + + parser.add_argument( + "--output_sr", + type=int, + default=22050, + required=False, + help="Samlple rate to which the audio files should be resampled", + ) + + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=False, + help="Path of the destination folder. If not defined, the operation is done in place", + ) + + parser.add_argument( + "--file_ext", + type=str, + default="wav", + required=False, + help="Extension of the audio files to resample", + ) + + parser.add_argument( + "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" + ) + + args = parser.parse_args() + + if args.output_dir: + print("Recursively copying the input folder...") + copy_tree(args.input_dir, args.output_dir) + args.input_dir = args.output_dir + + print("Resampling the audio files...") + audio_files = glob.glob(os.path.join(args.input_dir, f"**/*.{args.file_ext}"), recursive=True) + print(f"Found {len(audio_files)} files...") + audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr])) + with Pool(processes=args.n_jobs) as p: + with tqdm(total=len(audio_files)) as pbar: + for i, _ in enumerate(p.imap_unordered(resample_file, audio_files)): + pbar.update() + + print("Done !") diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py new file mode 100755 index 0000000000000000000000000000000000000000..787d958a18742612507f750f08aa7349efdcc051 --- /dev/null +++ b/TTS/bin/synthesize.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import sys +import pandas as pd +from argparse import RawTextHelpFormatter + +# pylint: disable=redefined-outer-name, unused-argument +from pathlib import Path + +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer +from tqdm.auto import tqdm + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + if v.lower() in ("no", "false", "f", "n", "0"): + return False + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def main(): + description = """Synthesize speech on command line. + +You can either use your trained model or choose a model from the provided list. + +If you don't specify any models, then it uses LJSpeech based English model. + +## Example Runs + +### Single Speaker Models + +- List provided models: + + ``` + $ tts --list_models + ``` + +- Query info for model info by idx: + + ``` + $ tts --model_info_by_idx "/" + ``` + +- Query info for model info by full name: + + ``` + $ tts --model_info_by_name "///" + ``` + +- Run TTS with default models: + + ``` + $ tts --text "Text for TTS" + ``` + +- Run a TTS model with its default vocoder model: + + ``` + $ tts --text "Text for TTS" --model_name "//" + ``` + +- Run with specific TTS and vocoder models from the list: + + ``` + $ tts --text "Text for TTS" --model_name "//" --vocoder_name "//" --output_path + ``` + +- Run your own TTS model (Using Griffin-Lim Vocoder): + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + ``` + +- Run your own TTS and Vocoder models: + ``` + $ tts --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth --out_path output/path/speech.wav + --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json + ``` + +### Multi-speaker Models + +- List the available speakers and choose as among them: + + ``` + $ tts --model_name "//" --list_speaker_idxs + ``` + +- Run the multi-speaker TTS model with the target speaker ID: + + ``` + $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx + ``` + +- Run your own multi-speaker TTS model: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx + ``` + """ + # We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep + # documentation in sync more easily. + parser = argparse.ArgumentParser( + description=description.replace(" ```\n", ""), + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--list_models", + type=str2bool, + nargs="?", + const=True, + default=False, + help="list available pre-trained TTS and vocoder models.", + ) + + parser.add_argument( + "--model_info_by_idx", + type=str, + default=None, + help="model info using query format: /", + ) + + parser.add_argument( + "--model_info_by_name", + type=str, + default=None, + help="model info using query format: ///", + ) + + parser.add_argument("--text", type=str, default=None, help="Text to generate speech.") + + #parser.add_argument("--text_file_path", type=str, default=None, help="A csv file in LJSpeech format ('|' seperated id, text and speaker) to generate speech.") + #parser.add_argument("--speaker_name_filter", type=str, default=None, help="Filter texts corresponding to a specific speaker in text_file_path ") + + # Args for running pre-trained TTS models. + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/tacotron2-DDC", + help="Name of one of the pre-trained TTS models in format //", + ) + parser.add_argument( + "--vocoder_name", + type=str, + default=None, + help="Name of one of the pre-trained vocoder models in format //", + ) + + # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model file.", + ) + parser.add_argument( + "--out_path", + type=str, + default="tts_output.wav", + help="Output wav file path.", + ) + + # parser.add_argument( + # "--out_folder", + # type=str, + # default="tts_output", + # help="Output wav files folder.", + # ) + + parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) + parser.add_argument( + "--vocoder_path", + type=str, + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", + default=None, + ) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) + parser.add_argument( + "--encoder_path", + type=str, + help="Path to speaker encoder model file.", + default=None, + ) + parser.add_argument("--encoder_config_path", type=str, help="Path to speaker encoder config file.", default=None) + + # args for multi-speaker synthesis + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) + parser.add_argument( + "--speaker_idx", + type=str, + help="Target speaker ID for a multi-speaker TTS model.", + default=None, + ) + parser.add_argument( + "--language_idx", + type=str, + help="Target language ID for a multi-lingual TTS model.", + default=None, + ) + parser.add_argument( + "--speaker_wav", + nargs="+", + help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.", + default=None, + ) + parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None) + parser.add_argument( + "--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None + ) + parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None) + parser.add_argument( + "--list_speaker_idxs", + help="List available speaker ids for the defined multi-speaker model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + parser.add_argument( + "--list_language_idxs", + help="List available language ids for the defined multi-lingual model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + # aux args + parser.add_argument( + "--save_spectogram", + type=bool, + help="If true save raw spectogram for further (vocoder) processing in out_path.", + default=False, + ) + parser.add_argument( + "--reference_wav", + type=str, + help="Reference wav file to convert in the voice of the speaker_idx or speaker_wav", + default=None, + ) + parser.add_argument( + "--reference_speaker_idx", + type=str, + help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).", + default=None, + ) + args = parser.parse_args() + + # print the description if either text or list_models is not set + check_args = [ + args.text, + args.list_models, + args.list_speaker_idxs, + args.list_language_idxs, + args.reference_wav, + args.model_info_by_idx, + args.model_info_by_name, + ] + if not any(check_args): + parser.parse_args(["-h"]) + + # load model manager + path = Path(__file__).parent / "../.models.json" + manager = ModelManager(path) + + model_path = None + config_path = None + speakers_file_path = None + language_ids_file_path = None + vocoder_path = None + vocoder_config_path = None + encoder_path = None + encoder_config_path = None + + # CASE1 #list : list pre-trained TTS models + if args.list_models: + manager.list_models() + sys.exit() + + # CASE2 #info : model info of pre-trained TTS models + if args.model_info_by_idx: + model_query = args.model_info_by_idx + manager.model_info_by_idx(model_query) + sys.exit() + + if args.model_info_by_name: + model_query_full_name = args.model_info_by_name + manager.model_info_by_full_name(model_query_full_name) + sys.exit() + + # CASE3: load pre-trained model paths + if args.model_name is not None and not args.model_path: + model_path, config_path, model_item = manager.download_model(args.model_name) + args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + + if args.vocoder_name is not None and not args.vocoder_path: + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) + + # CASE4: set custom model paths + if args.model_path is not None: + model_path = args.model_path + config_path = args.config_path + speakers_file_path = args.speakers_file_path + language_ids_file_path = args.language_ids_file_path + + if args.vocoder_path is not None: + vocoder_path = args.vocoder_path + vocoder_config_path = args.vocoder_config_path + + if args.encoder_path is not None: + encoder_path = args.encoder_path + encoder_config_path = args.encoder_config_path + + # load models + synthesizer = Synthesizer( + model_path, + config_path, + speakers_file_path, + language_ids_file_path, + vocoder_path, + vocoder_config_path, + encoder_path, + encoder_config_path, + args.use_cuda, + ) + + # query speaker ids of a multi-speaker model. + if args.list_speaker_idxs: + print( + " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + ) + print(synthesizer.tts_model.speaker_manager.ids) + return + + # query langauge ids of a multi-lingual model. + if args.list_language_idxs: + print( + " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + ) + print(synthesizer.tts_model.language_manager.ids) + return + + # check the arguments against a multi-speaker model. + if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): + print( + " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to " + "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." + ) + return + + # RUN THE SYNTHESIS + if args.text.endswith('.csv'): + df = pd.read_csv(args.text, sep='|') + num_cols = df.shape[1] + columns = ['id', 'text', 'speaker_name', 'gender', 'text_len', 'audio_len', 'speaker_wav'][:num_cols] + df = pd.read_csv(args.text, sep='|', names=columns) + df = df.head(10) + + # print(f'Number of examples before speaker filter: {len(df)}') + # if args.speaker_name_filter: + # df = df[df['speaker_name']==args.speaker_name_filter] + # print(f'Number of examples after speaker filter: {len(df)}') + + if len(df) == 0: + raise ValueError("No records found.") + + if 'speaker_wav' in df.columns: + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Synthesizing"): + wav = synthesizer.tts( + text=row['text'], + speaker_name=None, + language_name=args.language_idx, + speaker_wav=row['speaker_wav'], + reference_wav=args.reference_wav, + style_wav=args.capacitron_style_wav, + style_text=args.capacitron_style_text, + reference_speaker_name=args.reference_speaker_idx, + ) + synthesizer.save_wav(wav, f'{args.out_path}/{row["id"]}.wav') + else: + for idx, row in tqdm(df.iterrows(), total=len(df), desc="Synthesizing"): + wav = synthesizer.tts( + row['text'], + row['speaker_name'] if 'speaker_name' in df.columns else args.speaker_idx, + args.language_idx, + args.speaker_wav, + reference_wav=args.reference_wav, + style_wav=args.capacitron_style_wav, + style_text=args.capacitron_style_text, + reference_speaker_name=args.reference_speaker_idx, + ) + synthesizer.save_wav(wav, f'{args.out_path}/{row["id"]}.wav') + print(" > Saved output wav files in {}".format(args.out_path)) + return True + + if args.text: + print(" > Text: {}".format(args.text)) + + + # kick it + wav = synthesizer.tts( + args.text, + args.speaker_idx, + args.language_idx, + args.speaker_wav, + reference_wav=args.reference_wav, + style_wav=args.capacitron_style_wav, + style_text=args.capacitron_style_text, + reference_speaker_name=args.reference_speaker_idx, + ) + + # save the results + print(" > Saving output to {}".format(args.out_path)) + synthesizer.save_wav(wav, args.out_path) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d28f188e752ad545ef4295fe708f0a5ee52f5bd1 --- /dev/null +++ b/TTS/bin/train_encoder.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import sys +import time +import traceback + +import torch +from torch.utils.data import DataLoader +from trainer.torch import NoamLR +from trainer.trainer_utils import get_optimizer + +from TTS.encoder.dataset import EncoderDataset +from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model +from TTS.encoder.utils.samplers import PerfectBatchSampler +from TTS.encoder.utils.training import init_training +from TTS.encoder.utils.visual import plot_embeddings +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import count_parameters, remove_experiment_folder +from TTS.utils.io import copy_model_files +from TTS.utils.training import check_update + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True +torch.manual_seed(54321) +use_cuda = torch.cuda.is_available() +num_gpus = torch.cuda.device_count() +print(" > Using CUDA: ", use_cuda) +print(" > Number of GPUs: ", num_gpus) + + +def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): + num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class + num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch + + dataset = EncoderDataset( + c, + ap, + meta_data_eval if is_val else meta_data_train, + voice_len=c.voice_len, + num_utter_per_class=num_utter_per_class, + num_classes_in_batch=num_classes_in_batch, + verbose=verbose, + augmentation_config=c.audio_augmentation if not is_val else None, + use_torch_spec=c.model_params.get("use_torch_spec", False), + ) + # get classes list + classes = dataset.get_class_list() + + sampler = PerfectBatchSampler( + dataset.items, + classes, + batch_size=num_classes_in_batch * num_utter_per_class, # total batch size + num_classes_in_batch=num_classes_in_batch, + num_gpus=1, + shuffle=not is_val, + drop_last=True, + ) + + if len(classes) < num_classes_in_batch: + if is_val: + raise RuntimeError( + f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !" + ) + raise RuntimeError( + f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !" + ) + + # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal + if is_val: + dataset.set_classes(train_classes) + + loader = DataLoader( + dataset, + num_workers=c.num_loader_workers, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + ) + + return loader, classes, dataset.get_map_classid_to_classname() + + +def evaluation(model, criterion, data_loader, global_step): + eval_loss = 0 + for _, data in enumerate(data_loader): + with torch.no_grad(): + # setup input data + inputs, labels = data + + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose( + labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1 + ).reshape(labels.shape) + inputs = torch.transpose( + inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1 + ).reshape(inputs.shape) + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels + ) + + eval_loss += loss.item() + + eval_avg_loss = eval_loss / len(data_loader) + # save stats + dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss}) + # plot the last batch in the evaluation + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.eval_figures(global_step, figures) + return eval_avg_loss + + +def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): + model.train() + best_loss = float("inf") + avg_loader_time = 0 + end_time = time.time() + for epoch in range(c.epochs): + tot_loss = 0 + epoch_time = 0 + for _, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + inputs, labels = data + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape( + labels.shape + ) + inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape( + inputs.shape + ) + # ToDo: move it to a unit test + # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) + # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + # idx = 0 + # for j in range(0, c.num_classes_in_batch, 1): + # for i in range(j, len(labels), c.num_classes_in_batch): + # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): + # print("Invalid") + # print(labels) + # exit() + # idx += 1 + # labels = labels_converted + # inputs = inputs_converted + + loader_time = time.time() - end_time + global_step += 1 + + # setup lr + if c.lr_decay: + scheduler.step() + optimizer.zero_grad() + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels + ) + loss.backward() + grad_norm, _ = check_update(model, c.grad_clip) + optimizer.step() + + step_time = time.time() - start_time + epoch_time += step_time + + # acumulate the total epoch loss + tot_loss += loss.item() + + # Averaged Loader Time + num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 + avg_loader_time = ( + 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time + if avg_loader_time != 0 + else loader_time + ) + current_lr = optimizer.param_groups[0]["lr"] + + if global_step % c.steps_plot_stats == 0: + # Plot Training Epoch Stats + train_stats = { + "loss": loss.item(), + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time, + "avg_loader_time": avg_loader_time, + } + dashboard_logger.train_epoch_stats(global_step, train_stats) + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.train_figures(global_step, figures) + + if global_step % c.print_step == 0: + print( + " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} " + "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( + global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr + ), + flush=True, + ) + + if global_step % c.save_step == 0: + # save model + save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch) + + end_time = time.time() + + print("") + print( + ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " + "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( + epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time + ), + flush=True, + ) + # evaluation + if c.run_eval: + model.eval() + eval_loss = evaluation(model, criterion, eval_data_loader, global_step) + print("\n\n") + print("--> EVAL PERFORMANCE") + print( + " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss), + flush=True, + ) + # save the best checkpoint + best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch) + model.train() + + return best_loss, global_step + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data_train + global meta_data_eval + global train_classes + + ap = AudioProcessor(**c.audio) + model = setup_encoder_model(c) + + optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model) + + # pylint: disable=redefined-outer-name + meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) + + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) + if c.run_eval: + eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) + else: + eval_data_loader = None + + num_classes = len(train_classes) + criterion = model.get_criterion(c, num_classes) + + if c.loss == "softmaxproto" and c.model != "speaker_encoder": + c.map_classid_to_classname = map_classid_to_classname + copy_model_files(c, OUT_PATH) + + if args.restore_path: + criterion, args.restore_step = model.load_checkpoint( + c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion + ) + print(" > Model restored from step %d" % args.restore_step, flush=True) + else: + args.restore_step = 0 + + if c.lr_decay: + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) + else: + scheduler = None + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + + if use_cuda: + model = model.cuda() + criterion.cuda() + + global_step = args.restore_step + _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step) + + +if __name__ == "__main__": + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb4f6f69122a4a9aa4e07695f1816ce9727f323 --- /dev/null +++ b/TTS/bin/train_tts.py @@ -0,0 +1,71 @@ +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models import setup_model + + +@dataclass +class TrainTTSArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainTTSArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the model from config + model = setup_model(config, train_samples + eval_samples) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + model.config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..32ecd7bdc3652b3683be846bdd9518e937aee904 --- /dev/null +++ b/TTS/bin/train_vocoder.py @@ -0,0 +1,77 @@ +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.models import setup_model + + +@dataclass +class TrainVocoderArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainVocoderArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + if "feature_path" in config and config.feature_path: + # load pre-computed features + print(f" > Loading features from: {config.feature_path}") + eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size) + else: + # load data raw wav files + eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # init the model from config + model = setup_model(config) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..a31d6c4548bb0c769ca4b0bf05cf1d13c3ae39d4 --- /dev/null +++ b/TTS/bin/tune_wavegrad.py @@ -0,0 +1,100 @@ +"""Search a good noise schedule for WaveGrad for a given number of inferece iterations""" +import argparse +from itertools import product as cartesian_product + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.utils.generic_utils import setup_generator + +parser = argparse.ArgumentParser() +parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") +parser.add_argument("--config_path", type=str, help="Path to model config file.") +parser.add_argument("--data_path", type=str, help="Path to data directory.") +parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") +parser.add_argument( + "--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for." +) +parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.") +parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") +parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", +) + +# load config +args = parser.parse_args() +config = load_config(args.config_path) + +# setup audio processor +ap = AudioProcessor(**config.audio) + +# load dataset +_, train_data = load_wav_data(args.data_path, 0) +train_data = train_data[: args.num_samples] +dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True, +) +loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_full_clips, + drop_last=False, + num_workers=config.num_loader_workers, + pin_memory=False, +) + +# setup the model +model = setup_generator(config) +if args.use_cuda: + model.cuda() + +# setup optimization parameters +base_values = sorted(10 * np.random.uniform(size=args.search_depth)) +print(base_values) +exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) +best_error = float("inf") +best_schedule = None +total_search_iter = len(base_values) ** args.num_iter +for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): + beta = exponents * base + model.compute_noise_level(beta) + for data in loader: + mel, audio = data + y_hat = model.inference(mel.cuda() if args.use_cuda else mel) + + if args.use_cuda: + y_hat = y_hat.cpu() + y_hat = y_hat.numpy() + + mel_hat = [] + for i in range(y_hat.shape[0]): + m = ap.melspectrogram(y_hat[i, 0])[:, :-1] + mel_hat.append(torch.from_numpy(m)) + + mel_hat = torch.stack(mel_hat) + mse = torch.sum((mel - mel_hat) ** 2).mean() + if mse.item() < best_error: + best_error = mse.item() + best_schedule = {"beta": beta} + print(f" > Found a better schedule. - MSE: {mse.item()}") + np.save(args.output_path, best_schedule) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0778c5a7ca05e753b54b9a64eb5ec7aa29c0eb --- /dev/null +++ b/TTS/config/__init__.py @@ -0,0 +1,132 @@ +import json +import os +import re +from typing import Dict + +import fsspec +import yaml +from coqpit import Coqpit + +from TTS.config.shared_configs import * +from TTS.utils.generic_utils import find_module + + +def read_json_with_comments(json_path): + """for backward compat.""" + # fallback to json + with fsspec.open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments + input_str = re.sub(r"\\\n", "", input_str) + input_str = re.sub(r"//.*\n", "\n", input_str) + data = json.loads(input_str) + return data + + +def register_config(model_name: str) -> Coqpit: + """Find the right config for the given model name. + + Args: + model_name (str): Model name. + + Raises: + ModuleNotFoundError: No matching config for the model name. + + Returns: + Coqpit: config class. + """ + config_class = None + config_name = model_name + "_config" + paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"] + for path in paths: + try: + config_class = find_module(path, config_name) + except ModuleNotFoundError: + pass + if config_class is None: + raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.") + return config_class + + +def _process_model_name(config_dict: Dict) -> str: + """Format the model name as expected. It is a band-aid for the old `vocoder` model names. + + Args: + config_dict (Dict): A dictionary including the config fields. + + Returns: + str: Formatted modelname. + """ + model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] + model_name = model_name.replace("_generator", "").replace("_discriminator", "") + return model_name + + +def load_config(config_path: str) -> None: + """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name + to find the corresponding Config class. Then initialize the Config. + + Args: + config_path (str): path to the config file. + + Raises: + TypeError: given config file has an unknown type. + + Returns: + Coqpit: TTS config object. + """ + config_dict = {} + ext = os.path.splitext(config_path)[1] + if ext in (".yml", ".yaml"): + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + elif ext == ".json": + try: + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.decoder.JSONDecodeError: + # backwards compat. + data = read_json_with_comments(config_path) + else: + raise TypeError(f" [!] Unknown config file type {ext}") + config_dict.update(data) + model_name = _process_model_name(config_dict) + config_class = register_config(model_name.lower()) + config = config_class() + config.from_dict(config_dict) + return config + + +def check_config_and_model_args(config, arg_name, value): + """Check the give argument in `config.model_args` if exist or in `config` for + the given value. + + Return False if the argument does not exist in `config.model_args` or `config`. + This is to patch up the compatibility between models with and without `model_args`. + + TODO: Remove this in the future with a unified approach. + """ + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] == value + if hasattr(config, arg_name): + return config[arg_name] == value + return False + + +def get_from_config_or_model_args(config, arg_name): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + return config[arg_name] + + +def get_from_config_or_model_args_with_default(config, arg_name, def_val): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + if hasattr(config, arg_name): + return config[arg_name] + return def_val diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..3ea49796fce9b796703f285b91a16339432a2a1d --- /dev/null +++ b/TTS/config/shared_configs.py @@ -0,0 +1,260 @@ +from dataclasses import asdict, dataclass +from typing import List + +from coqpit import Coqpit, check_argument +from trainer import TrainerConfig + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 1024. + + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + + sample_rate (int): + Audio sampling rate. Defaults to 22050. + + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + pitch_fmax (float, optional): + Maximum frequency of the F0 frames. Defaults to ```640```. + + pitch_fmin (float, optional): + Minimum frequency of the F0 frames. Defaults to ```0```. + + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # rms volume normalization + do_rms_norm: bool = False + db_level: float = None + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + # f0 params + pitch_fmax: float = 640.0 + pitch_fmin: float = 0.0 + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + """Base config for TTS datasets. + + Args: + name (str): + Dataset name that defines the preprocessor in use. Defaults to None. + + path (str): + Root path to the dataset files. Defaults to None. + + meta_file_train (str): + Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. + Defaults to None. + + ignored_speakers (List): + List of speakers IDs that are not used at the training. Default None. + + language (str): + Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to None. + + meta_file_val (str): + Name of the dataset meta file that defines the instances used at validation. + + meta_file_attn_mask (str): + Path to the file that lists the attention mask files used with models that require attention masks to + train the duration predictor. + """ + + name: str = "" + path: str = "" + meta_file_train: str = "" + ignored_speakers: List[str] = None + language: str = "" + meta_file_val: str = "" + meta_file_attn_mask: str = "" + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("name", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(TrainerConfig): + """Base config to define the basic 🐸TTS training parameters that are shared + among all the models. It is based on ```Trainer.TrainingConfig```. + + Args: + model (str): + Name of the model that is used in the training. + + num_loader_workers (int): + Number of workers for training time dataloader. + + num_eval_loader_workers (int): + Number of workers for evaluation time dataloader. + """ + + model: str = None + # dataloading + num_loader_workers: int = 0 + num_eval_loader_workers: int = 0 + use_noise_augment: bool = False diff --git a/TTS/encoder/README.md b/TTS/encoder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b38b20052b707b0358068bc0ce58bc300a149def --- /dev/null +++ b/TTS/encoder/README.md @@ -0,0 +1,18 @@ +### Speaker Encoder + +This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. + +With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart. + +Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q). + +![](umap.png) + +Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page. + +To run the code, you need to follow the same flow as in TTS. + +- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. +- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` +- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. +- Watch training on Tensorboard as in TTS diff --git a/TTS/encoder/__init__.py b/TTS/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/encoder/configs/base_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbaa0457bb55aef70d54dd36fd9b2b7f7c702bb --- /dev/null +++ b/TTS/encoder/configs/base_encoder_config.py @@ -0,0 +1,61 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import MISSING + +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseEncoderConfig(BaseTrainingConfig): + """Defines parameters for a Generic Encoder model.""" + + model: str = None + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # model params + model_params: Dict = field( + default_factory=lambda: { + "model_name": "lstm", + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": True, + } + ) + + audio_augmentation: Dict = field(default_factory=lambda: {}) + + # training params + epochs: int = 10000 + loss: str = "angleproto" + grad_clip: float = 3.0 + lr: float = 0.0001 + optimizer: str = "radam" + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0}) + lr_decay: bool = False + warmup_steps: int = 4000 + + # logging params + tb_model_param_stats: bool = False + steps_plot_stats: int = 10 + save_step: int = 1000 + print_step: int = 20 + run_eval: bool = False + + # data loader + num_classes_in_batch: int = MISSING + num_utter_per_class: int = MISSING + eval_num_classes_in_batch: int = None + eval_num_utter_per_class: int = None + + num_loader_workers: int = MISSING + voice_len: float = 1.6 + + def check_values(self): + super().check_values() + c = asdict(self) + assert ( + c["model_params"]["input_dim"] == self.audio.num_mels + ), " [!] model input dimendion must be equal to melspectrogram dimension." diff --git a/TTS/encoder/configs/emotion_encoder_config.py b/TTS/encoder/configs/emotion_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5eda2671be980abce4a0506a075387b601a1596c --- /dev/null +++ b/TTS/encoder/configs/emotion_encoder_config.py @@ -0,0 +1,12 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class EmotionEncoderConfig(BaseEncoderConfig): + """Defines parameters for Emotion Encoder model.""" + + model: str = "emotion_encoder" + map_classid_to_classname: dict = None + class_name_key: str = "emotion_name" diff --git a/TTS/encoder/configs/speaker_encoder_config.py b/TTS/encoder/configs/speaker_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6dceb00277ba68efe128936ff7f9456338f9753f --- /dev/null +++ b/TTS/encoder/configs/speaker_encoder_config.py @@ -0,0 +1,11 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class SpeakerEncoderConfig(BaseEncoderConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + class_name_key: str = "speaker_name" diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..582b1fe9ca35cb9afbc20b8f72b6173282201272 --- /dev/null +++ b/TTS/encoder/dataset.py @@ -0,0 +1,147 @@ +import random + +import torch +from torch.utils.data import Dataset + +from TTS.encoder.utils.generic_utils import AugmentWAV + + +class EncoderDataset(Dataset): + def __init__( + self, + config, + ap, + meta_data, + voice_len=1.6, + num_classes_in_batch=64, + num_utter_per_class=10, + verbose=False, + augmentation_config=None, + use_torch_spec=None, + ): + """ + Args: + ap (TTS.tts.utils.AudioProcessor): audio processor object. + meta_data (list): list of dataset instances. + seq_len (int): voice segment length in seconds. + verbose (bool): print diagnostic information. + """ + super().__init__() + self.config = config + self.items = meta_data + self.sample_rate = ap.sample_rate + self.seq_len = int(voice_len * self.sample_rate) + self.num_utter_per_class = num_utter_per_class + self.ap = ap + self.verbose = verbose + self.use_torch_spec = use_torch_spec + self.classes, self.items = self.__parse_items() + + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + # Data Augmentation + self.augmentator = None + self.gaussian_augmentation_config = None + if augmentation_config: + self.data_augmentation_p = augmentation_config["p"] + if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): + self.augmentator = AugmentWAV(ap, augmentation_config) + + if "gaussian" in augmentation_config.keys(): + self.gaussian_augmentation_config = augmentation_config["gaussian"] + + if self.verbose: + print("\n > DataLoader initialization") + print(f" | > Classes per Batch: {num_classes_in_batch}") + print(f" | > Number of instances : {len(self.items)}") + print(f" | > Sequence length: {self.seq_len}") + print(f" | > Num Classes: {len(self.classes)}") + print(f" | > Classes: {self.classes}") + + def load_wav(self, filename): + audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) + return audio + + def __parse_items(self): + class_to_utters = {} + for item in self.items: + path_ = item["audio_file"] + class_name = item[self.config.class_name_key] + if class_name in class_to_utters.keys(): + class_to_utters[class_name].append(path_) + else: + class_to_utters[class_name] = [ + path_, + ] + + # skip classes with number of samples >= self.num_utter_per_class + class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class} + + classes = list(class_to_utters.keys()) + classes.sort() + + new_items = [] + for item in self.items: + path_ = item["audio_file"] + class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"] + # ignore filtered classes + if class_name not in classes: + continue + # ignore small audios + if self.load_wav(path_).shape[0] - self.seq_len <= 0: + continue + + new_items.append({"wav_file_path": path_, "class_name": class_name}) + + return classes, new_items + + def __len__(self): + return len(self.items) + + def get_num_classes(self): + return len(self.classes) + + def get_class_list(self): + return self.classes + + def set_classes(self, classes): + self.classes = classes + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + def get_map_classid_to_classname(self): + return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) + + def __getitem__(self, idx): + return self.items[idx] + + def collate_fn(self, batch): + # get the batch class_ids + labels = [] + feats = [] + for item in batch: + utter_path = item["wav_file_path"] + class_name = item["class_name"] + + # get classid + class_id = self.classname_to_classid[class_name] + # load wav file + wav = self.load_wav(utter_path) + offset = random.randint(0, wav.shape[0] - self.seq_len) + wav = wav[offset : offset + self.seq_len] + + if self.augmentator is not None and self.data_augmentation_p: + if random.random() < self.data_augmentation_p: + wav = self.augmentator.apply_one(wav) + + if not self.use_torch_spec: + mel = self.ap.melspectrogram(wav) + feats.append(torch.FloatTensor(mel)) + else: + feats.append(torch.FloatTensor(wav)) + + labels.append(class_id) + + feats = torch.stack(feats) + labels = torch.LongTensor(labels) + + return feats, labels diff --git a/TTS/encoder/losses.py b/TTS/encoder/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5aa0fc48fe00aeedeff28ba48ed2af498ce582 --- /dev/null +++ b/TTS/encoder/losses.py @@ -0,0 +1,226 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +# adapted from https://github.com/cvqluu/GE2E-Loss +class GE2ELoss(nn.Module): + def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): + """ + Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector (e.g. d-vector) + Args: + - init_w (float): defines the initial value of w in Equation (5) of [1] + - init_b (float): definies the initial value of b in Equation (5) of [1] + """ + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.loss_method = loss_method + + print(" > Initialized Generalized End-to-End loss") + + assert self.loss_method in ["softmax", "contrast"] + + if self.loss_method == "softmax": + self.embed_loss = self.embed_loss_softmax + if self.loss_method == "contrast": + self.embed_loss = self.embed_loss_contrast + + # pylint: disable=R0201 + def calc_new_centroids(self, dvecs, centroids, spkr, utt): + """ + Calculates the new centroids excluding the reference utterance + """ + excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) + excl = torch.mean(excl, 0) + new_centroids = [] + for i, centroid in enumerate(centroids): + if i == spkr: + new_centroids.append(excl) + else: + new_centroids.append(centroid) + return torch.stack(new_centroids) + + def calc_cosine_sim(self, dvecs, centroids): + """ + Make the cosine similarity matrix with dims (N,M,N) + """ + cos_sim_matrix = [] + for spkr_idx, speaker in enumerate(dvecs): + cs_row = [] + for utt_idx, utterance in enumerate(speaker): + new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) + # vector based cosine similarity for speed + cs_row.append( + torch.clamp( + torch.mm( + utterance.unsqueeze(1).transpose(0, 1), + new_centroids.transpose(0, 1), + ) + / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), + 1e-6, + ) + ) + cs_row = torch.cat(cs_row, dim=0) + cos_sim_matrix.append(cs_row) + return torch.stack(cos_sim_matrix) + + # pylint: disable=R0201 + def embed_loss_softmax(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by taking softmax + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + # pylint: disable=R0201 + def embed_loss_contrast(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) + excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) + L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + def forward(self, x, _label=None): + """ + Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + centroids = torch.mean(x, 1) + cos_sim_matrix = self.calc_cosine_sim(x, centroids) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = self.w * cos_sim_matrix + self.b + L = self.embed_loss(x, cos_sim_matrix) + return L.mean() + + +# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py +class AngleProtoLoss(nn.Module): + """ + Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector + Args: + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, init_w=10.0, init_b=-5.0): + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.criterion = torch.nn.CrossEntropyLoss() + + print(" > Initialized Angular Prototypical loss") + + def forward(self, x, _label=None): + """ + Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + out_anchor = torch.mean(x[:, 1:, :], 1) + out_positive = x[:, 0, :] + num_speakers = out_anchor.size()[0] + + cos_sim_matrix = F.cosine_similarity( + out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), + out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), + ) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = cos_sim_matrix * self.w + self.b + label = torch.arange(num_speakers).to(cos_sim_matrix.device) + L = self.criterion(cos_sim_matrix, label) + return L + + +class SoftmaxLoss(nn.Module): + """ + Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + """ + + def __init__(self, embedding_dim, n_speakers): + super().__init__() + + self.criterion = torch.nn.CrossEntropyLoss() + self.fc = nn.Linear(embedding_dim, n_speakers) + + print("Initialised Softmax Loss") + + def forward(self, x, label=None): + # reshape for compatibility + x = x.reshape(-1, x.size()[-1]) + label = label.reshape(-1) + + x = self.fc(x) + L = self.criterion(x, label) + + return L + + def inference(self, embedding): + x = self.fc(embedding) + activations = torch.nn.functional.softmax(x, dim=1).squeeze(0) + class_id = torch.argmax(activations) + return class_id + + +class SoftmaxAngleProtoLoss(nn.Module): + """ + Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): + super().__init__() + + self.softmax = SoftmaxLoss(embedding_dim, n_speakers) + self.angleproto = AngleProtoLoss(init_w, init_b) + + print("Initialised SoftmaxAnglePrototypical Loss") + + def forward(self, x, label=None): + """ + Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + Lp = self.angleproto(x) + + Ls = self.softmax(x, label) + + return Ls + Lp diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7d7dd5a64cb792525e1dbc8aaaf900eaf63432 --- /dev/null +++ b/TTS/encoder/models/base_encoder.py @@ -0,0 +1,154 @@ +import numpy as np +import torch +import torchaudio +from coqpit import Coqpit +from torch import nn + +from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss +from TTS.utils.generic_utils import set_init_dict +from TTS.utils.io import load_fsspec + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + +class BaseEncoder(nn.Module): + """Base `encoder` class. Every new `encoder` model must inherit this. + + It defines common `encoder` specific functions. + """ + + # pylint: disable=W0102 + def __init__(self): + super(BaseEncoder, self).__init__() + + def get_torch_mel_spectrogram_class(self, audio_config): + return torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + @torch.no_grad() + def inference(self, x, l2_norm=True): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config["hop_length"] + + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + return embeddings + + def get_criterion(self, c: Coqpit, num_classes=None): + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method="softmax") + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + elif c.loss == "softmaxproto": + criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) + else: + raise Exception("The %s not is a loss supported" % c.loss) + return criterion + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + try: + self.load_state_dict(state["model"]) + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"], c) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + print(" > Criterion load ignored because of:", error) + + # instance and load the criterion for the encoder classifier in inference time + if ( + eval + and criterion is None + and "criterion" in state + and getattr(config, "map_classid_to_classname", None) is not None + ): + criterion = self.get_criterion(config, len(config.map_classid_to_classname)) + criterion.load_state_dict(state["criterion"]) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion diff --git a/TTS/encoder/models/lstm.py b/TTS/encoder/models/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..51852b5b820d181824b0db1a205cd5d7bd4fb20d --- /dev/null +++ b/TTS/encoder/models/lstm.py @@ -0,0 +1,99 @@ +import torch +from torch import nn + +from TTS.encoder.models.base_encoder import BaseEncoder + + +class LSTMWithProjection(nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder(BaseEncoder): + def __init__( + self, + input_dim, + proj_dim=256, + lstm_dim=768, + num_lstm_layers=3, + use_lstm_with_projection=True, + use_torch_spec=False, + audio_config=None, + ): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0.0) + elif "weight" in name: + nn.init.xavier_normal_(param) + + def forward(self, x, l2_norm=True): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x.squeeze_(1) + x = self.torch_spec(x) + x = self.instancenorm(x).transpose(1, 2) + d = self.layers(x) + if self.use_lstm_with_projection: + d = d[:, -1] + if l2_norm: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..84e9967f84c32472d757f003728757e43072f77d --- /dev/null +++ b/TTS/encoder/models/resnet.py @@ -0,0 +1,200 @@ +import torch +from torch import nn + +# from TTS.utils.audio import TorchSTFT +from TTS.encoder.models.base_encoder import BaseEncoder + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class ResNetSpeakerEncoder(BaseEncoder): + """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 + Adapted from: https://github.com/clovaai/voxceleb_trainer + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x diff --git a/TTS/encoder/requirements.txt b/TTS/encoder/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a486cc45ddb44591bd03c9c0df294fbe98c13884 --- /dev/null +++ b/TTS/encoder/requirements.txt @@ -0,0 +1,2 @@ +umap-learn +numpy>=1.17.0 diff --git a/TTS/encoder/utils/__init__.py b/TTS/encoder/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..91a896f60d272dc25cc6cfe62cf91c66b2f28e00 --- /dev/null +++ b/TTS/encoder/utils/generic_utils.py @@ -0,0 +1,184 @@ +import datetime +import glob +import os +import random +import re + +import numpy as np +from scipy import signal + +from TTS.encoder.models.lstm import LSTMSpeakerEncoder +from TTS.encoder.models.resnet import ResNetSpeakerEncoder +from TTS.utils.io import save_fsspec + + +class AugmentWAV(object): + def __init__(self, ap, augmentation_config): + + self.ap = ap + self.use_additive_noise = False + + if "additive" in augmentation_config.keys(): + self.additive_noise_config = augmentation_config["additive"] + additive_path = self.additive_noise_config["sounds_path"] + if additive_path: + self.use_additive_noise = True + # get noise types + self.additive_noise_types = [] + for key in self.additive_noise_config.keys(): + if isinstance(self.additive_noise_config[key], dict): + self.additive_noise_types.append(key) + + additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) + + self.noise_list = {} + + for wav_file in additive_files: + noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] + # ignore not listed directories + if noise_dir not in self.additive_noise_types: + continue + if not noise_dir in self.noise_list: + self.noise_list[noise_dir] = [] + self.noise_list[noise_dir].append(wav_file) + + print( + f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + ) + + self.use_rir = False + + if "rir" in augmentation_config.keys(): + self.rir_config = augmentation_config["rir"] + if self.rir_config["rir_path"]: + self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) + self.use_rir = True + + print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") + + self.create_augmentation_global_list() + + def create_augmentation_global_list(self): + if self.use_additive_noise: + self.global_noise_list = self.additive_noise_types + else: + self.global_noise_list = [] + if self.use_rir: + self.global_noise_list.append("RIR_AUG") + + def additive_noise(self, noise_type, audio): + + clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) + + noise_list = random.sample( + self.noise_list[noise_type], + random.randint( + self.additive_noise_config[noise_type]["min_num_noises"], + self.additive_noise_config[noise_type]["max_num_noises"], + ), + ) + + audio_len = audio.shape[0] + noises_wav = None + for noise in noise_list: + noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len] + + if noiseaudio.shape[0] < audio_len: + continue + + noise_snr = random.uniform( + self.additive_noise_config[noise_type]["min_snr_in_db"], + self.additive_noise_config[noise_type]["max_num_noises"], + ) + noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4) + noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio + + if noises_wav is None: + noises_wav = noise_wav + else: + noises_wav += noise_wav + + # if all possible files is less than audio, choose other files + if noises_wav is None: + return self.additive_noise(noise_type, audio) + + return audio + noises_wav + + def reverberate(self, audio): + audio_len = audio.shape[0] + + rir_file = random.choice(self.rir_files) + rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) + rir = rir / np.sqrt(np.sum(rir**2)) + return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] + + def apply_one(self, audio): + noise_type = random.choice(self.global_noise_list) + if noise_type == "RIR_AUG": + return self.reverberate(audio) + + return self.additive_noise(noise_type, audio) + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_encoder_model(config: "Coqpit"): + if config.model_params["model_name"].lower() == "lstm": + model = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + elif config.model_params["model_name"].lower() == "resnet": + model = ResNetSpeakerEncoder( + input_dim=config.model_params["input_dim"], + proj_dim=config.model_params["proj_dim"], + log_input=config.model_params.get("log_input", False), + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + return model + + +def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch): + checkpoint_path = "checkpoint_{}.pth".format(current_step) + checkpoint_path = os.path.join(out_path, checkpoint_path) + print(" | | > Checkpoint saving : {}".format(checkpoint_path)) + + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "criterion": criterion.state_dict(), + "step": current_step, + "epoch": epoch, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + save_fsspec(state, checkpoint_path) + + +def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch): + if model_loss < best_loss: + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict(), + "criterion": criterion.state_dict(), + "step": current_step, + "epoch": epoch, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + best_loss = model_loss + bestmodel_path = "best_model.pth" + bestmodel_path = os.path.join(out_path, bestmodel_path) + print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) + save_fsspec(state, bestmodel_path) + return best_loss diff --git a/TTS/encoder/utils/io.py b/TTS/encoder/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..d1dad3e24d234cdcb9616fb14bc87919c7e20291 --- /dev/null +++ b/TTS/encoder/utils/io.py @@ -0,0 +1,38 @@ +import datetime +import os + +from TTS.utils.io import save_fsspec + + +def save_checkpoint(model, optimizer, model_loss, out_path, current_step): + checkpoint_path = "checkpoint_{}.pth".format(current_step) + checkpoint_path = os.path.join(out_path, checkpoint_path) + print(" | | > Checkpoint saving : {}".format(checkpoint_path)) + + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "step": current_step, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + save_fsspec(state, checkpoint_path) + + +def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): + if model_loss < best_loss: + new_state_dict = model.state_dict() + state = { + "model": new_state_dict, + "optimizer": optimizer.state_dict(), + "step": current_step, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + best_loss = model_loss + bestmodel_path = "best_model.pth" + bestmodel_path = os.path.join(out_path, bestmodel_path) + print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) + save_fsspec(state, bestmodel_path) + return best_loss diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py new file mode 100644 index 0000000000000000000000000000000000000000..b93baf9e60f0d5c35a4e86f6746e29f6097174b5 --- /dev/null +++ b/TTS/encoder/utils/prepare_voxceleb.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Only support eager mode and TF>=2.0.0 +# pylint: disable=no-member, invalid-name, relative-beyond-top-level +# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes +""" voxceleb 1 & 2 """ + +import hashlib +import os +import subprocess +import sys +import zipfile + +import pandas +import soundfile as sf +from absl import logging + +SUBSETS = { + "vox1_dev_wav": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", + ], + "vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], +} + +MD5SUM = { + "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", + "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", + "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", +} + +USER = {"user": "", "password": ""} + +speaker_id_dict = {} + + +def download_and_extract(directory, subset, urls): + """Download and extract the given split of dataset. + + Args: + directory: the directory where to put the downloaded data. + subset: subset name of the corpus. + urls: the list of urls to download the data file. + """ + os.makedirs(directory, exist_ok=True) + + try: + for url in urls: + zip_filepath = os.path.join(directory, url.split("/")[-1]) + if os.path.exists(zip_filepath): + continue + logging.info("Downloading %s to %s" % (url, zip_filepath)) + subprocess.call( + "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), + shell=True, + ) + + statinfo = os.stat(zip_filepath) + logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + + # concatenate all parts into zip files + if ".zip" not in zip_filepath: + zip_filepath = "_".join(zip_filepath.split("_")[:-1]) + subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) + zip_filepath += ".zip" + extract_path = zip_filepath.strip(".zip") + + # check zip file md5sum + with open(zip_filepath, "rb") as f_zip: + md5 = hashlib.md5(f_zip.read()).hexdigest() + if md5 != MD5SUM[subset]: + raise ValueError("md5sum of %s mismatch" % zip_filepath) + + with zipfile.ZipFile(zip_filepath, "r") as zfile: + zfile.extractall(directory) + extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) + subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) + finally: + # os.remove(zip_filepath) + pass + + +def exec_cmd(cmd): + """Run a command in a subprocess. + Args: + cmd: command line to be executed. + Return: + int, the return code. + """ + try: + retcode = subprocess.call(cmd, shell=True) + if retcode < 0: + logging.info(f"Child was terminated by signal {retcode}") + except OSError as e: + logging.info(f"Execution failed: {e}") + retcode = -999 + return retcode + + +def decode_aac_with_ffmpeg(aac_file, wav_file): + """Decode a given AAC file into WAV using ffmpeg. + Args: + aac_file: file path to input AAC file. + wav_file: file path to output WAV file. + Return: + bool, True if success. + """ + cmd = f"ffmpeg -i {aac_file} {wav_file}" + logging.info(f"Decoding aac file using command line: {cmd}") + ret = exec_cmd(cmd) + if ret != 0: + logging.error(f"Failed to decode aac file with retcode {ret}") + logging.error("Please check your ffmpeg installation.") + return False + return True + + +def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): + """Optionally convert AAC to WAV and make speaker labels. + Args: + input_dir: the directory which holds the input dataset. + subset: the name of the specified subset. e.g. vox1_dev_wav + output_dir: the directory to place the newly generated csv files. + output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv + """ + + logging.info("Preprocessing audio and label for subset %s" % subset) + source_dir = os.path.join(input_dir, subset) + + files = [] + # Convert all AAC file into WAV format. At the same time, generate the csv + for root, _, filenames in os.walk(source_dir): + for filename in filenames: + name, ext = os.path.splitext(filename) + if ext.lower() == ".wav": + _, ext2 = os.path.splitext(name) + if ext2: + continue + wav_file = os.path.join(root, filename) + elif ext.lower() == ".m4a": + # Convert AAC to WAV. + aac_file = os.path.join(root, filename) + wav_file = aac_file + ".wav" + if not os.path.exists(wav_file): + if not decode_aac_with_ffmpeg(aac_file, wav_file): + raise RuntimeError("Audio decoding failed.") + else: + continue + speaker_name = root.split(os.path.sep)[-2] + if speaker_name not in speaker_id_dict: + num = len(speaker_id_dict) + speaker_id_dict[speaker_name] = num + # wav_filesize = os.path.getsize(wav_file) + wav_length = len(sf.read(wav_file)[0]) + files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) + + # Write to CSV file which contains four columns: + # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". + csv_file_path = os.path.join(output_dir, output_file) + df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + df.to_csv(csv_file_path, index=False, sep="\t") + logging.info("Successfully generated csv file {}".format(csv_file_path)) + + +def processor(directory, subset, force_process): + """download and process""" + urls = SUBSETS + if subset not in urls: + raise ValueError(subset, "is not in voxceleb") + + subset_csv = os.path.join(directory, subset + ".csv") + if not force_process and os.path.exists(subset_csv): + return subset_csv + + logging.info("Downloading and process the voxceleb in %s", directory) + logging.info("Preparing subset %s", subset) + download_and_extract(directory, subset, urls[subset]) + convert_audio_and_make_label(directory, subset, directory, subset + ".csv") + logging.info("Finished downloading and processing") + return subset_csv + + +if __name__ == "__main__": + logging.set_verbosity(logging.INFO) + if len(sys.argv) != 4: + print("Usage: python prepare_data.py save_directory user password") + sys.exit() + + DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] + for SUBSET in SUBSETS: + processor(DIR, SUBSET, False) diff --git a/TTS/encoder/utils/samplers.py b/TTS/encoder/utils/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..08256b347d59368193cee1301b6b1997078d8410 --- /dev/null +++ b/TTS/encoder/utils/samplers.py @@ -0,0 +1,114 @@ +import random + +from torch.utils.data.sampler import Sampler, SubsetRandomSampler + + +class SubsetSampler(Sampler): + """ + Samples elements sequentially from a given list of indices. + + Args: + indices (list): a sequence of indices + """ + + def __init__(self, indices): + super().__init__(indices) + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in range(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class PerfectBatchSampler(Sampler): + """ + Samples a mini-batch of indices for a balanced class batching + + Args: + dataset_items(list): dataset items to sample from. + classes (list): list of classes of dataset_items to sample from. + batch_size (int): total number of samples to be sampled in a mini-batch. + num_gpus (int): number of GPU in the data parallel mode. + shuffle (bool): if True, samples randomly, otherwise samples sequentially. + drop_last (bool): if True, drops last incomplete batch. + """ + + def __init__( + self, + dataset_items, + classes, + batch_size, + num_classes_in_batch, + num_gpus=1, + shuffle=True, + drop_last=False, + label_key="class_name", + ): + super().__init__(dataset_items) + assert ( + batch_size % (num_classes_in_batch * num_gpus) == 0 + ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." + + label_indices = {} + for idx, item in enumerate(dataset_items): + label = item[label_key] + if label not in label_indices.keys(): + label_indices[label] = [idx] + else: + label_indices[label].append(idx) + + if shuffle: + self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] + else: + self._samplers = [SubsetSampler(label_indices[key]) for key in classes] + + self._batch_size = batch_size + self._drop_last = drop_last + self._dp_devices = num_gpus + self._num_classes_in_batch = num_classes_in_batch + + def __iter__(self): + + batch = [] + if self._num_classes_in_batch != len(self._samplers): + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + else: + valid_samplers_idx = None + + iters = [iter(s) for s in self._samplers] + done = False + + while True: + b = [] + for i, it in enumerate(iters): + if valid_samplers_idx is not None and i not in valid_samplers_idx: + continue + idx = next(it, None) + if idx is None: + done = True + break + b.append(idx) + if done: + break + batch += b + if len(batch) == self._batch_size: + yield batch + batch = [] + if valid_samplers_idx is not None: + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + + if not self._drop_last: + if len(batch) > 0: + groups = len(batch) // self._num_classes_in_batch + if groups % self._dp_devices == 0: + yield batch + else: + batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] + if len(batch) > 0: + yield batch + + def __len__(self): + class_batch_size = self._batch_size // self._num_classes_in_batch + return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..7c58a232e7a146bb24718700527ab80e62a1ab1a --- /dev/null +++ b/TTS/encoder/utils/training.py @@ -0,0 +1,99 @@ +import os +from dataclasses import dataclass, field + +from coqpit import Coqpit +from trainer import TrainerArgs, get_last_checkpoint +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger + +from TTS.config import load_config, register_config +from TTS.tts.utils.text.characters import parse_symbols +from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch +from TTS.utils.io import copy_model_files + + +@dataclass +class TrainArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def getarguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def process_args(args, config=None): + """Process parsed comand line arguments and initialize the config if not provided. + Args: + args (argparse.Namespace or dict like): Parsed input arguments. + config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. + Returns: + c (TTS.utils.io.AttrDict): Config paramaters. + out_path (str): Path to save models and logging. + audio_path (str): Path to save generated test audios. + c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does + logging to the console. + dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging + TODO: + - Interactive config definition. + """ + if isinstance(args, tuple): + args, coqpit_overrides = args + if args.continue_path: + # continue a previous training from its output folder + experiment_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + args.restore_path, best_model = get_last_checkpoint(args.continue_path) + if not args.best_path: + args.best_path = best_model + # init config if not already defined + if config is None: + if args.config_path: + # init from a file + config = load_config(args.config_path) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(coqpit_overrides) + config = register_config(config_base.model)() + # override values from command-line args + config.parse_known_args(coqpit_overrides, relaxed_parser=True) + experiment_path = args.continue_path + if not experiment_path: + experiment_path = get_experiment_folder_path(config.output_path, config.run_name) + audio_path = os.path.join(experiment_path, "test_audios") + config.output_log_path = experiment_path + # setup rank 0 process in distributed training + dashboard_logger = None + if args.rank == 0: + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + # if model characters are not set in the config file + # save the default set to the config file for future + # compatibility. + if config.has("characters") and config.characters is None: + used_characters = parse_symbols() + new_fields["characters"] = used_characters + copy_model_files(config, experiment_path, new_fields) + dashboard_logger = logger_factory(config, experiment_path) + c_logger = ConsoleLogger() + return config, experiment_path, audio_path, c_logger, dashboard_logger + + +def init_arguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def init_training(config: Coqpit = None): + """Initialization of a training run.""" + parser = init_arguments() + args = parser.parse_known_args() + config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config) + return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger diff --git a/TTS/encoder/utils/visual.py b/TTS/encoder/utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..f2db2f3fa3408f96a04f7932438f175c6ec19c51 --- /dev/null +++ b/TTS/encoder/utils/visual.py @@ -0,0 +1,50 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import umap + +matplotlib.use("Agg") + + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=np.float, + ) + / 255 +) + + +def plot_embeddings(embeddings, num_classes_in_batch): + num_utter_per_class = embeddings.shape[0] // num_classes_in_batch + + # if necessary get just the first 10 classes + if num_classes_in_batch > 10: + num_classes_in_batch = 10 + embeddings = embeddings[: num_classes_in_batch * num_utter_per_class] + + model = umap.UMAP() + projection = model.fit_transform(embeddings) + ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class) + colors = [colormap[i] for i in ground_truth] + fig, ax = plt.subplots(figsize=(16, 10)) + _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection") + plt.tight_layout() + plt.savefig("umap") + return fig diff --git a/TTS/model.py b/TTS/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a53b916a3f3844925ebf57ba721c3be0303985d0 --- /dev/null +++ b/TTS/model.py @@ -0,0 +1,56 @@ +from abc import abstractmethod +from typing import Dict + +import torch +from coqpit import Coqpit +from trainer import TrainerModel + +# pylint: skip-file + + +class BaseTrainerModel(TrainerModel): + """BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS. + + Every new 🐸TTS model must inherit it. + """ + + @staticmethod + @abstractmethod + def init_from_config(config: Coqpit): + """Init the model and all its attributes from the given config. + + Override this depending on your model. + """ + ... + + @abstractmethod + def inference(self, input: torch.Tensor, aux_input={}) -> Dict: + """Forward pass for inference. + + It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` + is considered to be the main output and you can add any other auxiliary outputs as you want. + + We don't use `*kwargs` since it is problematic with the TorchScript API. + + Args: + input (torch.Tensor): [description] + aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. + + Returns: + Dict: [description] + """ + outputs_dict = {"model_outputs": None} + ... + return outputs_dict + + @abstractmethod + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: + """Load a model checkpoint gile and get ready for training or inference. + + Args: + config (Coqpit): Model configuration. + checkpoint_path (str): Path to the model checkpoint file. + eval (bool, optional): If true, init model for inference else for training. Defaults to False. + strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. + """ + ... diff --git a/TTS/server/README.md b/TTS/server/README.md new file mode 100644 index 0000000000000000000000000000000000000000..270656c4e39dc11636efbb1ba51eba7c9b4a8f04 --- /dev/null +++ b/TTS/server/README.md @@ -0,0 +1,18 @@ +# :frog: TTS demo server +Before you use the server, make sure you [install](https://github.com/coqui-ai/TTS/tree/dev#install-tts)) :frog: TTS properly. Then, you can follow the steps below. + +**Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` end point on the terminal. + +Examples runs: + +List officially released models. +```python TTS/server/server.py --list_models ``` + +Run the server with the official models. +```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan``` + +Run the server with the official models on a GPU. +```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda True``` + +Run the server with a custom models. +```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth --vocoder_config /path/to/vocoder/config.json``` diff --git a/TTS/server/__init__.py b/TTS/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/server/conf.json b/TTS/server/conf.json new file mode 100644 index 0000000000000000000000000000000000000000..49b6c09c3848a224dfb39a1f653aa1b289a4b6e5 --- /dev/null +++ b/TTS/server/conf.json @@ -0,0 +1,12 @@ +{ + "tts_path":"/media/erogol/data_ssd/Models/libri_tts/5049/", // tts model root folder + "tts_file":"best_model.pth", // tts checkpoint file + "tts_config":"config.json", // tts config.json file + "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. + "vocoder_config":null, + "vocoder_file": null, + "is_wavernn_batched":true, + "port": 5002, + "use_cuda": true, + "debug": true +} diff --git a/TTS/server/server.py b/TTS/server/server.py new file mode 100644 index 0000000000000000000000000000000000000000..89fce493db93588c8ae69fec35bf5ce6c1a0158b --- /dev/null +++ b/TTS/server/server.py @@ -0,0 +1,190 @@ +#!flask/bin/python +import argparse +import io +import json +import os +import sys +from pathlib import Path +from typing import Union + +from flask import Flask, render_template, request, send_file + +from TTS.config import load_config +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer + + +def create_argparser(): + def convert_boolean(x): + return x.lower() in ["true", "1", "yes"] + + parser = argparse.ArgumentParser() + parser.add_argument( + "--list_models", + type=convert_boolean, + nargs="?", + const=True, + default=False, + help="list available pre-trained tts and vocoder models.", + ) + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/tacotron2-DDC", + help="Name of one of the pre-trained tts models in format //", + ) + parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.") + + # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model file.", + ) + parser.add_argument( + "--vocoder_path", + type=str, + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", + default=None, + ) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--port", type=int, default=5002, help="port to listen on.") + parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") + parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") + parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.") + return parser + + +# parse the args +args = create_argparser().parse_args() + +path = Path(__file__).parent / "../.models.json" +manager = ModelManager(path) + +if args.list_models: + manager.list_models() + sys.exit() + +# update in-use models to the specified released models. +model_path = None +config_path = None +speakers_file_path = None +vocoder_path = None +vocoder_config_path = None + +# CASE1: list pre-trained TTS models +if args.list_models: + manager.list_models() + sys.exit() + +# CASE2: load pre-trained model paths +if args.model_name is not None and not args.model_path: + model_path, config_path, model_item = manager.download_model(args.model_name) + args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + +if args.vocoder_name is not None and not args.vocoder_path: + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) + +# CASE3: set custom model paths +if args.model_path is not None: + model_path = args.model_path + config_path = args.config_path + speakers_file_path = args.speakers_file_path + +if args.vocoder_path is not None: + vocoder_path = args.vocoder_path + vocoder_config_path = args.vocoder_config_path + +# load models +synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint="", + encoder_config="", + use_cuda=args.use_cuda, +) + +use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( + synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None +) + +speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) +# TODO: set this from SpeakerManager +use_gst = synthesizer.tts_config.get("use_gst", False) +app = Flask(__name__) + + +def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]: + """Transform an uri style_wav, in either a string (path to wav file to be use for style transfer) + or a dict (gst tokens/values to be use for styling) + + Args: + style_wav (str): uri + + Returns: + Union[str, dict]: path to file (str) or gst style (dict) + """ + if style_wav: + if os.path.isfile(style_wav) and style_wav.endswith(".wav"): + return style_wav # style_wav is a .wav file located on the server + + style_wav = json.loads(style_wav) + return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...} + return None + + +@app.route("/") +def index(): + return render_template( + "index.html", + show_details=args.show_details, + use_multi_speaker=use_multi_speaker, + speaker_ids=speaker_manager.ids if speaker_manager is not None else None, + use_gst=use_gst, + ) + + +@app.route("/details") +def details(): + model_config = load_config(args.tts_config) + if args.vocoder_config is not None and os.path.isfile(args.vocoder_config): + vocoder_config = load_config(args.vocoder_config) + else: + vocoder_config = None + + return render_template( + "details.html", + show_details=args.show_details, + model_config=model_config, + vocoder_config=vocoder_config, + args=args.__dict__, + ) + + +@app.route("/api/tts", methods=["GET"]) +def tts(): + text = request.args.get("text") + speaker_idx = request.args.get("speaker_id", "") + style_wav = request.args.get("style_wav", "") + style_wav = style_wav_uri_to_dict(style_wav) + print(" > Model input: {}".format(text)) + print(" > Speaker Idx: {}".format(speaker_idx)) + wavs = synthesizer.tts(text, speaker_name=speaker_idx, style_wav=style_wav) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + return send_file(out, mimetype="audio/wav") + + +def main(): + app.run(debug=args.debug, host="::", port=args.port) + + +if __name__ == "__main__": + main() diff --git a/TTS/server/static/coqui-log-green-TTS.png b/TTS/server/static/coqui-log-green-TTS.png new file mode 100644 index 0000000000000000000000000000000000000000..6ad188b8c03a170097c0393c6769996f03cf9054 Binary files /dev/null and b/TTS/server/static/coqui-log-green-TTS.png differ diff --git a/TTS/server/templates/details.html b/TTS/server/templates/details.html new file mode 100644 index 0000000000000000000000000000000000000000..51c9ed85a83ac0aab045623ee1e6c430fbe51b9d --- /dev/null +++ b/TTS/server/templates/details.html @@ -0,0 +1,131 @@ + + + + + + + + + + + TTS engine + + + + + + + + + + Fork me on GitHub + + {% if show_details == true %} + +
+ Model details +
+ +
+
+ CLI arguments: + + + + + + + {% for key, value in args.items() %} + + + + + + + {% endfor %} +
CLI key Value
{{ key }}{{ value }}
+
+

+ +
+ + {% if model_config != None %} + +
+ Model config: + + + + + + + + + {% for key, value in model_config.items() %} + + + + + + + {% endfor %} + +
Key Value
{{ key }}{{ value }}
+
+ + {% endif %} + +

+ + + +
+ {% if vocoder_config != None %} +
+ Vocoder model config: + + + + + + + + + {% for key, value in vocoder_config.items() %} + + + + + + + {% endfor %} + + +
Key Value
{{ key }}{{ value }}
+
+ {% endif %} +

+ + {% else %} +
+ Please start server with --show_details=true to see details. +
+ + {% endif %} + + + + \ No newline at end of file diff --git a/TTS/server/templates/index.html b/TTS/server/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..b0eab291a2c78e678709aba7dddb2b97b8e94b0f --- /dev/null +++ b/TTS/server/templates/index.html @@ -0,0 +1,143 @@ + + + + + + + + + + + TTS engine + + + + + + + + + + Fork me on GitHub + + + + + +
+
+
+ + +
    +
+ + {%if use_gst%} + + {%endif%} + + +

+ + {%if use_multi_speaker%} + Choose a speaker: +

+ {%endif%} + + {%if show_details%} +

+ {%endif%} + +

+
+
+
+ + + + + + + \ No newline at end of file diff --git a/TTS/tts/__init__.py b/TTS/tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/configs/__init__.py b/TTS/tts/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3146ac1c116cb807a81889b7a9ab223b9a051036 --- /dev/null +++ b/TTS/tts/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +# configs_dir = os.path.dirname(__file__) +# for file in os.listdir(configs_dir): +# path = os.path.join(configs_dir, file) +# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): +# config_name = file[: file.find(".py")] if file.endswith(".py") else file +# module = importlib.import_module("TTS.tts.configs." + config_name) +# for attribute_name in dir(module): +# attribute = getattr(module, attribute_name) + +# if isclass(attribute): +# # Add the class to this package's variables +# globals()[attribute_name] = attribute diff --git a/TTS/tts/configs/align_tts_config.py b/TTS/tts/configs/align_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..317a01af53ce26914d83610a913eb44b5836dac2 --- /dev/null +++ b/TTS/tts/configs/align_tts_config.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.align_tts import AlignTTSArgs + + +@dataclass +class AlignTTSConfig(BaseTTSConfig): + """Defines parameters for AlignTTS model. + Example: + + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `align_tts`. + positional_encoding (bool): + enable / disable positional encoding applied to the encoder output. Defaults to True. + hidden_channels (int): + Base number of hidden channels. Defines all the layers expect ones defined by the specific encoder or decoder + parameters. Defaults to 256. + hidden_channels_dp (int): + Number of hidden channels of the duration predictor's layers. Defaults to 256. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `fftransformer`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + decoder_type (str): + Type of the decoder used by the model. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `fftransformer`. + decoder_params (dict): + Parameters used to define the decoder network. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + phase_start_steps (List[int]): + A list of number of steps required to start the next training phase. AlignTTS has 4 different training + phases. Thus you need to define 4 different values to enable phase based training. If None, it + trains the whole model together. Defaults to None. + ssim_alpha (float): + Weight for the SSIM loss. If set <= 0, disables the SSIM loss. Defaults to 1.0. + duration_loss_alpha (float): + Weight for the duration predictor's loss. Defaults to 1.0. + mdn_alpha (float): + Weight for the MDN loss. Defaults to 1.0. + spec_loss_alpha (float): + Weight for the MSE spectrogram loss. If set <= 0, disables the L1 loss. Defaults to 1.0. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage.""" + + model: str = "align_tts" + # model specific params + model_args: AlignTTSArgs = field(default_factory=AlignTTSArgs) + phase_start_steps: List[int] = None + + ssim_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + mdn_alpha: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = None + lr_scheduler_params: dict = None + lr: float = 1e-4 + grad_clip: float = 5.0 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py new file mode 100644 index 0000000000000000000000000000000000000000..26ccfdd54037a63c4d5d638109cd30524f8f22ca --- /dev/null +++ b/TTS/tts/configs/fast_pitch_config.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastPitchConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fast_pitch_config import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_pitch" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = ForwardTTSArgs() + + # data loader params + return_wav: bool = False + compute_linear_spec: bool = False + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + spk_encoder_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + aligner_epochs: int = 1000 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py new file mode 100644 index 0000000000000000000000000000000000000000..16a76e215f4d47d086bea827d2b6ccc61524e5c1 --- /dev/null +++ b/TTS/tts/configs/fast_speech_config.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastSpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastSpeech model. + + Example: + + >>> from TTS.tts.configs.fast_speech_config import FastSpeechConfig + >>> config = FastSpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastSpeechArgs` for more details. Defaults to `FastSpeechArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_speech" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=False) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f42f3e5a510bacf1b2312ccea7d46201bbcb774f --- /dev/null +++ b/TTS/tts/configs/glow_tts_config.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class GlowTTSConfig(BaseTTSConfig): + """Defines parameters for GlowTTS model. + + Example: + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> config = GlowTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `glow_tts`. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `rel_pos_transformers`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `{"kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "num_heads": 2, "hidden_channels_ffn": 768}` + use_encoder_prenet (bool): + enable / disable the use of a prenet for the encoder. Defaults to True. + hidden_channels_enc (int): + Number of base hidden channels used by the encoder network. It defines the input and the output channel sizes, + and for some encoder types internal hidden channels sizes too. Defaults to 192. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 192 as in the original work. + hidden_channels_dp (int): + Number of layer channels of the duration predictor network. Defaults to 256 as in the original work. + mean_only (bool): + If true predict only the mean values by the decoder flow. Defaults to True. + out_channels (int): + Number of channels of the model output tensor. Defaults to 80. + num_flow_blocks_dec (int): + Number of decoder blocks. Defaults to 12. + inference_noise_scale (float): + Noise scale used at inference. Defaults to 0.33. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_block_layers (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate for decoder. Defaults to 0.1. + num_speaker (int): + Number of speaker to define the size of speaker embedding layer. Defaults to 0. + c_in_channels (int): + Number of speaker embedding channels. It is set to 512 if embeddings are learned. Defaults to 0. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + mean_only (bool): + If True, encoder only computes mean value and uses constant variance for each time step. Defaults to true. + encoder_type (str): + Encoder module type. Possible values are`["rel_pos_transformer", "gated_conv", "residual_conv_bn", "time_depth_separable"]` + Check `TTS.tts.layers.glow_tts.encoder` for more details. Defaults to `rel_pos_transformers` as in the original paper. + encoder_params (dict): + Encoder module parameters. Defaults to None. + d_vector_dim (int): + Channels of external speaker embedding vectors. Defaults to 0. + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + style_wav_for_test (str): + Path to the wav file used for changing the style of the speech. Defaults to None. + inference_noise_scale (float): + Variance used for sampling the random noise added to the decoder's input at inference. Defaults to 0.0. + length_scale (float): + Multiply the predicted durations with this value to change the speech speed. Defaults to 1. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "glow_tts" + + # model params + num_chars: int = None + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) + use_encoder_prenet: bool = True + hidden_channels_enc: int = 192 + hidden_channels_dec: int = 192 + hidden_channels_dp: int = 256 + dropout_p_dp: float = 0.1 + dropout_p_dec: float = 0.05 + mean_only: bool = True + out_channels: int = 80 + num_flow_blocks_dec: int = 12 + inference_noise_scale: float = 0.33 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_block_layers: int = 4 + num_speakers: int = 0 + c_in_channels: int = 0 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + "input_length": None, + } + ) + d_vector_dim: int = 0 + + # training params + data_dep_init_steps: int = 10 + + # inference params + style_wav_for_test: str = None + inference_noise_scale: float = 0.0 + length_scale: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + grad_clip: float = 5.0 + lr: float = 1e-3 + + # overrides + min_seq_len: int = 3 + max_seq_len: int = 500 + r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it. + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..4704687c268780ed518f8c6a4ca64808dbab8e65 --- /dev/null +++ b/TTS/tts/configs/shared_configs.py @@ -0,0 +1,335 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import Coqpit, check_argument + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class GSTConfig(Coqpit): + """Defines the Global Style Token Module + + Args: + gst_style_input_wav (str): + Path to the wav file used to define the style of the output speech at inference. Defaults to None. + + gst_style_input_weights (dict): + Defines the weights for each style token used at inference. Defaults to None. + + gst_embedding_dim (int): + Defines the size of the GST embedding vector dimensions. Defaults to 256. + + gst_num_heads (int): + Number of attention heads used by the multi-head attention. Defaults to 4. + + gst_num_style_tokens (int): + Number of style token vectors. Defaults to 10. + """ + + gst_style_input_wav: str = None + gst_style_input_weights: dict = None + gst_embedding_dim: int = 256 + gst_use_speaker_embedding: bool = False + gst_num_heads: int = 4 + gst_num_style_tokens: int = 10 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("gst_style_input_weights", c, restricted=False) + check_argument("gst_style_input_wav", c, restricted=False) + check_argument("gst_embedding_dim", c, restricted=True, min_val=0, max_val=1000) + check_argument("gst_use_speaker_embedding", c, restricted=False) + check_argument("gst_num_heads", c, restricted=True, min_val=2, max_val=10) + check_argument("gst_num_style_tokens", c, restricted=True, min_val=1, max_val=1000) + + +@dataclass +class CapacitronVAEConfig(Coqpit): + """Defines the capacitron VAE Module + Args: + capacitron_capacity (int): + Defines the variational capacity limit of the prosody embeddings. Defaults to 150. + capacitron_VAE_embedding_dim (int): + Defines the size of the Capacitron embedding vector dimension. Defaults to 128. + capacitron_use_text_summary_embeddings (bool): + If True, use a text summary embedding in Capacitron. Defaults to True. + capacitron_text_summary_embedding_dim (int): + Defines the size of the capacitron text embedding vector dimension. Defaults to 128. + capacitron_use_speaker_embedding (bool): + if True use speaker embeddings in Capacitron. Defaults to False. + capacitron_VAE_loss_alpha (float): + Weight for the VAE loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + capacitron_grad_clip (float): + Gradient clipping value for all gradients except beta. Defaults to 5.0 + """ + + capacitron_loss_alpha: int = 1 + capacitron_capacity: int = 150 + capacitron_VAE_embedding_dim: int = 128 + capacitron_use_text_summary_embeddings: bool = True + capacitron_text_summary_embedding_dim: int = 128 + capacitron_use_speaker_embedding: bool = False + capacitron_VAE_loss_alpha: float = 0.25 + capacitron_grad_clip: float = 5.0 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("capacitron_capacity", c, restricted=True, min_val=10, max_val=500) + check_argument("capacitron_VAE_embedding_dim", c, restricted=True, min_val=16, max_val=1024) + check_argument("capacitron_use_speaker_embedding", c, restricted=False) + check_argument("capacitron_text_summary_embedding_dim", c, restricted=False, min_val=16, max_val=512) + check_argument("capacitron_VAE_loss_alpha", c, restricted=False) + check_argument("capacitron_grad_clip", c, restricted=False) + + +@dataclass +class CharactersConfig(Coqpit): + """Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses. + + Args: + characters_class (str): + Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on + the configuration. Defaults to None. + + vocab_dict (dict): + Defines the vocabulary dictionary used to encode the characters. Defaults to None. + + pad (str): + characters in place of empty padding. Defaults to None. + + eos (str): + characters showing the end of a sentence. Defaults to None. + + bos (str): + characters showing the beginning of a sentence. Defaults to None. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + characters (str): + character set used by the model. Characters not in this list are ignored when converting input text to + a list of sequence IDs. Defaults to None. + + punctuations (str): + characters considered as punctuation as parsing the input sentence. Defaults to None. + + phonemes (str): + characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new + models. Defaults to None. + + is_unique (bool): + remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old + models trained with character lists with duplicates. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + characters_class: str = None + + # using BaseVocabulary + vocab_dict: Dict = None + + # using on BaseCharacters + pad: str = None + eos: str = None + bos: str = None + blank: str = None + characters: str = None + punctuations: str = None + phonemes: str = None + is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates + is_sorted: bool = True + + +@dataclass +class BaseTTSConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + + audio (BaseAudioConfig): + Audio processor config object instance. + + use_phonemes (bool): + enable / disable phoneme use. + + phonemizer (str): + Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`. + Defaults to None. + + phoneme_language (str): + Language code for the phonemizer. You can check the list of supported languages by running + `python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None. + + compute_input_seq_cache (bool): + enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of + the training, It allows faster data loader time and precise limitation with `max_seq_len` and + `min_seq_len`. + + text_cleaner (str): + Name of the text cleaner used for cleaning and formatting transcripts. + + enable_eos_bos_chars (bool): + enable / disable the use of eos and bos characters. + + test_senteces_file (str): + Path to a txt file that has sentences used at test time. The file must have a sentence per line. + + phoneme_cache_path (str): + Path to the output folder caching the computed phonemes for each sample. + + characters (CharactersConfig): + Instance of a CharactersConfig class. + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + sort_by_audio_len (bool): + If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. + + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. + + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). + + compute_f0 (int): + (Not in use yet). + + compute_linear_spec (bool): + If True data loader computes and returns linear spectrograms alongside the other data. + + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + + use_noise_augment (bool): + Augment the input audio with random noise. + + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + + add_blank (bool): + Add blank characters between each other two characters. It improves performance for some models at expense + of slower run-time due to the longer input sequence. + + datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # phoneme settings + use_phonemes: bool = False + phonemizer: str = None + phoneme_language: str = None + compute_input_seq_cache: bool = False + text_cleaner: str = None + enable_eos_bos_chars: bool = False + test_sentences_file: str = "" + phoneme_cache_path: str = None + # vocabulary parameters + characters: CharactersConfig = None + add_blank: bool = False + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + sort_by_audio_len: bool = False + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") + compute_f0: bool = False + compute_linear_spec: bool = False + precompute_num_workers: int = 0 + use_noise_augment: bool = False + start_by_longest: bool = False + # dataset + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = "radam" + optimizer_params: dict = None + # scheduler + lr_scheduler: str = "" + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf5101fcad2479e87836c827658c88addfd7cc6 --- /dev/null +++ b/TTS/tts/configs/speedy_speech_config.py @@ -0,0 +1,192 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class SpeedySpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as SpeedySpeech model. + + Example: + + >>> from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig + >>> config = SpeedySpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `speedy_speech`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `RAdam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `l1`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `huber`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "speedy_speech" + base_model: str = "forward_tts" + + # set model args as SpeedySpeech + model_args: ForwardTTSArgs = ForwardTTSArgs( + use_pitch=False, + encoder_type="residual_conv_bn", + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13, + }, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + out_channels=80, + hidden_channels=128, + positional_encoding=True, + detach_duration_predictor=True, + ) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "l1" + duration_loss_type: str = "huber" + use_ssim_loss: bool = False + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 0.3 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/tacotron2_config.py b/TTS/tts/configs/tacotron2_config.py new file mode 100644 index 0000000000000000000000000000000000000000..95b65202218cf3aa0dd70c8d8cd55a3f913ed308 --- /dev/null +++ b/TTS/tts/configs/tacotron2_config.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from TTS.tts.configs.tacotron_config import TacotronConfig + + +@dataclass +class Tacotron2Config(TacotronConfig): + """Defines parameters for Tacotron2 based models. + + Example: + + >>> from TTS.tts.configs.tacotron2_config import Tacotron2Config + >>> config = Tacotron2Config() + + Check `TacotronConfig` for argument descriptions. + """ + + model: str = "tacotron2" + out_channels: int = 80 + encoder_in_features: int = 512 + decoder_in_features: int = 512 diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e25609ffcf685fae91fc40eaa2201e728ebc73c4 --- /dev/null +++ b/TTS/tts/configs/tacotron_config.py @@ -0,0 +1,235 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig, CapacitronVAEConfig, GSTConfig + + +@dataclass +class TacotronConfig(BaseTTSConfig): + """Defines parameters for Tacotron based models. + + Example: + + >>> from TTS.tts.configs.tacotron_config import TacotronConfig + >>> config = TacotronConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Tacotron`. + use_gst (bool): + enable / disable the use of Global Style Token modules. Defaults to False. + gst (GSTConfig): + Instance of `GSTConfig` class. + gst_style_input (str): + Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and + this is not defined, the model uses a zero vector as an input. Defaults to None. + use_capacitron_vae (bool): + enable / disable the use of Capacitron modules. Defaults to False. + capacitron_vae (CapacitronConfig): + Instance of `CapacitronConfig` class. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + num_speakers (int): + Number of speakers for multi-speaker models. Defaults to 1. + r (int): + Initial number of output frames that the decoder computed per iteration. Larger values makes training and inference + faster but reduces the quality of the output frames. This must be equal to the largest `r` value used in + `gradual_training` schedule. Defaults to 1. + gradual_training (List[List]): + Parameters for the gradual training schedule. It is in the form `[[a, b, c], [d ,e ,f] ..]` where `a` is + the step number to start using the rest of the values, `b` is the `r` value and `c` is the batch size. + If sets None, no gradual training is used. Defaults to None. + memory_size (int): + Defines the number of previous frames used by the Prenet. If set to < 0, then it uses only the last frame. + Defaults to -1. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dropout (bool): + enables / disables the use of dropout in the Prenet. Defaults to True. + prenet_dropout_at_inference (bool): + enable / disable the use of dropout in the Prenet at the inference time. Defaults to False. + stopnet (bool): + enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True. + stopnet_pos_weight (float): + Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with + datasets with longer sentences. Defaults to 10. + max_decoder_steps (int): + Max number of steps allowed for the decoder. Defaults to 50. + encoder_in_features (int): + Channels of encoder input and character embedding tensors. Defaults to 256. + decoder_in_features (int): + Channels of decoder input and encoder output tensors. Defaults to 256. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + separate_stopnet (bool): + Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True. + attention_type (str): + attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'. + attention_heads (int): + Number of attention heads for GMM attention. Defaults to 5. + windowing (bool): + It especially useful at inference to keep attention alignment diagonal. Defaults to False. + use_forward_attn (bool): + It is only valid if ```attn_type``` is ```original```. Defaults to False. + forward_attn_mask (bool): + enable/disable extra masking over forward attention. It is useful at inference to prevent + possible attention failures. Defaults to False. + transition_agent (bool): + enable/disable transition agent in forward attention. Defaults to False. + location_attn (bool): + enable/disable location sensitive attention as in the original Tacotron2 paper. + It is only valid if ```attn_type``` is ```original```. Defaults to True. + bidirectional_decoder (bool): + enable/disable bidirectional decoding. Defaults to False. + double_decoder_consistency (bool): + enable/disable double decoder consistency. Defaults to False. + ddc_r (int): + reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this + as a multiple of the `r` value. Defaults to 6. + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to `RAdam`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `NoamLR`. + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + lr (float): + Initial learning rate. Defaults to `1e-4`. + wd (float): + Weight decay coefficient. Defaults to `1e-6`. + grad_clip (float): + Gradient clipping threshold. Defaults to `5`. + seq_len_norm (bool): + enable / disable the sequnce length normalization in the loss functions. If set True, loss of a sample + is divided by the sequence length. Defaults to False. + loss_masking (bool): + enable / disable masking the paddings of the samples in loss computation. Defaults to True. + decoder_loss_alpha (float): + Weight for the decoder loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_loss_alpha (float): + Weight for the postnet loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_diff_spec_alpha (float): + Weight for the postnet differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_diff_spec_alpha (float): + + Weight for the decoder differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_ssim_alpha (float): + Weight for the decoder SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_ssim_alpha (float): + Weight for the postnet SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + ga_alpha (float): + Weight for the guided attention loss. If set less than or equal to zero, it disables the corresponding loss + function. Defaults to 5. + """ + + model: str = "tacotron" + # model_params: TacotronArgs = field(default_factory=lambda: TacotronArgs()) + use_gst: bool = False + gst: GSTConfig = None + gst_style_input: str = None + + use_capacitron_vae: bool = False + capacitron_vae: CapacitronVAEConfig = None + + # model specific params + num_speakers: int = 1 + num_chars: int = 0 + r: int = 2 + gradual_training: List[List[int]] = None + memory_size: int = -1 + prenet_type: str = "original" + prenet_dropout: bool = True + prenet_dropout_at_inference: bool = False + stopnet: bool = True + separate_stopnet: bool = True + stopnet_pos_weight: float = 10.0 + max_decoder_steps: int = 500 + encoder_in_features: int = 256 + decoder_in_features: int = 256 + decoder_output_dim: int = 80 + out_channels: int = 513 + + # attention layers + attention_type: str = "original" + attention_heads: int = None + attention_norm: str = "sigmoid" + attention_win: bool = False + windowing: bool = False + use_forward_attn: bool = False + forward_attn_mask: bool = False + transition_agent: bool = False + location_attn: bool = True + + # advance methods + bidirectional_decoder: bool = False + double_decoder_consistency: bool = False + ddc_r: int = 6 + + # multi-speaker settings + speakers_file: str = None + use_speaker_embedding: bool = False + speaker_embedding_dim: int = 512 + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = None + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + seq_len_norm: bool = False + loss_masking: bool = True + + # loss params + decoder_loss_alpha: float = 0.25 + postnet_loss_alpha: float = 0.25 + postnet_diff_spec_alpha: float = 0.25 + decoder_diff_spec_alpha: float = 0.25 + decoder_ssim_alpha: float = 0.25 + postnet_ssim_alpha: float = 0.25 + ga_alpha: float = 5.0 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def check_values(self): + if self.gradual_training: + assert ( + self.gradual_training[0][1] == self.r + ), f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}" + if self.model == "tacotron" and self.audio is not None: + assert self.out_channels == ( + self.audio.fft_size // 2 + 1 + ), f"{self.out_channels} vs {self.audio.fft_size // 2 + 1}" + if self.model == "tacotron2" and self.audio is not None: + assert self.out_channels == self.audio.num_mels diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c7f91dcd965e0d1f1e3f2b78de321e27af3b95 --- /dev/null +++ b/TTS/tts/configs/vits_config.py @@ -0,0 +1,155 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.vits import VitsArgs + + +@dataclass +class VitsConfig(BaseTTSConfig): + """Defines parameters for VITS End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (VitsArgs): + Model architecture arguments. Defaults to `VitsArgs()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> config = VitsConfig() + """ + + model: str = "vits" + # model specific params + model_args: VitsArgs = field(default_factory=VitsArgs) + + # optimizer + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # loss params + kl_loss_alpha: float = 1.0 + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + mel_loss_alpha: float = 45.0 + dur_loss_alpha: float = 1.0 + speaker_encoder_loss_alpha: float = 1.0 + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # testing + test_sentences: List[List] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: str = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7c9eddeae7dd68cd8e73feab9e6f7ec7f002b0 --- /dev/null +++ b/TTS/tts/datasets/__init__.py @@ -0,0 +1,169 @@ +import sys +from collections import Counter +from pathlib import Path +from typing import Callable, Dict, List, Tuple, Union + +import numpy as np + +from TTS.tts.datasets.dataset import * +from TTS.tts.datasets.formatters import * + + +def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): + """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. + + Args: + <<<<<<< HEAD + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + ======= + items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. + >>>>>>> Fix docstring + """ + speakers = [item["speaker_name"] for item in items] + is_multi_speaker = len(set(speakers)) > 1 + if eval_split_size > 1: + eval_split_size = int(eval_split_size) + else: + if eval_split_max_size: + eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size)) + else: + eval_split_size = int(len(items) * eval_split_size) + + assert ( + eval_split_size > 0 + ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( + 1 / len(items) + ) + np.random.seed(0) + np.random.shuffle(items) + if is_multi_speaker: + items_eval = [] + speakers = [item["speaker_name"] for item in items] + speaker_counter = Counter(speakers) + while len(items_eval) < eval_split_size: + item_idx = np.random.randint(0, len(items)) + speaker_to_be_removed = items[item_idx]["speaker_name"] + if speaker_counter[speaker_to_be_removed] > 1: + items_eval.append(items[item_idx]) + speaker_counter[speaker_to_be_removed] -= 1 + del items[item_idx] + return items_eval, items + return items[:eval_split_size], items[eval_split_size:] + + +def load_tts_samples( + datasets: Union[List[Dict], Dict], + eval_split=True, + formatter: Callable = None, + eval_split_max_size=None, + eval_split_size=0.01, +) -> Tuple[List[List], List[List]]: + """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. + If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based + on the dataset name. + + Args: + datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are + in the list, they are all merged. + + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + + formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It + must take the root_path and the meta_file name and return a list of samples in the format of + `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as + example. Defaults to None. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ + meta_data_train_all = [] + meta_data_eval_all = [] if eval_split else None + if not isinstance(datasets, list): + datasets = [datasets] + for dataset in datasets: + name = dataset["name"] + root_path = dataset["path"] + meta_file_train = dataset["meta_file_train"] + meta_file_val = dataset["meta_file_val"] + ignored_speakers = dataset["ignored_speakers"] + language = dataset["language"] + + # setup the right data processor + if formatter is None: + formatter = _get_formatter_by_name(name) + # load train set + meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) + meta_data_train = [{**item, **{"language": language}} for item in meta_data_train] + + print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") + # load evaluation split if set + if eval_split: + if meta_file_val: + meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) + meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval] + else: + meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_split_max_size, eval_split_size) + meta_data_eval_all += meta_data_eval + meta_data_train_all += meta_data_train + # load attention masks for the duration predictor training + if dataset.meta_file_attn_mask: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) + for idx, ins in enumerate(meta_data_train_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_train_all[idx].update({"alignment_file": attn_file}) + if meta_data_eval_all: + for idx, ins in enumerate(meta_data_eval_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_eval_all[idx].update({"alignment_file": attn_file}) + # set none for the next iter + formatter = None + return meta_data_train_all, meta_data_eval_all + + +def load_attention_mask_meta_data(metafile_path): + """Load meta data file created by compute_attention_masks.py""" + with open(metafile_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + meta_data = [] + for line in lines: + wav_file, attn_file = line.split("|") + meta_data.append([wav_file, attn_file]) + return meta_data + + +def _get_formatter_by_name(name): + """Returns the respective preprocessing function.""" + thismodule = sys.modules[__name__] + return getattr(thismodule, name.lower()) + + +def find_unique_chars(data_samples, verbose=True): + texts = "".join(item[0] for item in data_samples) + chars = set(texts) + lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + + if verbose: + print(f" > Number of unique characters: {len(chars)}") + print(f" > Unique characters: {''.join(sorted(chars))}") + print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") + print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + return chars_force_lower diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f16e4efe390d8ebaf63eb681ec3d5646e6be3e --- /dev/null +++ b/TTS/tts/datasets/dataset.py @@ -0,0 +1,772 @@ +import collections +import os +import random +from typing import Dict, List, Union + +import numpy as np +import torch +import tqdm +from torch.utils.data import Dataset + +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.utils.audio import AudioProcessor + +# to prevent too many open files error as suggested here +# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +torch.multiprocessing.set_sharing_strategy("file_system") + + +def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + + +def noise_augment_audio(wav): + return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + + +class TTSDataset(Dataset): + def __init__( + self, + outputs_per_step: int = 1, + compute_linear_spec: bool = False, + ap: AudioProcessor = None, + samples: List[Dict] = None, + tokenizer: "TTSTokenizer" = None, + compute_f0: bool = False, + f0_cache_path: str = None, + return_wav: bool = False, + batch_group_size: int = 0, + min_text_len: int = 0, + max_text_len: int = float("inf"), + min_audio_len: int = 0, + max_audio_len: int = float("inf"), + phoneme_cache_path: str = None, + precompute_num_workers: int = 0, + speaker_id_mapping: Dict = None, + d_vector_mapping: Dict = None, + language_id_mapping: Dict = None, + use_noise_augment: bool = False, + start_by_longest: bool = False, + verbose: bool = False, + ): + """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. + + If you need something different, you can subclass and override. + + Args: + outputs_per_step (int): Number of time frames predicted per step. + + compute_linear_spec (bool): compute linear spectrogram if True. + + ap (TTS.tts.utils.AudioProcessor): Audio processor object. + + samples (list): List of dataset samples. + + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + + compute_f0 (bool): compute f0 if True. Defaults to False. + + f0_cache_path (str): Path to store f0 cache. Defaults to None. + + return_wav (bool): Return the waveform of the sample. Defaults to False. + + batch_group_size (int): Range of batch randomization after sorting + sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a + batch. Set 0 to disable. Defaults to 0. + + min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. + Defaults to 0. + + max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. + Defaults to float("inf"). + + min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. + Defaults to 0. + + max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. + The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to + this value if you encounter an OOM error in training. Defaults to float("inf"). + + phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a + separate file. Defaults to None. + + precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. + + speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the + embedding layer. Defaults to None. + + d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. + + use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + + start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + + verbose (bool): Print diagnostic information. Defaults to false. + """ + super().__init__() + self.batch_group_size = batch_group_size + self._samples = samples + self.outputs_per_step = outputs_per_step + self.compute_linear_spec = compute_linear_spec + self.return_wav = return_wav + self.compute_f0 = compute_f0 + self.f0_cache_path = f0_cache_path + self.min_audio_len = min_audio_len + self.max_audio_len = max_audio_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len + self.ap = ap + self.phoneme_cache_path = phoneme_cache_path + self.speaker_id_mapping = speaker_id_mapping + self.d_vector_mapping = d_vector_mapping + self.language_id_mapping = language_id_mapping + self.use_noise_augment = use_noise_augment + self.start_by_longest = start_by_longest + + self.verbose = verbose + self.rescue_item_idx = 1 + self.pitch_computed = False + self.tokenizer = tokenizer + + if self.tokenizer.use_phonemes: + self.phoneme_dataset = PhonemeDataset( + self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + ) + + if compute_f0: + self.f0_dataset = F0Dataset( + self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + ) + + if self.verbose: + self.print_logs() + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, new_samples): + self._samples = new_samples + if hasattr(self, "f0_dataset"): + self.f0_dataset.samples = new_samples + if hasattr(self, "phoneme_dataset"): + self.phoneme_dataset.samples = new_samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.load_data(idx) + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> DataLoader initialization") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + def load_wav(self, filename): + waveform = self.ap.load_wav(filename) + assert waveform.size > 0 + return waveform + + def get_phonemes(self, idx, text): + out_dict = self.phoneme_dataset[idx] + assert text == out_dict["text"], f"{text} != {out_dict['text']}" + assert len(out_dict["token_ids"]) > 0 + return out_dict + + def get_f0(self, idx): + out_dict = self.f0_dataset[idx] + item = self.samples[idx] + assert item["audio_file"] == out_dict["audio_file"] + return out_dict + + @staticmethod + def get_attn_mask(attn_file): + return np.load(attn_file) + + def get_token_ids(self, idx, text): + if self.tokenizer.use_phonemes: + token_ids = self.get_phonemes(idx, text)["token_ids"] + else: + token_ids = self.tokenizer.text_to_ids(text) + return np.array(token_ids, dtype=np.int32) + + def load_data(self, idx): + item = self.samples[idx] + + raw_text = item["text"] + + wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) + + # apply noise for augmentation + if self.use_noise_augment: + wav = noise_augment_audio(wav) + + # get token ids + token_ids = self.get_token_ids(idx, item["text"]) + + # get pre-computed attention maps + attn = None + if "alignment_file" in item: + attn = self.get_attn_mask(item["alignment_file"]) + + # after phonemization the text length may change + # this is a shareful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: + self.rescue_item_idx += 1 + return self.load_data(self.rescue_item_idx) + + # get f0 values + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + + sample = { + "raw_text": raw_text, + "token_ids": token_ids, + "wav": wav, + "pitch": f0, + "attn": attn, + "item_idx": item["audio_file"], + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "wav_file_name": os.path.basename(item["audio_file"]), + } + return sample + + @staticmethod + def _compute_lengths(samples): + new_samples = [] + for item in samples: + audio_length = os.path.getsize(item["audio_file"]) / 16 * 8 # assuming 16bit audio + text_lenght = len(item["text"]) + item["audio_length"] = audio_length + item["text_length"] = text_lenght + new_samples += [item] + return new_samples + + @staticmethod + def filter_by_length(lengths: List[int], min_len: int, max_len: int): + idxs = np.argsort(lengths) # ascending order + ignore_idx = [] + keep_idx = [] + for idx in idxs: + length = lengths[idx] + if length < min_len or length > max_len: + ignore_idx.append(idx) + else: + keep_idx.append(idx) + return ignore_idx, keep_idx + + @staticmethod + def sort_by_length(samples: List[List]): + audio_lengths = [s["audio_length"] for s in samples] + idxs = np.argsort(audio_lengths) # ascending order + return idxs + + @staticmethod + def create_buckets(samples, batch_group_size: int): + assert batch_group_size > 0 + for i in range(len(samples) // batch_group_size): + offset = i * batch_group_size + end_offset = offset + batch_group_size + temp_items = samples[offset:end_offset] + random.shuffle(temp_items) + samples[offset:end_offset] = temp_items + return samples + + @staticmethod + def _select_samples_by_idx(idxs, samples): + samples_new = [] + for idx in idxs: + samples_new.append(samples[idx]) + return samples_new + + def preprocess_samples(self): + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + """ + samples = self._compute_lengths(self.samples) + + # sort items based on the sequence length in ascending order + text_lengths = [i["text_length"] for i in samples] + audio_lengths = [i["audio_length"] for i in samples] + text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) + keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) + ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) + + samples = self._select_samples_by_idx(keep_idx, samples) + + sorted_idxs = self.sort_by_length(samples) + + if self.start_by_longest: + longest_idxs = sorted_idxs[-1] + sorted_idxs[-1] = sorted_idxs[0] + sorted_idxs[0] = longest_idxs + + samples = self._select_samples_by_idx(sorted_idxs, samples) + + if len(samples) == 0: + raise RuntimeError(" [!] No samples left") + + # shuffle batch groups + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. + if self.batch_group_size > 0: + samples = self.create_buckets(samples, self.batch_group_size) + + # update items to the new sorted items + audio_lengths = [s["audio_length"] for s in samples] + text_lengths = [s["text_length"] for s in samples] + self.samples = samples + + if self.verbose: + print(" | > Preprocessing samples") + print(" | > Max text length: {}".format(np.max(text_lengths))) + print(" | > Min text length: {}".format(np.min(text_lengths))) + print(" | > Avg text length: {}".format(np.mean(text_lengths))) + print(" | ") + print(" | > Max audio length: {}".format(np.max(audio_lengths))) + print(" | > Min audio length: {}".format(np.min(audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) + print(f" | > Num. instances discarded samples: {len(ignore_idx)}") + print(" | > Batch group size: {}.".format(self.batch_group_size)) + + @staticmethod + def _sort_batch(batch, text_lengths): + """Sort the batch by the input text length for RNN efficiency. + + Args: + batch (Dict): Batch returned by `__getitem__`. + text_lengths (List[int]): Lengths of the input character sequences. + """ + text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) + batch = [batch[idx] for idx in ids_sorted_decreasing] + return batch, text_lengths, ids_sorted_decreasing + + def collate_fn(self, batch): + r""" + Perform preprocessing and create a final data batch: + 1. Sort batch instances by text-length + 2. Convert Audio signal to features. + 3. PAD sequences wrt r. + 4. Load to Torch. + """ + + # Puts each data field into a tensor with outer dimension batch size + if isinstance(batch[0], collections.abc.Mapping): + + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) + + # sort items with text input length for RNN efficiency + batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) + + # convert list of dicts to dict of lists + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + # get language ids from language names + if self.language_id_mapping is not None: + language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] + else: + language_ids = None + # get pre-computed d-vectors + if self.d_vector_mapping is not None: + wav_files_names = list(batch["wav_file_name"]) + d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names] + else: + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_id_mapping: + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] + else: + speaker_ids = None + # compute features + mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] + + mel_lengths = [m.shape[1] for m in mel] + + # lengths adjusted by the reduction factor + mel_lengths_adjusted = [ + m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) + if m.shape[1] % self.outputs_per_step + else m.shape[1] + for m in mel + ] + + # compute 'stop token' targets + stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] + + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) + + # PAD sequences with longest instance in the batch + token_ids = prepare_data(batch["token_ids"]).astype(np.int32) + + # PAD features with longest instance + mel = prepare_tensor(mel, self.outputs_per_step) + + # B x D x T --> B x T x D + mel = mel.transpose(0, 2, 1) + + # convert things to pytorch + token_ids_lengths = torch.LongTensor(token_ids_lengths) + token_ids = torch.LongTensor(token_ids) + mel = torch.FloatTensor(mel).contiguous() + mel_lengths = torch.LongTensor(mel_lengths) + stop_targets = torch.FloatTensor(stop_targets) + + # speaker vectors + if d_vectors is not None: + d_vectors = torch.FloatTensor(d_vectors) + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + # compute linear spectrogram + linear = None + if self.compute_linear_spec: + linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] + linear = prepare_tensor(linear, self.outputs_per_step) + linear = linear.transpose(0, 2, 1) + assert mel.shape[1] == linear.shape[1] + linear = torch.FloatTensor(linear).contiguous() + + # format waveforms + wav_padded = None + if self.return_wav: + wav_lengths = [w.shape[0] for w in batch["wav"]] + max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length + wav_lengths = torch.LongTensor(wav_lengths) + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + for i, w in enumerate(batch["wav"]): + mel_length = mel_lengths_adjusted[i] + w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") + w = w[: mel_length * self.ap.hop_length] + wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) + wav_padded.transpose_(1, 2) + + # format F0 + if self.compute_f0: + pitch = prepare_data(batch["pitch"]) + assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None + + # format attention masks + attns = None + if batch["attn"][0] is not None: + attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] + for idx, attn in enumerate(attns): + pad2 = mel.shape[1] - attn.shape[1] + pad1 = token_ids.shape[1] - attn.shape[0] + assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" + attn = np.pad(attn, [[0, pad1], [0, pad2]]) + attns[idx] = attn + attns = prepare_tensor(attns, self.outputs_per_step) + attns = torch.FloatTensor(attns).unsqueeze(1) + + return { + "token_id": token_ids, + "token_id_lengths": token_ids_lengths, + "speaker_names": batch["speaker_name"], + "linear": linear, + "mel": mel, + "mel_lengths": mel_lengths, + "stop_targets": stop_targets, + "item_idxs": batch["item_idx"], + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "attns": attns, + "waveform": wav_padded, + "raw_text": batch["raw_text"], + "pitch": pitch, + "language_ids": language_ids, + } + + raise TypeError( + ( + "batch must contain tensors, numbers, dicts or lists;\ + found {}".format( + type(batch[0]) + ) + ) + ) + + +class PhonemeDataset(Dataset): + """Phoneme Dataset for converting input text to phonemes and then token IDs + + At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data + loading latency. If `cache_path` is already present, it skips the pre-computation. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + tokenizer (TTSTokenizer): + Tokenizer to convert input text to phonemes. + + cache_path (str): + Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. + + precompute_num_workers (int): + Number of workers used for pre-computing the phonemes. Defaults to 0. + """ + + def __init__( + self, + samples: Union[List[Dict], List[List]], + tokenizer: "TTSTokenizer", + cache_path: str, + precompute_num_workers=0, + ): + self.samples = samples + self.tokenizer = tokenizer + self.cache_path = cache_path + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + + def __getitem__(self, index): + item = self.samples[index] + ids = self.compute_or_load(item["audio_file"], item["text"]) + ph_hat = self.tokenizer.ids_to_text(ids) + return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + + def __len__(self): + return len(self.samples) + + def compute_or_load(self, wav_file, text): + """Compute phonemes for the given text. + + If the phonemes are already cached, load them from cache. + """ + file_name = os.path.splitext(os.path.basename(wav_file))[0] + file_ext = "_phoneme.npy" + cache_path = os.path.join(self.cache_path, file_name + file_ext) + try: + ids = np.load(cache_path) + except FileNotFoundError: + ids = self.tokenizer.text_to_ids(text) + np.save(cache_path, ids) + return ids + + def get_pad_id(self): + """Get pad token ID for sequence padding""" + return self.tokenizer.pad_id + + def precompute(self, num_workers=1): + """Precompute phonemes for all samples. + + We use pytorch dataloader because we are lazy. + """ + print("[*] Pre-computing phonemes...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + for _ in dataloder: + pbar.update(batch_size) + + def collate_fn(self, batch): + ids = [item["token_ids"] for item in batch] + ids_lens = [item["token_ids_len"] for item in batch] + texts = [item["text"] for item in batch] + texts_hat = [item["ph_hat"] for item in batch] + ids_lens_max = max(ids_lens) + ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) + for i, ids_len in enumerate(ids_lens): + ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) + return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> PhonemeDataset ") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class F0Dataset: + """F0 Dataset for computing F0 from wav files in CPU + + Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of F0 values if `normalize_f0` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute F0 from wav files. + + cache_path (str): + Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the F0 values. Defaults to 0. + + normalize_f0 (bool): + Whether to normalize F0 values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + self.samples = samples + self.ap = ap + self.verbose = verbose + self.cache_path = cache_path + self.normalize_f0 = normalize_f0 + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_f0: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + f0 = self.compute_or_load(item["audio_file"]) + if self.normalize_f0: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + f0 = self.normalize(f0) + return {"audio_file": item["audio_file"], "f0": f0} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + print("[*] Pre-computing F0s...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_f0 = self.normalize_f0 + self.normalize_f0 = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + f0 = batch["f0"] + computed_data.append(f for f in f0) + pbar.update(batch_size) + self.normalize_f0 = normalize_f0 + + if self.normalize_f0: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_pitch_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean + pitch[zero_idxs] = 0.0 + return pitch + + def compute_or_load(self, wav_file): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + def collate_fn(self, batch): + audio_file = [item["audio_file"] for item in batch] + f0s = [item["f0"] for item in batch] + f0_lens = [len(item["f0"]) for item in batch] + f0_lens_max = max(f0_lens) + f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + for i, f0_len in enumerate(f0_lens): + f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + return {"audio_file": audio_file, "f0": f0s_torch, "f0_lens": f0_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> F0Dataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py new file mode 100644 index 0000000000000000000000000000000000000000..ef05ea7c7ada5b614240bd0733529e530c137d10 --- /dev/null +++ b/TTS/tts/datasets/formatters.py @@ -0,0 +1,558 @@ +import os +import re +import xml.etree.ElementTree as ET +from glob import glob +from pathlib import Path +from typing import List + +import pandas as pd +from tqdm import tqdm + +######################## +# DATASETS +######################## + + +def coqui(root_path, meta_file, ignored_speakers=None): + """Interal dataset formatter.""" + metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|") + assert all(x in metadata.columns for x in ["audio_file", "text"]) + speaker_name = None if "speaker_name" in metadata.columns else "coqui" + emotion_name = None if "emotion_name" in metadata.columns else "neutral" + items = [] + not_found_counter = 0 + for row in metadata.itertuples(): + if speaker_name is None and ignored_speakers is not None and row.speaker_name in ignored_speakers: + continue + audio_path = os.path.join(root_path, row.audio_file) + if not os.path.exists(audio_path): + not_found_counter += 1 + continue + items.append( + { + "text": row.text, + "audio_file": audio_path, + "speaker_name": speaker_name if speaker_name is not None else row.speaker_name, + "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + } + ) + if not_found_counter > 0: + print(f" | > [!] {not_found_counter} files not found") + return items + + +def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalize TWEB dataset. + https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "tweb" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + wav_file = os.path.join(root_path, cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = cols[1].strip() + text = cols[0].strip() + wav_file = os.path.join(root_path, "wavs", wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="ISO 8859-1") as ttf: + for line in ttf: + cols = line.strip().split("|") + wav_file = cols[0].strip() + text = cols[1].strip() + folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" + wav_file = os.path.join(root_path, folder_name, wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def mailabs(root_path, meta_files=None, ignored_speakers=None): + """Normalizes M-AI-Labs meta data files to TTS format + + Args: + root_path (str): root folder of the MAILAB language folder. + meta_files (str): list of meta files to be used in the training. If None, finds all the csv files + recursively. Defaults to None + """ + speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") + if not meta_files: + csv_files = glob(root_path + "/**/metadata.csv", recursive=True) + else: + csv_files = meta_files + + # meta_files = [f.strip() for f in meta_files.split(",")] + items = [] + for csv_file in csv_files: + if os.path.isfile(csv_file): + txt_file = csv_file + else: + txt_file = os.path.join(root_path, csv_file) + + folder = os.path.dirname(txt_file) + # determine speaker based on folder structure... + speaker_name_match = speaker_regex.search(txt_file) + if speaker_name_match is None: + continue + speaker_name = speaker_name_match.group("speaker_name") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + print(" | > {}".format(csv_file)) + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + if not meta_files: + wav_file = os.path.join(folder, "wavs", cols[0] + ".wav") + else: + wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") + if os.path.isfile(wav_file): + text = cols[1].strip() + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + else: + # M-AI-Labs have some missing samples, so just print the warning + print("> File %s does not exist!" % (wav_file)) + return items + + +def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file to TTS format + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ljspeech" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file for TTS testing + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + speaker_id = 0 + for idx, line in enumerate(ttf): + # 2 samples per speaker to avoid eval split issues + if idx % 2 == 0: + speaker_id += 1 + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"}) + return items + + +def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the thorsten meta data file to TTS format + https://github.com/thorstenMueller/deep-learning-german-tts/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "thorsten" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the sam-accenture meta data file to TTS format + https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files""" + xml_file = os.path.join(root_path, "voice_over_recordings", meta_file) + xml_root = ET.parse(xml_file).getroot() + items = [] + speaker_name = "sam_accenture" + for item in xml_root.findall("./fileid"): + text = item.text + wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") + if not os.path.exists(wav_file): + print(f" [!] {wav_file} in metafile does not exist. Skipping...") + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the RUSLAN meta data file to TTS format + https://ruslan-corpus.github.io/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ruslan" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the CSS10 dataset file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "css10" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the Nancy meta data file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "nancy" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + utt_id = line.split()[1] + text = line[line.find('"') + 1 : line.rfind('"') - 1] + wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def common_voice(root_path, meta_file, ignored_speakers=None): + """Normalize the common voice meta data file to TTS format.""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("client_id"): + continue + cols = line.split("\t") + text = cols[2] + speaker_name = cols[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name}) + return items + + +def libri_tts(root_path, meta_files=None, ignored_speakers=None): + """https://ai.google/tools/datasets/libri-tts/""" + items = [] + if not meta_files: + meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) + else: + if isinstance(meta_files, str): + meta_files = [os.path.join(root_path, meta_files)] + + for meta_file in meta_files: + _meta_file = os.path.basename(meta_file).split(".")[0] + with open(meta_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + file_name = cols[0] + speaker_name, chapter_id, *_ = cols[0].split("_") + _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") + wav_file = os.path.join(_root_path, file_name + ".wav") + text = cols[2] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"}) + for item in items: + assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}" + return items + + +def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "turkish-female" + skipped_files = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav") + if not os.path.exists(wav_file): + skipped_files.append(wav_file) + continue + text = cols[1].strip() + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + print(f" [!] {len(skipped_files)} files skipped. They don't exist...") + return items + + +# ToDo: add the dataset link when the dataset is released publicly +def brspeech(root_path, meta_file, ignored_speakers=None): + """BRSpeech 3.0 beta""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("wav_filename"): + continue + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] + speaker_id = cols[3] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) + return items + + +def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None): + """VCTK dataset v0.92. + + URL: + https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip + + This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```. + It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset. + + mic1: + Audio recorded using an omni-directional microphone (DPA 4035). + Contains very low frequency noises. + This is the same audio released in previous versions of VCTK: + https://doi.org/10.7488/ds/1994 + + mic2: + Audio recorded using a small diaphragm condenser microphone with + very wide bandwidth (Sennheiser MKH 800). + Two speakers, p280 and p315 had technical issues of the audio + recordings using MKH 800. + """ + file_ext = "flac" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + # p280 has no mic2 recordings + if speaker_id == "p280": + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}") + else: + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") + if os.path.exists(wav_file): + items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id}) + else: + print(f" [!] wav files don't exist - {wav_file}") + return items + + +def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): + """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id}) + return items + + +def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-argument + items = [] + speaker_name = "synpaflex" + root_path = os.path.join(root_path, "") + wav_files = glob(f"{root_path}**/*.wav", recursive=True) + for wav_file in wav_files: + if os.sep + "wav" + os.sep in wav_file: + txt_file = wav_file.replace("wav", "txt") + else: + txt_file = os.path.join( + os.path.dirname(wav_file), "txt", os.path.basename(wav_file).replace(".wav", ".txt") + ) + if os.path.exists(txt_file) and os.path.exists(wav_file): + with open(txt_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items + + +def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, ignored_speakers=None): + """ToDo: Refer the paper when available""" + items = [] + split_dir = meta_files + meta_files = glob(f"{os.path.join(root_path, split_dir)}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readline().replace("\n", "") + # ignore sentences that contains digits + if ignore_digits_sentences and any(map(str.isdigit, text)): + continue + wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac") + items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id}) + return items + + +def mls(root_path, meta_files=None, ignored_speakers=None): + """http://www.openslr.org/94/""" + items = [] + with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta: + for line in meta: + file, text = line.split("\t") + text = text[:-1] + speaker, book, *_ = file.split("_") + wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker in ignored_speakers: + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker}) + return items + + +# ======================================== VOX CELEB =========================================== +def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="2") + + +def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="1") + + +def _voxcel_x(root_path, meta_file, voxcel_idx): + assert voxcel_idx in ["1", "2"] + expected_count = 148_000 if voxcel_idx == "1" else 1_000_000 + voxceleb_path = Path(root_path) + cache_to = voxceleb_path / f"metafile_voxceleb{voxcel_idx}.csv" + cache_to.parent.mkdir(exist_ok=True) + + # if not exists meta file, crawl recursively for 'wav' files + if meta_file is not None: + with open(str(meta_file), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + elif not cache_to.exists(): + cnt = 0 + meta_data = [] + wav_files = voxceleb_path.rglob("**/*.wav") + for path in tqdm( + wav_files, + desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", + total=expected_count, + ): + speaker_id = str(Path(path).parent.parent.stem) + assert speaker_id.startswith("id") + text = None # VoxCel does not provide transciptions, and they are not needed for training the SE + meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") + cnt += 1 + with open(str(cache_to), "w", encoding="utf-8") as f: + f.write("".join(meta_data)) + if cnt < expected_count: + raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") + + with open(str(cache_to), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + +def emotion(root_path, meta_file, ignored_speakers=None): + """Generic emotion dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("file_path"): + continue + cols = line.split(",") + wav_file = os.path.join(root_path, cols[0]) + speaker_id = cols[1] + emotion_id = cols[2].replace("\n", "") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id}) + return items + + +def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument + """Normalizes the Baker meta data file to TTS format + + Args: + root_path (str): path to the baker dataset + meta_file (str): name of the meta dataset containing names of wav to select and the transcript of the sentence + Returns: + List[List[str]]: List of (text, wav_path, speaker_name) associated with each sentences + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "baker" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + wav_name, text = line.rstrip("\n").split("|") + wav_path = os.path.join(root_path, "clips_22", wav_name) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name}) + return items + + +def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kokoro" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2].replace(" ", "") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + return items diff --git a/TTS/tts/layers/__init__.py b/TTS/tts/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f93efdb7fc41109ec3497d8e5e37ba05b0a4315e --- /dev/null +++ b/TTS/tts/layers/__init__.py @@ -0,0 +1 @@ +from TTS.tts.layers.losses import * diff --git a/TTS/tts/layers/align_tts/__init__.py b/TTS/tts/layers/align_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/align_tts/duration_predictor.py b/TTS/tts/layers/align_tts/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b83894cc3f87575a89ea8fd7bf4a584ca22c28 --- /dev/null +++ b/TTS/tts/layers/align_tts/duration_predictor.py @@ -0,0 +1,21 @@ +from torch import nn + +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.generic.transformer import FFTransformerBlock + + +class DurationPredictor(nn.Module): + def __init__(self, num_chars, hidden_channels, hidden_channels_ffn, num_heads): + super().__init__() + self.embed = nn.Embedding(num_chars, hidden_channels) + self.pos_enc = PositionalEncoding(hidden_channels, dropout_p=0.1) + self.FFT = FFTransformerBlock(hidden_channels, num_heads, hidden_channels_ffn, 2, 0.1) + self.out_layer = nn.Conv1d(hidden_channels, 1, 1) + + def forward(self, text, text_lengths): + # B, L -> B, L + emb = self.embed(text) + emb = self.pos_enc(emb.transpose(1, 2)) + x = self.FFT(emb, text_lengths) + x = self.out_layer(x).squeeze(-1) + return x diff --git a/TTS/tts/layers/align_tts/mdn.py b/TTS/tts/layers/align_tts/mdn.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb332524bf7a5fec6a23da9e7977de6325a0324 --- /dev/null +++ b/TTS/tts/layers/align_tts/mdn.py @@ -0,0 +1,30 @@ +from torch import nn + + +class MDNBlock(nn.Module): + """Mixture of Density Network implementation + https://arxiv.org/pdf/2003.01950.pdf + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv1d(in_channels, in_channels, 1) + self.norm = nn.LayerNorm(in_channels) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.1) + self.conv2 = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x): + o = self.conv1(x) + o = o.transpose(1, 2) + o = self.norm(o) + o = o.transpose(1, 2) + o = self.relu(o) + o = self.dropout(o) + mu_sigma = self.conv2(o) + # TODO: check this sigmoid + # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) + mu = mu_sigma[:, : self.out_channels // 2, :] + log_sigma = mu_sigma[:, self.out_channels // 2 :, :] + return mu, log_sigma diff --git a/TTS/tts/layers/feed_forward/__init__.py b/TTS/tts/layers/feed_forward/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..34c586aab24e014ce99d5806a975585a242b81bd --- /dev/null +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -0,0 +1,230 @@ +import torch +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN, Conv1dBNBlock, ResidualConv1dBNBlock +from TTS.tts.layers.generic.transformer import FFTransformerBlock +from TTS.tts.layers.generic.wavenet import WNBlocks +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer + + +class WaveNetDecoder(nn.Module): + """WaveNet based decoder with a prenet and a postnet. + + prenet: conv1d_1x1 + postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1 + + TODO: Integrate speaker conditioning vector. + + Note: + default wavenet parameters; + params = { + "num_blocks": 12, + "hidden_channels":192, + "kernel_size": 5, + "dilation_rate": 1, + "num_layers": 4, + "dropout_p": 0.05 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels for prenet and postnet. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): + super().__init__() + # prenet + self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1) + # wavenet layers + self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params) + # postnet + self.postnet = [ + torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, out_channels, 1), + ] + self.postnet = nn.Sequential(*self.postnet) + + def forward(self, x, x_mask=None, g=None): + x = self.prenet(x) * x_mask + x = self.wn(x, x_mask, g) + o = self.postnet(x) * x_mask + return o + + +class RelativePositionTransformerDecoder(nn.Module): + """Decoder with Relative Positional Transformer. + + Note: + Default params + params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 8, + "rel_attn_window_size": 4, + "input_length": None + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + + super().__init__() + self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) + self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class FFTransformerDecoder(nn.Module): + """Decoder with FeedForwardTransformer. + + Default params + params={ + 'hidden_channels_ffn': 1024, + 'num_heads': 2, + "dropout_p": 0.1, + "num_layers": 6, + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, params): + + super().__init__() + self.transformer_block = FFTransformerBlock(in_channels, **params) + self.postnet = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + # TODO: handle multi-speaker + x_mask = 1 if x_mask is None else x_mask + o = self.transformer_block(x) * x_mask + o = self.postnet(o) * x_mask + return o + + +class ResidualConv1dBNDecoder(nn.Module): + """Residual Convolutional Decoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Note: + Default params + params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params) + self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) + self.postnet = nn.Sequential( + Conv1dBNBlock( + hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2 + ), + nn.Conv1d(hidden_channels, out_channels, 1), + ) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.res_conv_block(x, x_mask) + o = self.post_conv(o) + x + return self.postnet(o) * x_mask + + +class Decoder(nn.Module): + """Decodes the expanded phoneme encoding into spectrograms + Args: + out_channels (int): number of output channels. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. + decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + decoder_params (dict): model parameters for specified decoder type. + c_in_channels (int): number of channels for conditional input. + + Shapes: + - input: (B, C, T) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + out_channels, + in_hidden_channels, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + c_in_channels=0, + ): + super().__init__() + + if decoder_type.lower() == "relative_position_transformer": + self.decoder = RelativePositionTransformerDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "residual_conv_bn": + self.decoder = ResidualConv1dBNDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "wavenet": + self.decoder = WaveNetDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + c_in_channels=c_in_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "fftransformer": + self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) + else: + raise ValueError(f"[!] Unknown decoder type - {decoder_type}") + + def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument + """ + Args: + x: [B, C, T] + x_mask: [B, 1, T] + g: [B, C_g, 1] + """ + # TODO: implement multi-speaker + o = self.decoder(x, x_mask, g) + return o diff --git a/TTS/tts/layers/feed_forward/duration_predictor.py b/TTS/tts/layers/feed_forward/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..5392aeca3cd4eed08daeb2a3c34c735baec18364 --- /dev/null +++ b/TTS/tts/layers/feed_forward/duration_predictor.py @@ -0,0 +1,42 @@ +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN + + +class DurationPredictor(nn.Module): + """Speedy Speech duration predictor model. + Predicts phoneme durations from encoder outputs. + + Note: + Outputs interpreted as log(durations) + To get actual durations, do exp transformation + + conv_BN_4x1 -> conv_BN_3x1 -> conv_BN_1x1 -> conv_1x1 + + Args: + hidden_channels (int): number of channels in the inner layers. + """ + + def __init__(self, hidden_channels): + + super().__init__() + + self.layers = nn.ModuleList( + [ + Conv1dBN(hidden_channels, hidden_channels, 4, 1), + Conv1dBN(hidden_channels, hidden_channels, 3, 1), + Conv1dBN(hidden_channels, hidden_channels, 1, 1), + nn.Conv1d(hidden_channels, 1, 1), + ] + ) + + def forward(self, x, x_mask): + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + """ + o = x + for layer in self.layers: + o = layer(o) * x_mask + return o diff --git a/TTS/tts/layers/feed_forward/encoder.py b/TTS/tts/layers/feed_forward/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..caf939ffc73fedac299228e090b2df3bb4cc553c --- /dev/null +++ b/TTS/tts/layers/feed_forward/encoder.py @@ -0,0 +1,162 @@ +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.transformer import FFTransformerBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer + + +class RelativePositionTransformerEncoder(nn.Module): + """Speedy speech encoder built on Transformer with Relative Position encoding. + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = ResidualConv1dBNBlock( + in_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_res_blocks=3, + num_conv_blocks=1, + dilations=[1, 1, 1], + ) + self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class ResidualConv1dBNEncoder(nn.Module): + """Residual Convolutional Encoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) + self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) + + self.postnet = nn.Sequential( + *[ + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.ReLU(), + nn.BatchNorm1d(hidden_channels), + nn.Conv1d(hidden_channels, out_channels, 1), + ] + ) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.res_conv_block(o, x_mask) + o = self.postnet(o + x) * x_mask + return o * x_mask + + +class Encoder(nn.Module): + # pylint: disable=dangerous-default-value + """Factory class for Speedy Speech encoder enables different encoder types internally. + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. + encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + encoder_params (dict): model parameters for specified encoder type. + c_in_channels (int): number of channels for conditional input. + + Note: + Default encoder_params to be set in config.json... + + ```python + # for 'relative_position_transformer' + encoder_params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "rel_attn_window_size": 4, + "input_length": None + }, + + # for 'residual_conv_bn' + encoder_params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + } + + # for 'fftransformer' + encoder_params = { + "hidden_channels_ffn": 1024 , + "num_heads": 2, + "num_layers": 6, + "dropout_p": 0.1 + } + ``` + """ + + def __init__( + self, + in_hidden_channels, + out_channels, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + c_in_channels=0, + ): + super().__init__() + self.out_channels = out_channels + self.in_channels = in_hidden_channels + self.hidden_channels = in_hidden_channels + self.encoder_type = encoder_type + self.c_in_channels = c_in_channels + + # init encoder + if encoder_type.lower() == "relative_position_transformer": + # text encoder + # pylint: disable=unexpected-keyword-arg + self.encoder = RelativePositionTransformerEncoder( + in_hidden_channels, out_channels, in_hidden_channels, encoder_params + ) + elif encoder_type.lower() == "residual_conv_bn": + self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) + elif encoder_type.lower() == "fftransformer": + assert ( + in_hidden_channels == out_channels + ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" + # pylint: disable=unexpected-keyword-arg + self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) + else: + raise NotImplementedError(" [!] unknown encoder type.") + + def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + g: [B, C, 1] + """ + o = self.encoder(x, x_mask) + return o * x_mask diff --git a/TTS/tts/layers/generic/__init__.py b/TTS/tts/layers/generic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/generic/aligner.py b/TTS/tts/layers/generic/aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..eef4c4b66d80f9bab83ddf81427e5b48d2a43b4b --- /dev/null +++ b/TTS/tts/layers/generic/aligner.py @@ -0,0 +1,81 @@ +from typing import Tuple + +import torch +from torch import nn + + +class AlignmentNetwork(torch.nn.Module): + """Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. + + :: + + query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment + key -> conv1d -> relu -> conv1d -----------------------^ + + Args: + in_query_channels (int): Number of channels in the query network. Defaults to 80. + in_key_channels (int): Number of channels in the key network. Defaults to 512. + attn_channels (int): Number of inner channels in the attention layers. Defaults to 80. + temperature (float): Temperature for the softmax. Defaults to 0.0005. + """ + + def __init__( + self, + in_query_channels=80, + in_key_channels=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + self.softmax = torch.nn.Softmax(dim=3) + self.log_softmax = torch.nn.LogSoftmax(dim=3) + + self.key_layer = nn.Sequential( + nn.Conv1d( + in_key_channels, + in_key_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + self.query_layer = nn.Sequential( + nn.Conv1d( + in_query_channels, + in_query_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + def forward( + self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None + ) -> Tuple[torch.tensor, torch.tensor]: + """Forward pass of the aligner encoder. + Shapes: + - queries: :math:`[B, C, T_de]` + - keys: :math:`[B, C_emb, T_en]` + - mask: :math:`[B, T_de]` + Output: + attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. + attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. + """ + key_out = self.key_layer(keys) + query_out = self.query_layer(queries) + attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 + attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) + if attn_prior is not None: + attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) + if mask is not None: + attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + attn = self.softmax(attn_logp) + return attn, attn_logp diff --git a/TTS/tts/layers/generic/gated_conv.py b/TTS/tts/layers/generic/gated_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9a29c4499f970db538a4b99c3c05cba22576195f --- /dev/null +++ b/TTS/tts/layers/generic/gated_conv.py @@ -0,0 +1,37 @@ +from torch import nn + +from .normalization import LayerNorm + + +class GatedConvBlock(nn.Module): + """Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf + Args: + in_out_channels (int): number of input/output channels. + kernel_size (int): convolution kernel size. + dropout_p (float): dropout rate. + """ + + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): + super().__init__() + # class arguments + self.dropout_p = dropout_p + self.num_layers = num_layers + # define layers + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)] + self.norm_layers += [LayerNorm(2 * in_out_channels)] + + def forward(self, x, x_mask): + o = x + res = x + for idx in range(self.num_layers): + o = nn.functional.dropout(o, p=self.dropout_p, training=self.training) + o = self.conv_layers[idx](o * x_mask) + o = self.norm_layers[idx](o) + o = nn.functional.glu(o, dim=1) + o = res + o + res = o + return o diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c0270e405e4246e47b7bc0787e4cd4b069533f92 --- /dev/null +++ b/TTS/tts/layers/generic/normalization.py @@ -0,0 +1,123 @@ +import torch +from torch import nn + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + """Layer norm for the 2nd dimension of the input. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1) + self.beta = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x): + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + self.eps) + x = x * self.gamma + self.beta + return x + + +class LayerNorm2(nn.Module): + """Layer norm for the 2nd dimension of the input using torch primitive. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class TemporalBatchNorm1d(nn.BatchNorm1d): + """Normalize each channel separately over time and batch.""" + + def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): + super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) + + def forward(self, x): + return super().forward(x.transpose(2, 1)).transpose(2, 1) + + +class ActNorm(nn.Module): + """Activation Normalization bijector as an alternative to Batch Norm. It computes + mean and std from a sample data in advance and it uses these values + for normalization at training. + + Args: + channels (int): input channels. + ddi (False): data depended initialization flag. + + Shapes: + - inputs: (B, C, T) + - outputs: (B, C, T) + """ + + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = None + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m**2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed8d56786b5442c0038dc2af5eef9c68cdace56 --- /dev/null +++ b/TTS/tts/layers/generic/pos_encoding.py @@ -0,0 +1,70 @@ +import math + +import torch +from torch import nn + + +class PositionalEncoding(nn.Module): + """Sinusoidal positional encoding for non-recurrent neural networks. + Implementation based on "Attention Is All You Need" + + Args: + channels (int): embedding size + dropout_p (float): dropout rate applied to the output. + max_len (int): maximum sequence length. + use_scale (bool): whether to use a learnable scaling coefficient. + """ + + def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False): + super().__init__() + if channels % 2 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) + ) + self.max_len = max_len + self.use_scale = use_scale + if use_scale: + self.scale = torch.nn.Parameter(torch.ones(1)) + pe = torch.zeros(max_len, channels) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0).transpose(1, 2) + self.register_buffer("pe", pe) + if dropout_p > 0: + self.dropout = nn.Dropout(p=dropout_p) + self.channels = channels + + def forward(self, x, mask=None, first_idx=None, last_idx=None): + """ + Shapes: + x: [B, C, T] + mask: [B, 1, T] + first_idx: int + last_idx: int + """ + + x = x * math.sqrt(self.channels) + if first_idx is None: + if self.pe.size(2) < x.size(2): + raise RuntimeError( + f"Sequence is {x.size(2)} but PositionalEncoding is" + f" limited to {self.pe.size(2)}. See max_len argument." + ) + if mask is not None: + pos_enc = self.pe[:, :, : x.size(2)] * mask + else: + pos_enc = self.pe[:, :, : x.size(2)] + if self.use_scale: + x = x + self.scale * pos_enc + else: + x = x + pos_enc + else: + if self.use_scale: + x = x + self.scale * self.pe[:, :, first_idx:last_idx] + else: + x = x + self.pe[:, :, first_idx:last_idx] + if hasattr(self, "dropout"): + x = self.dropout(x) + return x diff --git a/TTS/tts/layers/generic/res_conv_bn.py b/TTS/tts/layers/generic/res_conv_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..30c134cd70018197950fb9fb4d7f5fa1a7198b5e --- /dev/null +++ b/TTS/tts/layers/generic/res_conv_bn.py @@ -0,0 +1,128 @@ +from torch import nn + + +class ZeroTemporalPad(nn.Module): + """Pad sequences to equal lentgh in the temporal dimension""" + + def __init__(self, kernel_size, dilation): + super().__init__() + total_pad = dilation * (kernel_size - 1) + begin = total_pad // 2 + end = total_pad - begin + self.pad_layer = nn.ZeroPad2d((0, 0, begin, end)) + + def forward(self, x): + return self.pad_layer(x) + + +class Conv1dBN(nn.Module): + """1d convolutional with batch norm. + conv1d -> relu -> BN blocks. + + Note: + Batch normalization is applied after ReLU regarding the original implementation. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): kernel size for convolutional filters. + dilation (int): dilation for convolution layers. + """ + + def __init__(self, in_channels, out_channels, kernel_size, dilation): + super().__init__() + padding = dilation * (kernel_size - 1) + pad_s = padding // 2 + pad_e = padding - pad_s + self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation) + self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0)) # uneven left and right padding + self.norm = nn.BatchNorm1d(out_channels) + + def forward(self, x): + o = self.conv1d(x) + o = self.pad(o) + o = nn.functional.relu(o) + o = self.norm(o) + return o + + +class Conv1dBNBlock(nn.Module): + """1d convolutional block with batch norm. It is a set of conv1d -> relu -> BN blocks. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of inner convolution channels. + kernel_size (int): kernel size for convolutional filters. + dilation (int): dilation for convolution layers. + num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2): + super().__init__() + self.conv_bn_blocks = [] + for idx in range(num_conv_blocks): + layer = Conv1dBN( + in_channels if idx == 0 else hidden_channels, + out_channels if idx == (num_conv_blocks - 1) else hidden_channels, + kernel_size, + dilation, + ) + self.conv_bn_blocks.append(layer) + self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks) + + def forward(self, x): + """ + Shapes: + x: (B, D, T) + """ + return self.conv_bn_blocks(x) + + +class ResidualConv1dBNBlock(nn.Module): + """Residual Convolutional Blocks with BN + Each block has 'num_conv_block' conv layers and 'num_res_blocks' such blocks are connected + with residual connections. + + conv_block = (conv1d -> relu -> bn) x 'num_conv_blocks' + residuak_conv_block = (x -> conv_block -> + ->) x 'num_res_blocks' + ' - - - - - - - - - ^ + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of inner convolution channels. + kernel_size (int): kernel size for convolutional filters. + dilations (list): dilations for each convolution layer. + num_res_blocks (int, optional): number of residual blocks. Defaults to 13. + num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2. + """ + + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2 + ): + + super().__init__() + assert len(dilations) == num_res_blocks + self.res_blocks = nn.ModuleList() + for idx, dilation in enumerate(dilations): + block = Conv1dBNBlock( + in_channels if idx == 0 else hidden_channels, + out_channels if (idx + 1) == len(dilations) else hidden_channels, + hidden_channels, + kernel_size, + dilation, + num_conv_blocks, + ) + self.res_blocks.append(block) + + def forward(self, x, x_mask=None): + if x_mask is None: + x_mask = 1.0 + o = x * x_mask + for block in self.res_blocks: + res = o + o = block(o) + o = o + res + if x_mask is not None: + o = o * x_mask + return o diff --git a/TTS/tts/layers/generic/time_depth_sep_conv.py b/TTS/tts/layers/generic/time_depth_sep_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..186cea02e75e156c40923de91086c369a9ea02ee --- /dev/null +++ b/TTS/tts/layers/generic/time_depth_sep_conv.py @@ -0,0 +1,84 @@ +import torch +from torch import nn + + +class TimeDepthSeparableConv(nn.Module): + """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf + It shows competative results with less computation and memory footprint.""" + + def __init__(self, in_channels, hid_channels, out_channels, kernel_size, bias=True): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hid_channels = hid_channels + self.kernel_size = kernel_size + + self.time_conv = nn.Conv1d( + in_channels, + 2 * hid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm1 = nn.BatchNorm1d(2 * hid_channels) + self.depth_conv = nn.Conv1d( + hid_channels, + hid_channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=hid_channels, + bias=bias, + ) + self.norm2 = nn.BatchNorm1d(hid_channels) + self.time_conv2 = nn.Conv1d( + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm3 = nn.BatchNorm1d(out_channels) + + def forward(self, x): + x_res = x + x = self.time_conv(x) + x = self.norm1(x) + x = nn.functional.glu(x, dim=1) + x = self.depth_conv(x) + x = self.norm2(x) + x = x * torch.sigmoid(x) + x = self.time_conv2(x) + x = self.norm3(x) + x = x_res + x + return x + + +class TimeDepthSeparableConvBlock(nn.Module): + def __init__(self, in_channels, hid_channels, out_channels, num_layers, kernel_size, bias=True): + super().__init__() + assert (kernel_size - 1) % 2 == 0 + assert num_layers > 1 + + self.layers = nn.ModuleList() + layer = TimeDepthSeparableConv( + in_channels, hid_channels, out_channels if num_layers == 1 else hid_channels, kernel_size, bias + ) + self.layers.append(layer) + for idx in range(num_layers - 1): + layer = TimeDepthSeparableConv( + hid_channels, + hid_channels, + out_channels if (idx + 1) == (num_layers - 1) else hid_channels, + kernel_size, + bias, + ) + self.layers.append(layer) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x * mask) + return x diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7ecee2bacb68cd330e18630531c97bc6f2e6a3 --- /dev/null +++ b/TTS/tts/layers/generic/transformer.py @@ -0,0 +1,89 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class FFTransformer(nn.Module): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn=1024, kernel_size_fft=3, dropout_p=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(in_out_channels, num_heads, dropout=dropout_p) + + padding = (kernel_size_fft - 1) // 2 + self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding) + self.conv2 = nn.Conv1d(hidden_channels_ffn, in_out_channels, kernel_size=kernel_size_fft, padding=padding) + + self.norm1 = nn.LayerNorm(in_out_channels) + self.norm2 = nn.LayerNorm(in_out_channels) + + self.dropout1 = nn.Dropout(dropout_p) + self.dropout2 = nn.Dropout(dropout_p) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """😦 ugly looking with all the transposing""" + src = src.permute(2, 0, 1) + src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src + src2) + # T x B x D -> B x D x T + src = src.permute(1, 2, 0) + src2 = self.conv2(F.relu(self.conv1(src))) + src2 = self.dropout2(src2) + src = src + src2 + src = src.transpose(1, 2) + src = self.norm2(src) + src = src.transpose(1, 2) + return src, enc_align + + +class FFTransformerBlock(nn.Module): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p): + super().__init__() + self.fft_layers = nn.ModuleList( + [ + FFTransformer( + in_out_channels=in_out_channels, + num_heads=num_heads, + hidden_channels_ffn=hidden_channels_ffn, + dropout_p=dropout_p, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument + """ + TODO: handle multi-speaker + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T] or [B, T]` + """ + if mask is not None and mask.ndim == 3: + mask = mask.squeeze(1) + # mask is negated, torch uses 1s and 0s reversely. + mask = ~mask.bool() + alignments = [] + for layer in self.fft_layers: + x, align = layer(x, src_key_padding_mask=mask) + alignments.append(align.unsqueeze(1)) + alignments = torch.cat(alignments, 1) + return x + + +class FFTDurationPredictor: + def __init__( + self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None + ): # pylint: disable=unused-argument + self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.proj = nn.Linear(in_channels, 1) + + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T]` + + TODO: Handle the cond input + """ + x = self.fft(x, mask=mask) + x = self.proj(x) + return x diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb45c7bcd455d29499848446faaca8036a8c0f9 --- /dev/null +++ b/TTS/tts/layers/generic/wavenet.py @@ -0,0 +1,171 @@ +import torch +from torch import nn + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class WN(torch.nn.Module): + """Wavenet layers with weight norm and no input conditioning. + + |-----------------------------------------------------------------------------| + | |-> tanh -| | + res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res + g -------------------------------------| |-> sigmoid -| | + o --------------------------------------------------------------------------- + --------- o + + Args: + in_channels (int): number of input channels. + hidden_channes (int): number of hidden channels. + kernel_size (int): filter kernel size for the first conv layer. + dilation_rate (int): dilations rate to increase dilation per layer. + If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers. + num_layers (int): number of wavenet layers. + c_in_channels (int): number of channels of conditioning input. + dropout_p (float): dropout rate. + weight_norm (bool): enable/disable weight norm for convolution layers. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): + super().__init__() + assert kernel_size % 2 == 1 + assert hidden_channels % 2 == 0 + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.dropout = nn.Dropout(dropout_p) + + # init conditioning layer + if c_in_channels > 0: + cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + # intermediate layers + for i in range(num_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + if i < num_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + # setup weight norm + if not weight_norm: + self.remove_weight_norm() + + def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + x_mask = 1.0 if x_mask is None else x_mask + if g is not None: + g = self.cond_layer(g) + for i in range(self.num_layers): + x_in = self.in_layers[i](x) + x_in = self.dropout(x_in) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.num_layers - 1: + x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.c_in_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class WNBlocks(nn.Module): + """Wavenet blocks. + + Note: After each block dilation resets to 1 and it increases in each block + along the dilation rate. + + Args: + in_channels (int): number of input channels. + hidden_channes (int): number of hidden channels. + kernel_size (int): filter kernel size for the first conv layer. + dilation_rate (int): dilations rate to increase dilation per layer. + If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers. + num_blocks (int): number of wavenet blocks. + num_layers (int): number of wavenet layers. + c_in_channels (int): number of channels of conditioning input. + dropout_p (float): dropout rate. + weight_norm (bool): enable/disable weight norm for convolution layers. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_blocks, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): + + super().__init__() + self.wn_blocks = nn.ModuleList() + for idx in range(num_blocks): + layer = WN( + in_channels=in_channels if idx == 0 else hidden_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + weight_norm=weight_norm, + ) + self.wn_blocks.append(layer) + + def forward(self, x, x_mask=None, g=None): + o = x + for layer in self.wn_blocks: + o = layer(o, x_mask, g) + return o diff --git a/TTS/tts/layers/glow_tts/__init__.py b/TTS/tts/layers/glow_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f57c37314c7a2ff10f56f3367f577b31f6dd9821 --- /dev/null +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -0,0 +1,141 @@ +import torch +from torch import nn + +from TTS.tts.layers.generic.normalization import ActNorm +from TTS.tts.layers.glow_tts.glow import CouplingBlock, InvConvNear + + +def squeeze(x, x_mask=None, num_sqz=2): + """GlowTTS squeeze operation + Increase number of channels and reduce number of time steps + by the same factor. + + Note: + each 's' is a n-dimensional vector. + ``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]`` + """ + b, c, t = x.size() + + t = (t // num_sqz) * num_sqz + x = x[:, :, :t] + x_sqz = x.view(b, c, t // num_sqz, num_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz) + + if x_mask is not None: + x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz] + else: + x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype) + return x_sqz * x_mask, x_mask + + +def unsqueeze(x, x_mask=None, num_sqz=2): + """GlowTTS unsqueeze operation + + Note: + each 's' is a n-dimensional vector. + ``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]`` + """ + b, c, t = x.size() + + x_unsqz = x.view(b, num_sqz, c // num_sqz, t) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz) + + if x_mask is not None: + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz) + else: + x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype) + return x_unsqz * x_mask, x_mask + + +class Decoder(nn.Module): + """Stack of Glow Decoder Modules. + + :: + + Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze + + Args: + in_channels (int): channels of input tensor. + hidden_channels (int): hidden decoder channels. + kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_flow_blocks (int): number of decoder blocks. + num_coupling_layers (int): number coupling layers. (number of wavenet layers.) + dropout_p (float): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): + super().__init__() + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_flow_blocks = num_flow_blocks + self.num_coupling_layers = num_coupling_layers + self.dropout_p = dropout_p + self.num_splits = num_splits + self.num_squeeze = num_squeeze + self.sigmoid_scale = sigmoid_scale + self.c_in_channels = c_in_channels + + self.flows = nn.ModuleList() + for _ in range(num_flow_blocks): + self.flows.append(ActNorm(channels=in_channels * num_squeeze)) + self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits)) + self.flows.append( + CouplingBlock( + in_channels * num_squeeze, + hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_coupling_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + sigmoid_scale=sigmoid_scale, + ) + ) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1 ,T]` + - g: :math:`[B, C]` + """ + if not reverse: + flows = self.flows + logdet_tot = 0 + else: + flows = reversed(self.flows) + logdet_tot = None + + if self.num_squeeze > 1: + x, x_mask = squeeze(x, x_mask, self.num_squeeze) + for f in flows: + if not reverse: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + logdet_tot += logdet + else: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + if self.num_squeeze > 1: + x, x_mask = unsqueeze(x, x_mask, self.num_squeeze) + return x, logdet_tot + + def store_inverse(self): + for f in self.flows: + f.store_inverse() diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..e766ed6ab5a0348eaca8d1482be124003d8b8c68 --- /dev/null +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -0,0 +1,69 @@ +import torch +from torch import nn + +from ..generic.normalization import LayerNorm + + +class DurationPredictor(nn.Module): + """Glow-TTS duration prediction model. + + :: + + [2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs + + Args: + in_channels (int): Number of channels of the input tensor. + hidden_channels (int): Number of hidden channels of the network. + kernel_size (int): Kernel size for the conv layers. + dropout_p (float): Dropout rate used after each conv layer. + """ + + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None): + super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + + # class arguments + self.in_channels = in_channels + self.filter_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + # layers + self.drop = nn.Dropout(dropout_p) + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(hidden_channels) + self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(hidden_channels) + # output layer + self.proj = nn.Conv1d(hidden_channels, 1, 1) + if cond_channels is not None and cond_channels != 0: + self.cond = nn.Conv1d(cond_channels, in_channels, 1) + + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1) + + def forward(self, x, x_mask, g=None, lang_emb=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if g is not None: + x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3b43e527f5e9ca2bd0880bf204e04a1526bc8dfb --- /dev/null +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -0,0 +1,179 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.generic.gated_conv import GatedConvBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.helpers import sequence_mask + + +class Encoder(nn.Module): + """Glow-TTS encoder module. + + :: + + embedding -> -> encoder_module -> --> proj_mean + | + |-> proj_var + | + |-> concat -> duration_predictor + ↑ + speaker_embed + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + hidden_channels (int): encoder's embedding size. + hidden_channels_ffn (int): transformer's feed-forward channels. + kernel_size (int): kernel size for conv layers and duration predictor. + dropout_p (float): dropout rate for any dropout layer. + mean_only (bool): if True, output only mean values and use constant std. + use_prenet (bool): if True, use pre-convolutional layers before transformer layers. + c_in_channels (int): number of channels in conditional input. + + Shapes: + - input: (B, T, C) + + :: + + suggested encoder params... + + for encoder_type == 'rel_pos_transformer' + encoder_params={ + 'kernel_size':3, + 'dropout_p': 0.1, + 'num_layers': 6, + 'num_heads': 2, + 'hidden_channels_ffn': 768, # 4 times the hidden_channels + 'input_length': None + } + + for encoder_type == 'gated_conv' + encoder_params={ + 'kernel_size':5, + 'dropout_p': 0.1, + 'num_layers': 9, + } + + for encoder_type == 'residual_conv_bn' + encoder_params={ + "kernel_size": 4, + "dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + } + + for encoder_type == 'time_depth_separable' + encoder_params={ + "kernel_size": 5, + 'num_layers': 9, + } + """ + + def __init__( + self, + num_chars, + out_channels, + hidden_channels, + hidden_channels_dp, + encoder_type, + encoder_params, + dropout_p_dp=0.1, + mean_only=False, + use_prenet=True, + c_in_channels=0, + ): + super().__init__() + # class arguments + self.num_chars = num_chars + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.hidden_channels_dp = hidden_channels_dp + self.dropout_p_dp = dropout_p_dp + self.mean_only = mean_only + self.use_prenet = use_prenet + self.c_in_channels = c_in_channels + self.encoder_type = encoder_type + # embedding layer + self.emb = nn.Embedding(num_chars, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + # init encoder module + if encoder_type.lower() == "rel_pos_transformer": + if use_prenet: + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = RelativePositionTransformer( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + elif encoder_type.lower() == "gated_conv": + self.encoder = GatedConvBlock(hidden_channels, **encoder_params) + elif encoder_type.lower() == "residual_conv_bn": + if use_prenet: + self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) + self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params) + self.postnet = nn.Sequential( + nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels) + ) + elif encoder_type.lower() == "time_depth_separable": + if use_prenet: + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = TimeDepthSeparableConvBlock( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + else: + raise ValueError(" [!] Unkown encoder type.") + + # final projection layers + self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) + if not mean_only: + self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) + # duration predictor + self.duration_predictor = DurationPredictor( + hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp + ) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B]` + - g (optional): :math:`[B, 1, T]` + """ + # embedding layer + # [B ,T, D] + x = self.emb(x) * math.sqrt(self.hidden_channels) + # [B, D, T] + x = torch.transpose(x, 1, -1) + # compute input sequence mask + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + # prenet + if hasattr(self, "prenet") and self.use_prenet: + x = self.prenet(x, x_mask) + # encoder + x = self.encoder(x, x_mask) + # postnet + if hasattr(self, "postnet"): + x = self.postnet(x) * x_mask + # set duration predictor input + if g is not None: + g_exp = g.expand(-1, -1, x.size(-1)) + x_dp = torch.cat([x.detach(), g_exp], 1) + else: + x_dp = x.detach() + # final projection layer + x_m = self.proj_m(x) * x_mask + if not self.mean_only: + x_logs = self.proj_s(x) * x_mask + else: + x_logs = torch.zeros_like(x_m) + # duration predictor + logw = self.duration_predictor(x_dp, x_mask) + return x_m, x_logs, logw, x_mask diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py new file mode 100644 index 0000000000000000000000000000000000000000..ff1b99e8ecc4de8fffd40011532e801e13f99c0c --- /dev/null +++ b/TTS/tts/layers/glow_tts/glow.py @@ -0,0 +1,234 @@ +from distutils.version import LooseVersion + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.wavenet import WN + +from ..generic.normalization import LayerNorm + + +class ResidualConv1dLayerNormBlock(nn.Module): + """Conv1d with Layer Normalization and residual connection as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + :: + + x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o + |---------------> conv1d_1x1 ------------------| + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of inner layer channels. + out_channels (int): number of output tensor channels. + kernel_size (int): kernel size of conv1d filter. + num_layers (int): number of blocks. + dropout_p (float): dropout rate for each block. + """ + + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_layers = num_layers + self.dropout_p = dropout_p + assert num_layers > 1, " [!] number of layers should be > 0." + assert kernel_size % 2 == 1, " [!] kernel size should be odd number." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + + for idx in range(num_layers): + self.conv_layers.append( + nn.Conv1d( + in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + x_res = x + for i in range(self.num_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x * x_mask) + x = F.dropout(F.relu(x), self.dropout_p, training=self.training) + x = x_res + self.proj(x) + return x * x_mask + + +class InvConvNear(nn.Module): + """Invertible Convolution with input splitting as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + Args: + channels (int): input and output channels. + num_splits (int): number of splits, also H and W of conv layer. + no_jacobian (bool): enable/disable jacobian computations. + + Note: + Split the input into groups of size self.num_splits and + perform 1x1 convolution separately. Cast 1x1 conv operation + to 2d by reshaping the input for efficiency. + """ + + def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + assert num_splits % 2 == 0 + self.channels = channels + self.num_splits = num_splits + self.no_jacobian = no_jacobian + self.weight_inv = None + + if LooseVersion(torch.__version__) < LooseVersion("1.9"): + w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] + else: + w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0] + + if torch.det(w_init) < 0: + w_init[:, 0] = -1 * w_init[:, 0] + self.weight = nn.Parameter(w_init) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + b, c, t = x.size() + assert c % self.num_splits == 0 + if x_mask is None: + x_mask = 1 + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t + else: + x_len = torch.sum(x_mask, [1, 2]) + + x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t) + + if reverse: + if self.weight_inv is not None: + weight = self.weight_inv + else: + weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + logdet = None + else: + weight = self.weight + if self.no_jacobian: + logdet = 0 + else: + logdet = torch.logdet(self.weight) * (c / self.num_splits) * x_len # [b] + + weight = weight.view(self.num_splits, self.num_splits, 1, 1) + z = F.conv2d(x, weight) + + z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t) + z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask + return z, logdet + + def store_inverse(self): + weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + self.weight_inv = nn.Parameter(weight_inv, requires_grad=False) + + +class CouplingBlock(nn.Module): + """Glow Affine Coupling block as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + :: + + x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o + '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of hidden channels. + kernel_size (int): WaveNet filter kernel size. + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_layers (int): number of WaveNet layers. + c_in_channels (int): number of conditioning input channels. + dropout_p (int): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. + + Note: + It does not use the conditional inputs differently from WaveGlow. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + sigmoid_scale=False, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + self.sigmoid_scale = sigmoid_scale + # input layer + start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) + start = torch.nn.utils.weight_norm(start) + self.start = start + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(hidden_channels, in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + # coupling layers + self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) + + def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if x_mask is None: + x_mask = 1 + x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] + + x = self.start(x_0) * x_mask + x = self.wn(x, x_mask, g) + out = self.end(x) + + z_0 = x_0 + t = out[:, : self.in_channels // 2, :] + s = out[:, self.in_channels // 2 :, :] + if self.sigmoid_scale: + s = torch.log(1e-6 + torch.sigmoid(s + 2)) + + if reverse: + z_1 = (x_1 - t) * torch.exp(-s) * x_mask + logdet = None + else: + z_1 = (t + torch.exp(s) * x_1) * x_mask + logdet = torch.sum(s * x_mask, [1, 2]) + + z = torch.cat([z_0, z_1], 1) + return z, logdet + + def store_inverse(self): + self.wn.remove_weight_norm() diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0f837abfeb441477de419f6cf4c9a05730a351c8 --- /dev/null +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -0,0 +1,434 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 + + +class RelativePositionMultiHeadAttention(nn.Module): + """Multi-head attention with Relative Positional embedding. + https://arxiv.org/pdf/1809.04281.pdf + + It learns positional embeddings for a window of neighbours. For keys and values, + it learns different set of embeddings. Key embeddings are agregated with the attention + scores and value embeddings are aggregated with the output. + + Note: + Example with relative attention window size 2 + + - input = [a, b, c, d, e] + - rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)] + + So it learns 4 embedding vectors (in total 8) separately for key and value vectors. + + Considering the input c + + - e(t-2) corresponds to c -> a + - e(t-2) corresponds to c -> b + - e(t-2) corresponds to c -> d + - e(t-2) corresponds to c -> e + + These embeddings are shared among different time steps. So input a, b, d and e also uses + the same embeddings. + + Embeddings are ignored when the relative window is out of limit for the first and the last + n items. + + Args: + channels (int): input and inner layer channels. + out_channels (int): output channels. + num_heads (int): number of attention heads. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + heads_share (bool, optional): [description]. Defaults to True. + dropout_p (float, optional): dropout rate. Defaults to 0.. + input_length (int, optional): intput length for positional encoding. Defaults to None. + proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False. + proximal_init (bool, optional): enable/disable poximal init as in the paper. + Init key and query layer weights the same. Defaults to False. + """ + + def __init__( + self, + channels, + out_channels, + num_heads, + rel_attn_window_size=None, + heads_share=True, + dropout_p=0.0, + input_length=None, + proximal_bias=False, + proximal_init=False, + ): + + super().__init__() + assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." + # class attributes + self.channels = channels + self.out_channels = out_channels + self.num_heads = num_heads + self.rel_attn_window_size = rel_attn_window_size + self.heads_share = heads_share + self.input_length = input_length + self.proximal_bias = proximal_bias + self.dropout_p = dropout_p + self.attn = None + # query, key, value layers + self.k_channels = channels // num_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + # output layers + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.dropout = nn.Dropout(dropout_p) + # relative positional encoding layers + if rel_attn_window_size is not None: + n_heads_rel = 1 if heads_share else num_heads + rel_stddev = self.k_channels**-0.5 + emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + self.register_parameter("emb_rel_k", emb_rel_k) + self.register_parameter("emb_rel_v", emb_rel_v) + + # init layers + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + # proximal bias + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - c: :math:`[B, C, T]` + - attn_mask: :math:`[B, 1, T, T]` + """ + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + x, self.attn = self.attention(q, k, v, mask=attn_mask) + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + # compute raw attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + # relative positional encoding for scores + if self.rel_attn_window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + # get relative key embeddings + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + # proximan bias + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype) + # attention score masking + if mask is not None: + # add small value to prevent oor error. + scores = scores.masked_fill(mask == 0, -1e4) + if self.input_length is not None: + block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length) + scores = scores * block_mask + -1e4 * (1 - block_mask) + # attention score normalization + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + # apply dropout to attention weights + p_attn = self.dropout(p_attn) + # compute output + output = torch.matmul(p_attn, value) + # relative positional encoding for values + if self.rel_attn_window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + @staticmethod + def _matmul_with_relative_values(p_attn, re): + """ + Args: + p_attn (Tensor): attention weights. + re (Tensor): relative value embedding vector. (a_(i,j)^V) + + Shapes: + -p_attn: :math:`[B, H, T, V]` + -re: :math:`[H or 1, V, D]` + -logits: :math:`[B, H, T, D]` + """ + logits = torch.matmul(p_attn, re.unsqueeze(0)) + return logits + + @staticmethod + def _matmul_with_relative_keys(query, re): + """ + Args: + query (Tensor): batch of query vectors. (x*W^Q) + re (Tensor): relative key embedding vector. (a_(i,j)^K) + + Shapes: + - query: :math:`[B, H, T, D]` + - re: :math:`[H or 1, V, D]` + - logits: :math:`[B, H, T, V]` + """ + # logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)]) + logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1)) + return logits + + def _get_relative_embeddings(self, relative_embeddings, length): + """Convert embedding vestors to a tensor of embeddings""" + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.rel_attn_window_size + 1), 0) + slice_start_position = max((self.rel_attn_window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + @staticmethod + def _relative_position_to_absolute_position(x): + """Converts tensor from relative to absolute indexing for local attention. + Shapes: + x: :math:`[B, C, T, 2 * T - 1]` + Returns: + A Tensor of shape :math:`[B, C, T, T]` + """ + batch, heads, length, _ = x.size() + # Pad to shift from relative to absolute indexing. + x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) + # Pad extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + @staticmethod + def _absolute_position_to_relative_position(x): + """ + Shapes: + - x: :math:`[B, C, T, T]` + - ret: :math:`[B, C, T, 2*T-1]` + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + @staticmethod + def _attn_proximity_bias(length): + """Produce an attention mask that discourages distant + attention values. + Args: + length (int): an integer scalar. + Returns: + a Tensor with shape :math:`[1, 1, T, T]` + """ + # L + r = torch.arange(length, dtype=torch.float32) + # L x L + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + # scale mask values + diff = -torch.log1p(torch.abs(diff)) + # 1 x 1 x L x L + return diff.unsqueeze(0).unsqueeze(0) + + +class FeedForwardNetwork(nn.Module): + """Feed Forward Inner layers for Transformer. + + Args: + in_channels (int): input tensor channels. + out_channels (int): output tensor channels. + hidden_channels (int): inner layers hidden channels. + kernel_size (int): conv1d filter kernel size. + dropout_p (float, optional): dropout rate. Defaults to 0. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False): + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size) + self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + x = torch.relu(x) + x = self.dropout(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + @staticmethod + def _pad_shape(padding): + l = padding[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +class RelativePositionTransformer(nn.Module): + """Transformer with Relative Potional Encoding. + https://arxiv.org/abs/1803.02155 + + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. + layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm + primitive. Use type "2", type "1: is for backward compat. Defaults to "1". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size: int = None, + input_length: int = None, + layer_norm_type: str = "1", + ): + super().__init__() + self.hidden_channels = hidden_channels + self.hidden_channels_ffn = hidden_channels_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.rel_attn_window_size = rel_attn_window_size + + self.dropout = nn.Dropout(dropout_p) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for idx in range(self.num_layers): + self.attn_layers.append( + RelativePositionMultiHeadAttention( + hidden_channels if idx != 0 else in_channels, + hidden_channels, + num_heads, + rel_attn_window_size=rel_attn_window_size, + dropout_p=dropout_p, + input_length=input_length, + ) + ) + if layer_norm_type == "1": + self.norm_layers_1.append(LayerNorm(hidden_channels)) + elif layer_norm_type == "2": + self.norm_layers_1.append(LayerNorm2(hidden_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + if hidden_channels != out_channels and (idx + 1) == self.num_layers: + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + self.ffn_layers.append( + FeedForwardNetwork( + hidden_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, + hidden_channels_ffn, + kernel_size, + dropout_p=dropout_p, + ) + ) + + if layer_norm_type == "1": + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + elif layer_norm_type == "2": + self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + def forward(self, x, x_mask): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.num_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.dropout(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.dropout(y) + + if (i + 1) == self.num_layers and hasattr(self, "proj"): + x = self.proj(x) + + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..fea52d3d2d206ff786263fa38554cf07ec3d55a0 --- /dev/null +++ b/TTS/tts/layers/losses.py @@ -0,0 +1,859 @@ +import math + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional + +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.ssim import ssim +from TTS.utils.audio import TorchSTFT + + +# pylint: disable=abstract-method +# relates https://github.com/pytorch/pytorch/issues/42305 +class L1LossMasked(nn.Module): + def __init__(self, seq_len_norm): + super().__init__() + self.seq_len_norm = seq_len_norm + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len, dim) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len, dim) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + x: B x T X D + target: B x T x D + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.l1_loss(x * mask, target * mask, reduction="none") + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.l1_loss(x * mask, target * mask, reduction="sum") + loss = loss / mask.sum() + return loss + + +class MSELossMasked(nn.Module): + def __init__(self, seq_len_norm): + super().__init__() + self.seq_len_norm = seq_len_norm + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len, dim) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len, dim) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + - x: :math:`[B, T, D]` + - target: :math:`[B, T, D]` + - length: :math:`B` + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.mse_loss(x * mask, target * mask, reduction="none") + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.mse_loss(x * mask, target * mask, reduction="sum") + loss = loss / mask.sum() + return loss + + +class SSIMLoss(torch.nn.Module): + """SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity""" + + def __init__(self): + super().__init__() + self.loss_func = ssim + + def forward(self, y_hat, y, length=None): + """ + Args: + y_hat (tensor): model prediction values. + y (tensor): target values. + length (tensor): length of each sample in a batch. + Shapes: + y_hat: B x T X D + y: B x T x D + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + if length is not None: + m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device) + y_hat, y = y_hat * m, y * m + return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1)) + + +class AttentionEntropyLoss(nn.Module): + # pylint: disable=R0201 + def forward(self, align): + """ + Forces attention to be more decisive by penalizing + soft attention weights + + TODO: arguments + TODO: unit_test + """ + entropy = torch.distributions.Categorical(probs=align).entropy() + loss = (entropy / np.log(align.shape[1])).mean() + return loss + + +class BCELossMasked(nn.Module): + def __init__(self, pos_weight): + super().__init__() + self.pos_weight = pos_weight + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + x: B x T + target: B x T + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + if length is not None: + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() + x = x * mask + target = target * mask + num_items = mask.sum() + else: + num_items = torch.numel(x) + loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") + loss = loss / num_items + return loss + + +class DifferentailSpectralLoss(nn.Module): + """Differential Spectral Loss + https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf""" + + def __init__(self, loss_func): + super().__init__() + self.loss_func = loss_func + + def forward(self, x, target, length=None): + """ + Shapes: + x: B x T + target: B x T + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + x_diff = x[:, 1:] - x[:, :-1] + target_diff = target[:, 1:] - target[:, :-1] + if length is None: + return self.loss_func(x_diff, target_diff) + return self.loss_func(x_diff, target_diff, length - 1) + + +class GuidedAttentionLoss(torch.nn.Module): + def __init__(self, sigma=0.4): + super().__init__() + self.sigma = sigma + + def _make_ga_masks(self, ilens, olens): + B = len(ilens) + max_ilen = max(ilens) + max_olen = max(olens) + ga_masks = torch.zeros((B, max_olen, max_ilen)) + for idx, (ilen, olen) in enumerate(zip(ilens, olens)): + ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma) + return ga_masks + + def forward(self, att_ws, ilens, olens): + ga_masks = self._make_ga_masks(ilens, olens).to(att_ws.device) + seq_masks = self._make_masks(ilens, olens).to(att_ws.device) + losses = ga_masks * att_ws + loss = torch.mean(losses.masked_select(seq_masks)) + return loss + + @staticmethod + def _make_ga_mask(ilen, olen, sigma): + grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) + grid_x, grid_y = grid_x.float(), grid_y.float() + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) + + @staticmethod + def _make_masks(ilens, olens): + in_masks = sequence_mask(ilens) + out_masks = sequence_mask(olens) + return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) + + +class Huber(nn.Module): + # pylint: disable=R0201 + def forward(self, x, y, length=None): + """ + Shapes: + x: B x T + y: B x T + length: B + """ + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float() + return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum() + + +class ForwardSumLoss(nn.Module): + def __init__(self, blank_logprob=-1): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=3) + self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) + self.blank_logprob = blank_logprob + + def forward(self, attn_logprob, in_lens, out_lens): + key_lens = in_lens + query_lens = out_lens + attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) + + total_loss = 0.0 + for bid in range(attn_logprob.shape[0]): + target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) + curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] + + curr_logprob = self.log_softmax(curr_logprob[None])[0] + loss = self.ctc_loss( + curr_logprob, + target_seq, + input_lengths=query_lens[bid : bid + 1], + target_lengths=key_lens[bid : bid + 1], + ) + total_loss = total_loss + loss + + total_loss = total_loss / attn_logprob.shape[0] + return total_loss + + +######################## +# MODEL LOSS LAYERS +######################## + + +class TacotronLoss(torch.nn.Module): + """Collection of Tacotron set-up based on provided config.""" + + def __init__(self, c, ga_sigma=0.4): + super().__init__() + self.stopnet_pos_weight = c.stopnet_pos_weight + self.use_capacitron_vae = c.use_capacitron_vae + if self.use_capacitron_vae: + self.capacitron_capacity = c.capacitron_vae.capacitron_capacity + self.capacitron_vae_loss_alpha = c.capacitron_vae.capacitron_VAE_loss_alpha + self.ga_alpha = c.ga_alpha + self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha + self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha + self.decoder_alpha = c.decoder_loss_alpha + self.postnet_alpha = c.postnet_loss_alpha + self.decoder_ssim_alpha = c.decoder_ssim_alpha + self.postnet_ssim_alpha = c.postnet_ssim_alpha + self.config = c + + # postnet and decoder loss + if c.loss_masking: + self.criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron"] else MSELossMasked(c.seq_len_norm) + else: + self.criterion = nn.L1Loss() if c.model in ["Tacotron"] else nn.MSELoss() + # guided attention loss + if c.ga_alpha > 0: + self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma) + # differential spectral loss + if c.postnet_diff_spec_alpha > 0 or c.decoder_diff_spec_alpha > 0: + self.criterion_diff_spec = DifferentailSpectralLoss(loss_func=self.criterion) + # ssim loss + if c.postnet_ssim_alpha > 0 or c.decoder_ssim_alpha > 0: + self.criterion_ssim = SSIMLoss() + # stopnet loss + # pylint: disable=not-callable + self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None + + # For dev pruposes only + self.criterion_capacitron_reconstruction_loss = nn.L1Loss(reduction="sum") + + def forward( + self, + postnet_output, + decoder_output, + mel_input, + linear_input, + stopnet_output, + stopnet_target, + stop_target_length, + capacitron_vae_outputs, + output_lens, + decoder_b_output, + alignments, + alignment_lens, + alignments_backwards, + input_lens, + ): + + # decoder outputs linear or mel spectrograms for Tacotron and Tacotron2 + # the target should be set acccordingly + postnet_target = linear_input if self.config.model.lower() in ["tacotron"] else mel_input + + return_dict = {} + # remove lengths if no masking is applied + if not self.config.loss_masking: + output_lens = None + # decoder and postnet losses + if self.config.loss_masking: + if self.decoder_alpha > 0: + decoder_loss = self.criterion(decoder_output, mel_input, output_lens) + if self.postnet_alpha > 0: + postnet_loss = self.criterion(postnet_output, postnet_target, output_lens) + else: + if self.decoder_alpha > 0: + decoder_loss = self.criterion(decoder_output, mel_input) + if self.postnet_alpha > 0: + postnet_loss = self.criterion(postnet_output, postnet_target) + loss = self.decoder_alpha * decoder_loss + self.postnet_alpha * postnet_loss + return_dict["decoder_loss"] = decoder_loss + return_dict["postnet_loss"] = postnet_loss + + if self.use_capacitron_vae: + # extract capacitron vae infos + posterior_distribution, prior_distribution, beta = capacitron_vae_outputs + + # KL divergence term between the posterior and the prior + kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution)) + + # Limit the mutual information between the data and latent space by the variational capacity limit + kl_capacity = kl_term - self.capacitron_capacity + + # pass beta through softplus to keep it positive + beta = torch.nn.functional.softplus(beta)[0] + + # This is the term going to the main ADAM optimiser, we detach beta because + # beta is optimised by a separate, SGD optimiser below + capacitron_vae_loss = beta.detach() * kl_capacity + + # normalize the capacitron_vae_loss as in L1Loss or MSELoss. + # After this, both the standard loss and capacitron_vae_loss will be in the same scale. + # For this reason we don't need use L1Loss and MSELoss in "sum" reduction mode. + # Note: the batch is not considered because the L1Loss was calculated in "sum" mode + # divided by the batch size, So not dividing the capacitron_vae_loss by B is legitimate. + + # get B T D dimension from input + B, T, D = mel_input.size() + # normalize + if self.config.loss_masking: + # if mask loss get T using the mask + T = output_lens.sum() / B + + # Only for dev purposes to be able to compare the reconstruction loss with the values in the + # original Capacitron paper + return_dict["capaciton_reconstruction_loss"] = ( + self.criterion_capacitron_reconstruction_loss(decoder_output, mel_input) / decoder_output.size(0) + ) + kl_capacity + + capacitron_vae_loss = capacitron_vae_loss / (T * D) + capacitron_vae_loss = capacitron_vae_loss * self.capacitron_vae_loss_alpha + + # This is the term to purely optimise beta and to pass into the SGD optimizer + beta_loss = torch.negative(beta) * kl_capacity.detach() + + loss += capacitron_vae_loss + + return_dict["capacitron_vae_loss"] = capacitron_vae_loss + return_dict["capacitron_vae_beta_loss"] = beta_loss + return_dict["capacitron_vae_kl_term"] = kl_term + return_dict["capacitron_beta"] = beta + + stop_loss = ( + self.criterion_st(stopnet_output, stopnet_target, stop_target_length) + if self.config.stopnet + else torch.zeros(1) + ) + loss += stop_loss + return_dict["stopnet_loss"] = stop_loss + + # backward decoder loss (if enabled) + if self.config.bidirectional_decoder: + if self.config.loss_masking: + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input, output_lens) + else: + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1,)), decoder_output) + loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss) + return_dict["decoder_b_loss"] = decoder_b_loss + return_dict["decoder_c_loss"] = decoder_c_loss + + # double decoder consistency loss (if enabled) + if self.config.double_decoder_consistency: + if self.config.loss_masking: + decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) + else: + decoder_b_loss = self.criterion(decoder_b_output, mel_input) + # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) + attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) + loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss) + return_dict["decoder_coarse_loss"] = decoder_b_loss + return_dict["decoder_ddc_loss"] = attention_c_loss + + # guided attention loss (if enabled) + if self.config.ga_alpha > 0: + ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens) + loss += ga_loss * self.ga_alpha + return_dict["ga_loss"] = ga_loss + + # decoder differential spectral loss + if self.config.decoder_diff_spec_alpha > 0: + decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens) + loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha + return_dict["decoder_diff_spec_loss"] = decoder_diff_spec_loss + + # postnet differential spectral loss + if self.config.postnet_diff_spec_alpha > 0: + postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, postnet_target, output_lens) + loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha + return_dict["postnet_diff_spec_loss"] = postnet_diff_spec_loss + + # decoder ssim loss + if self.config.decoder_ssim_alpha > 0: + decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens) + loss += decoder_ssim_loss * self.postnet_ssim_alpha + return_dict["decoder_ssim_loss"] = decoder_ssim_loss + + # postnet ssim loss + if self.config.postnet_ssim_alpha > 0: + postnet_ssim_loss = self.criterion_ssim(postnet_output, postnet_target, output_lens) + loss += postnet_ssim_loss * self.postnet_ssim_alpha + return_dict["postnet_ssim_loss"] = postnet_ssim_loss + + return_dict["loss"] = loss + return return_dict + + +class GlowTTSLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.constant_factor = 0.5 * math.log(2 * math.pi) + + def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths): + return_dict = {} + # flow loss - neg log likelihood + pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[2]) + # duration loss - MSE + loss_dur = torch.sum((o_dur_log - o_attn_dur) ** 2) / torch.sum(x_lengths) + # duration loss - huber loss + # loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths) + return_dict["loss"] = log_mle + loss_dur + return_dict["log_mle"] = log_mle + return_dict["loss_dur"] = loss_dur + + # check if any loss is NaN + for key, loss in return_dict.items(): + if torch.isnan(loss): + raise RuntimeError(f" [!] NaN loss with {key}.") + return return_dict + + +def mse_loss_custom(x, y): + """MSE loss using the torch back-end without reduction. + It uses less VRAM than the raw code""" + expanded_x, expanded_y = torch.broadcast_tensors(x, y) + return torch._C._nn.mse_loss(expanded_x, expanded_y, 0) # pylint: disable=protected-access, c-extension-no-member + + +class MDNLoss(nn.Module): + """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.""" + + def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use + """ + Shapes: + mu: [B, D, T] + log_sigma: [B, D, T] + mel_spec: [B, D, T] + """ + B, T_seq, T_mel = logp.shape + log_alpha = logp.new_ones(B, T_seq, T_mel) * (-1e4) + log_alpha[:, 0, 0] = logp[:, 0, 0] + for t in range(1, T_mel): + prev_step = torch.cat( + [log_alpha[:, :, t - 1 : t], functional.pad(log_alpha[:, :, t - 1 : t], (0, 0, 1, -1), value=-1e4)], + dim=-1, + ) + log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t] + alpha_last = log_alpha[torch.arange(B), text_lengths - 1, mel_lengths - 1] + mdn_loss = -alpha_last.mean() / T_seq + return mdn_loss # , log_prob_matrix + + +class AlignTTSLoss(nn.Module): + """Modified AlignTTS Loss. + Computes + - L1 and SSIM losses from output spectrograms. + - Huber loss for duration predictor. + - MDNLoss for Mixture of Density Network. + + All loss values are aggregated by a weighted sum of the alpha values. + + Args: + c (dict): TTS model configuration. + """ + + def __init__(self, c): + super().__init__() + self.mdn_loss = MDNLoss() + self.spec_loss = MSELossMasked(False) + self.ssim = SSIMLoss() + self.dur_loss = MSELossMasked(False) + + self.ssim_alpha = c.ssim_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.spec_loss_alpha = c.spec_loss_alpha + self.mdn_alpha = c.mdn_alpha + + def forward( + self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, phase + ): + # ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) + spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 + if phase == 0: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + elif phase == 1: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 2: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 3: + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) + else: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) + loss = ( + self.spec_loss_alpha * spec_loss + + self.ssim_alpha * ssim_loss + + self.dur_loss_alpha * dur_loss + + self.mdn_alpha * mdn_loss + ) + return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} + + +class VitsGeneratorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.kl_loss_alpha = c.kl_loss_alpha + self.gen_loss_alpha = c.gen_loss_alpha + self.feat_loss_alpha = c.feat_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.mel_loss_alpha = c.mel_loss_alpha + self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha + self.stft = TorchSTFT( + c.audio.fft_size, + c.audio.hop_length, + c.audio.win_length, + sample_rate=c.audio.sample_rate, + mel_fmin=c.audio.mel_fmin, + mel_fmax=c.audio.mel_fmax, + n_mels=c.audio.num_mels, + use_mel=True, + do_amp_to_db=True, + ) + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + @staticmethod + def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l + + @staticmethod + def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): + return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + + def forward( + self, + mel_slice, + mel_slice_hat, + z_p, + logs_q, + m_p, + logs_p, + z_len, + scores_disc_fake, + feats_disc_fake, + feats_disc_real, + loss_duration, + use_speaker_encoder_as_loss=False, + gt_spk_emb=None, + syn_spk_emb=None, + ): + """ + Shapes: + - mel_slice : :math:`[B, 1, T]` + - mel_slice_hat: :math:`[B, 1, T]` + - z_p: :math:`[B, C, T]` + - logs_q: :math:`[B, C, T]` + - m_p: :math:`[B, C, T]` + - logs_p: :math:`[B, C, T]` + - z_len: :math:`[B]` + - scores_disc_fake[i]: :math:`[B, C]` + - feats_disc_fake[i][j]: :math:`[B, C, T', P]` + - feats_disc_real[i][j]: :math:`[B, C, T', P]` + """ + loss = 0.0 + return_dict = {} + z_mask = sequence_mask(z_len).float() + # compute losses + loss_kl = ( + self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) + * self.kl_loss_alpha + ) + loss_feat = ( + self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + ) + loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha + loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha + loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha + loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration + + if use_speaker_encoder_as_loss: + loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha + loss = loss + loss_se + return_dict["loss_spk_encoder"] = loss_se + # pass losses to the dict + return_dict["loss_gen"] = loss_gen + return_dict["loss_kl"] = loss_kl + return_dict["loss_feat"] = loss_feat + return_dict["loss_mel"] = loss_mel + return_dict["loss_duration"] = loss_duration + return_dict["loss"] = loss + return return_dict + + +class VitsDiscriminatorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.disc_loss_alpha = c.disc_loss_alpha + + @staticmethod + def discriminator_loss(scores_real, scores_fake): + loss = 0 + real_losses = [] + fake_losses = [] + for dr, dg in zip(scores_real, scores_fake): + dr = dr.float() + dg = dg.float() + real_loss = torch.mean((1 - dr) ** 2) + fake_loss = torch.mean(dg**2) + loss += real_loss + fake_loss + real_losses.append(real_loss.item()) + fake_losses.append(fake_loss.item()) + return loss, real_losses, fake_losses + + def forward(self, scores_disc_real, scores_disc_fake): + loss = 0.0 + return_dict = {} + loss_disc, loss_disc_real, _ = self.discriminator_loss( + scores_real=scores_disc_real, scores_fake=scores_disc_fake + ) + return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha + loss = loss + return_dict["loss_disc"] + return_dict["loss"] = loss + + for i, ldr in enumerate(loss_disc_real): + return_dict[f"loss_disc_real_{i}"] = ldr + return return_dict + + +class ForwardTTSLoss(nn.Module): + """Generic configurable ForwardTTS loss.""" + + def __init__(self, c): + super().__init__() + if c.spec_loss_type == "mse": + self.spec_loss = MSELossMasked(False) + elif c.spec_loss_type == "l1": + self.spec_loss = L1LossMasked(False) + else: + raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type)) + + if c.duration_loss_type == "mse": + self.dur_loss = MSELossMasked(False) + elif c.duration_loss_type == "l1": + self.dur_loss = L1LossMasked(False) + elif c.duration_loss_type == "huber": + self.dur_loss = Huber() + else: + raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type)) + + if c.model_args.use_aligner: + self.aligner_loss = ForwardSumLoss() + self.aligner_loss_alpha = c.aligner_loss_alpha + + if c.model_args.use_pitch: + self.pitch_loss = MSELossMasked(False) + self.pitch_loss_alpha = c.pitch_loss_alpha + + if c.use_ssim_loss: + self.ssim = SSIMLoss() if c.use_ssim_loss else None + self.ssim_loss_alpha = c.ssim_loss_alpha + + self.spec_loss_alpha = c.spec_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.binary_alignment_loss_alpha = c.binary_align_loss_alpha + self.spk_encoder_loss_alpha = c.spk_encoder_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + @staticmethod + def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): + return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + input_lens, + alignment_logprob=None, + alignment_hard=None, + alignment_soft=None, + binary_loss_weight=None, + train_aligner=True, + use_speaker_encoder_as_loss=False, + gt_spk_emb=None, + syn_spk_emb=None + ): + loss = 0 + return_dict = {} + if hasattr(self, "ssim") and self.ssim_loss_alpha > 0: + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.ssim_loss_alpha * ssim_loss + return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss + + if self.spec_loss_alpha > 0: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.spec_loss_alpha * spec_loss + return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss + + if self.dur_loss_alpha > 0: + log_dur_tgt = torch.log(dur_target.float() + 1) + dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) + loss = loss + self.dur_loss_alpha * dur_loss + return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss + + if hasattr(self, "pitch_loss") and self.pitch_loss_alpha > 0: + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) + loss = loss + self.pitch_loss_alpha * pitch_loss + return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + + if train_aligner: + if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: + aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + loss = loss + self.aligner_loss_alpha * aligner_loss + return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + + if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss + if binary_loss_weight: + return_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + + if use_speaker_encoder_as_loss: + loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha + loss = loss + loss_se + return_dict["loss_spk_encoder"] = loss_se + + return_dict["loss"] = loss + return return_dict diff --git a/TTS/tts/layers/tacotron/__init__.py b/TTS/tts/layers/tacotron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/tacotron/attentions.py b/TTS/tts/layers/tacotron/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a90d72010066c1e3e09fd195c25954282e7526 --- /dev/null +++ b/TTS/tts/layers/tacotron/attentions.py @@ -0,0 +1,487 @@ +import torch +from scipy.stats import betabinom +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.tacotron.common_layers import Linear + + +class LocationLayer(nn.Module): + """Layers for Location Sensitive Attention + + Args: + attention_dim (int): number of channels in the input tensor. + attention_n_filters (int, optional): number of filters in convolution. Defaults to 32. + attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31. + """ + + def __init__(self, attention_dim, attention_n_filters=32, attention_kernel_size=31): + super().__init__() + self.location_conv1d = nn.Conv1d( + in_channels=2, + out_channels=attention_n_filters, + kernel_size=attention_kernel_size, + stride=1, + padding=(attention_kernel_size - 1) // 2, + bias=False, + ) + self.location_dense = Linear(attention_n_filters, attention_dim, bias=False, init_gain="tanh") + + def forward(self, attention_cat): + """ + Shapes: + attention_cat: [B, 2, C] + """ + processed_attention = self.location_conv1d(attention_cat) + processed_attention = self.location_dense(processed_attention.transpose(1, 2)) + return processed_attention + + +class GravesAttention(nn.Module): + """Graves Attention as is ref1 with updates from ref2. + ref1: https://arxiv.org/abs/1910.10288 + ref2: https://arxiv.org/pdf/1906.01083.pdf + + Args: + query_dim (int): number of channels in query tensor. + K (int): number of Gaussian heads to be used for computing attention. + """ + + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) + + def __init__(self, query_dim, K): + + super().__init__() + self._mask_value = 1e-8 + self.K = K + # self.attention_alignment = 0.05 + self.eps = 1e-5 + self.J = None + self.N_a = nn.Sequential( + nn.Linear(query_dim, query_dim, bias=True), nn.ReLU(), nn.Linear(query_dim, 3 * K, bias=True) + ) + self.attention_weights = None + self.mu_prev = None + self.init_layers() + + def init_layers(self): + torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K) : (3 * self.K)], 1.0) # bias mean + torch.nn.init.constant_(self.N_a[2].bias[self.K : (2 * self.K)], 10) # bias std + + def init_states(self, inputs): + if self.J is None or inputs.shape[1] + 1 > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1] + 2.0).to(inputs.device) + 0.5 + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + # pylint: disable=R0201 + # pylint: disable=unused-argument + def preprocess_inputs(self, inputs): + return None + + def forward(self, query, inputs, processed_inputs, mask): + """ + Shapes: + query: [B, C_attention_rnn] + inputs: [B, T_in, C_encoder] + processed_inputs: place_holder + mask: [B, T_in] + """ + gbk_t = self.N_a(query) + gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) + + # attention model parameters + # each B x K + g_t = gbk_t[:, 0, :] + b_t = gbk_t[:, 1, :] + k_t = gbk_t[:, 2, :] + + # dropout to decorrelate attention heads + g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training) + + # attention GMM parameters + sig_t = torch.nn.functional.softplus(b_t) + self.eps + + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = torch.softmax(g_t, dim=-1) + self.eps + + j = self.J[: inputs.size(1) + 1] + + # attention weights + phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) + + # discritize attention weights + alpha_t = torch.sum(phi_t, 1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + alpha_t[alpha_t == 0] = 1e-8 + + # apply masking + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t + return context + + +class OriginalAttention(nn.Module): + """Bahdanau Attention with various optional modifications. + - Location sensitive attnetion: https://arxiv.org/abs/1712.05884 + - Forward Attention: https://arxiv.org/abs/1807.06736 + state masking at inference + - Using sigmoid instead of softmax normalization + - Attention windowing at inference time + + Note: + Location Sensitive Attention extends the additive attention mechanism + to use cumulative attention weights from previous decoder time steps with the current time step features. + + Forward attention computes most probable monotonic alignment. The modified attention probabilities at each + timestep are computed recursively by the forward algorithm. + + Transition agent in the forward attention explicitly gates the attention mechanism whether to move forward or + stay at each decoder timestep. + + Attention windowing is a inductive prior that prevents the model from attending to previous and future timesteps + beyond a certain window. + + Args: + query_dim (int): number of channels in the query tensor. + embedding_dim (int): number of channels in the vakue tensor. In general, the value tensor is the output of the encoder layer. + attention_dim (int): number of channels of the inner attention layers. + location_attention (bool): enable/disable location sensitive attention. + attention_location_n_filters (int): number of location attention filters. + attention_location_kernel_size (int): filter size of location attention convolution layer. + windowing (int): window size for attention windowing. if it is 5, for computing the attention, it only considers the time steps [(t-5), ..., (t+5)] of the input. + norm (str): normalization method applied to the attention weights. 'softmax' or 'sigmoid' + forward_attn (bool): enable/disable forward attention. + trans_agent (bool): enable/disable transition agent in the forward attention. + forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ): + super().__init__() + self.query_layer = Linear(query_dim, attention_dim, bias=False, init_gain="tanh") + self.inputs_layer = Linear(embedding_dim, attention_dim, bias=False, init_gain="tanh") + self.v = Linear(attention_dim, 1, bias=True) + if trans_agent: + self.ta = nn.Linear(query_dim + embedding_dim, 1, bias=True) + if location_attention: + self.location_layer = LocationLayer( + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ) + self._mask_value = -float("inf") + self.windowing = windowing + self.win_idx = None + self.norm = norm + self.forward_attn = forward_attn + self.trans_agent = trans_agent + self.forward_attn_mask = forward_attn_mask + self.location_attention = location_attention + + def init_win_idx(self): + self.win_idx = -1 + self.win_back = 2 + self.win_front = 6 + + def init_forward_attn(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) + self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) + + def init_location_attention(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights_cum = torch.zeros([B, T], device=inputs.device) + + def init_states(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights = torch.zeros([B, T], device=inputs.device) + if self.location_attention: + self.init_location_attention(inputs) + if self.forward_attn: + self.init_forward_attn(inputs) + if self.windowing: + self.init_win_idx() + + def preprocess_inputs(self, inputs): + return self.inputs_layer(inputs) + + def update_location_attention(self, alignments): + self.attention_weights_cum += alignments + + def get_location_attention(self, query, processed_inputs): + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_cat) + energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def get_attention(self, query, processed_inputs): + processed_query = self.query_layer(query.unsqueeze(1)) + energies = self.v(torch.tanh(processed_query + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def apply_windowing(self, attention, inputs): + back_win = self.win_idx - self.win_back + front_win = self.win_idx + self.win_front + if back_win > 0: + attention[:, :back_win] = -float("inf") + if front_win < inputs.shape[1]: + attention[:, front_win:] = -float("inf") + # this is a trick to solve a special problem. + # but it does not hurt. + if self.win_idx == -1: + attention[:, 0] = attention.max() + # Update the window + self.win_idx = torch.argmax(attention, 1).long()[0].item() + return attention + + def apply_forward_attention(self, alignment): + # forward attention + fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0)) + # compute transition potentials + alpha = ((1 - self.u) * self.alpha + self.u * fwd_shifted_alpha + 1e-8) * alignment + # force incremental alignment + if not self.training and self.forward_attn_mask: + _, n = fwd_shifted_alpha.max(1) + val, _ = alpha.max(1) + for b in range(alignment.shape[0]): + alpha[b, n[b] + 3 :] = 0 + alpha[b, : (n[b] - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step + # renormalize attention weights + alpha = alpha / alpha.sum(dim=1, keepdim=True) + return alpha + + def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: [B, C_attn_rnn] + inputs: [B, T_en, D_en] + processed_inputs: [B, T_en, D_attn] + mask: [B, T_en] + """ + if self.location_attention: + attention, _ = self.get_location_attention(query, processed_inputs) + else: + attention, _ = self.get_attention(query, processed_inputs) + # apply masking + if mask is not None: + attention.data.masked_fill_(~mask, self._mask_value) + # apply windowing - only in eval mode + if not self.training and self.windowing: + attention = self.apply_windowing(attention, inputs) + + # normalize attention values + if self.norm == "softmax": + alignment = torch.softmax(attention, dim=-1) + elif self.norm == "sigmoid": + alignment = torch.sigmoid(attention) / torch.sigmoid(attention).sum(dim=1, keepdim=True) + else: + raise ValueError("Unknown value for attention norm type") + + if self.location_attention: + self.update_location_attention(alignment) + + # apply forward attention if enabled + if self.forward_attn: + alignment = self.apply_forward_attention(alignment) + self.alpha = alignment + + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + self.attention_weights = alignment + + # compute transition agent + if self.forward_attn and self.trans_agent: + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) + self.u = torch.sigmoid(self.ta(ta_input)) + return context + + +class MonotonicDynamicConvolutionAttention(nn.Module): + """Dynamic convolution attention from + https://arxiv.org/pdf/1910.10288.pdf + + + query -> linear -> tanh -> linear ->| + | mask values + v | | + atten_w(t-1) -|-> conv1d_dynamic -> linear -|-> tanh -> + -> softmax -> * -> * -> context + |-> conv1d_static -> linear -| | + |-> conv1d_prior -> log ----------------| + + query: attention rnn output. + + Note: + Dynamic convolution attention is an alternation of the location senstive attention with + dynamically computed convolution filters from the previous attention scores and a set of + constraints to keep the attention alignment diagonal. + DCA is sensitive to mixed precision training and might cause instable training. + + Args: + query_dim (int): number of channels in the query tensor. + embedding_dim (int): number of channels in the value tensor. + static_filter_dim (int): number of channels in the convolution layer computing the static filters. + static_kernel_size (int): kernel size for the convolution layer computing the static filters. + dynamic_filter_dim (int): number of channels in the convolution layer computing the dynamic filters. + dynamic_kernel_size (int): kernel size for the convolution layer computing the dynamic filters. + prior_filter_len (int, optional): [description]. Defaults to 11 from the paper. + alpha (float, optional): [description]. Defaults to 0.1 from the paper. + beta (float, optional): [description]. Defaults to 0.9 from the paper. + """ + + def __init__( + self, + query_dim, + embedding_dim, # pylint: disable=unused-argument + attention_dim, + static_filter_dim, + static_kernel_size, + dynamic_filter_dim, + dynamic_kernel_size, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ): + super().__init__() + self._mask_value = 1e-8 + self.dynamic_filter_dim = dynamic_filter_dim + self.dynamic_kernel_size = dynamic_kernel_size + self.prior_filter_len = prior_filter_len + self.attention_weights = None + # setup key and query layers + self.query_layer = nn.Linear(query_dim, attention_dim) + self.key_layer = nn.Linear(attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False) + self.static_filter_conv = nn.Conv1d( + 1, + static_filter_dim, + static_kernel_size, + padding=(static_kernel_size - 1) // 2, + bias=False, + ) + self.static_filter_layer = nn.Linear(static_filter_dim, attention_dim, bias=False) + self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim) + self.v = nn.Linear(attention_dim, 1, bias=False) + + prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, alpha, beta) + self.register_buffer("prior", torch.FloatTensor(prior).flip(0)) + + # pylint: disable=unused-argument + def forward(self, query, inputs, processed_inputs, mask): + """ + query: [B, C_attn_rnn] + inputs: [B, T_en, D_en] + processed_inputs: place holder. + mask: [B, T_en] + """ + # compute prior filters + prior_filter = F.conv1d( + F.pad(self.attention_weights.unsqueeze(1), (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1) + ) + prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1) + G = self.key_layer(torch.tanh(self.query_layer(query))) + # compute dynamic filters + dynamic_filter = F.conv1d( + self.attention_weights.unsqueeze(0), + G.view(-1, 1, self.dynamic_kernel_size), + padding=(self.dynamic_kernel_size - 1) // 2, + groups=query.size(0), + ) + dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2) + # compute static filters + static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2) + alignment = ( + self.v( + torch.tanh(self.static_filter_layer(static_filter) + self.dynamic_filter_layer(dynamic_filter)) + ).squeeze(-1) + + prior_filter + ) + # compute attention weights + attention_weights = F.softmax(alignment, dim=-1) + # apply masking + if mask is not None: + attention_weights.data.masked_fill_(~mask, self._mask_value) + self.attention_weights = attention_weights + # compute context + context = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1) + return context + + def preprocess_inputs(self, inputs): # pylint: disable=no-self-use + return None + + def init_states(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights = torch.zeros([B, T], device=inputs.device) + self.attention_weights[:, 0] = 1.0 + + +def init_attn( + attn_type, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + attn_K, +): + if attn_type == "original": + return OriginalAttention( + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ) + if attn_type == "graves": + return GravesAttention(query_dim, attn_K) + if attn_type == "dynamic_convolution": + return MonotonicDynamicConvolutionAttention( + query_dim, + embedding_dim, + attention_dim, + static_filter_dim=8, + static_kernel_size=21, + dynamic_filter_dim=8, + dynamic_kernel_size=21, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ) + + raise RuntimeError(f" [!] Given Attention Type '{attn_type}' is not exist.") diff --git a/TTS/tts/layers/tacotron/capacitron_layers.py b/TTS/tts/layers/tacotron/capacitron_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..56fe44bc333f17a4361f16f3cc75334a69c58ac6 --- /dev/null +++ b/TTS/tts/layers/tacotron/capacitron_layers.py @@ -0,0 +1,205 @@ +import torch +from torch import nn +from torch.distributions.multivariate_normal import MultivariateNormal as MVN +from torch.nn import functional as F + + +class CapacitronVAE(nn.Module): + """Effective Use of Variational Embedding Capacity for prosody transfer. + + See https://arxiv.org/abs/1906.03402""" + + def __init__( + self, + num_mel, + capacitron_VAE_embedding_dim, + encoder_output_dim=256, + reference_encoder_out_dim=128, + speaker_embedding_dim=None, + text_summary_embedding_dim=None, + ): + super().__init__() + # Init distributions + self.prior_distribution = MVN( + torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim) + ) + self.approximate_posterior_distribution = None + # define output ReferenceEncoder dim to the capacitron_VAE_embedding_dim + self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim) + + # Init beta, the lagrange-like term for the KL distribution + self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True) + mlp_input_dimension = reference_encoder_out_dim + + if text_summary_embedding_dim is not None: + self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim) + mlp_input_dimension += text_summary_embedding_dim + if speaker_embedding_dim is not None: + # TODO: Test a multispeaker model! + mlp_input_dimension += speaker_embedding_dim + self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim) + + def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None): + # Use reference + if reference_mel_info is not None: + reference_mels = reference_mel_info[0] # [batch_size, num_frames, num_mels] + mel_lengths = reference_mel_info[1] # [batch_size] + enc_out = self.encoder(reference_mels, mel_lengths) + + # concat speaker_embedding and/or text summary embedding + if text_info is not None: + text_inputs = text_info[0] # [batch_size, num_characters, num_embedding] + input_lengths = text_info[1] + text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device) + enc_out = torch.cat([enc_out, text_summary_out], dim=-1) + if speaker_embedding is not None: + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) + + # Feed the output of the ref encoder and information about text/speaker into + # an MLP to produce the parameteres for the approximate poterior distributions + mu, sigma = self.post_encoder_mlp(enc_out) + # convert to cpu because prior_distribution was created on cpu + mu = mu.cpu() + sigma = sigma.cpu() + + # Sample from the posterior: z ~ q(z|x) + self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma)) + VAE_embedding = self.approximate_posterior_distribution.rsample() + # Infer from the model, bypasses encoding + else: + # Sample from the prior: z ~ p(z) + VAE_embedding = self.prior_distribution.sample().unsqueeze(0) + + # reshape to [batch_size, 1, capacitron_VAE_embedding_dim] + return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, num_mel, out_dim): + + super().__init__() + self.num_mel = num_mel + filters = [1] + [32, 32, 64, 64, 128, 128] + num_layers = len(filters) - 1 + convs = [ + nn.Conv2d( + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2) + ) + for i in range(num_layers) + ] + self.convs = nn.ModuleList(convs) + self.training = False + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) + + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers) + self.recurrence = nn.LSTM( + input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False + ) + + def forward(self, inputs, input_lengths): + batch_size = inputs.size(0) + x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel] + valid_lengths = input_lengths.float() # [batch_size] + for conv, bn in zip(self.convs, self.bns): + x = conv(x) + x = bn(x) + x = F.relu(x) + + # Create the post conv width mask based on the valid lengths of the output of the convolution. + # The valid lengths for the output of a convolution on varying length inputs is + # ceil(input_length/stride) + 1 for stride=3 and padding=2 + # For example (kernel_size=3, stride=2, padding=2): + # 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d + # _____ + # x _____ + # x _____ + # x ____ + # x + # x x x x -> Output valid length = 4 + # Since every example in te batch is zero padded and therefore have separate valid_lengths, + # we need to mask off all the values AFTER the valid length for each example in the batch. + # Otherwise, the convolutions create noise and a lot of not real information + valid_lengths = (valid_lengths / 2).float() + valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size] + post_conv_max_width = x.size(2) + + mask = torch.arange(post_conv_max_width).to(inputs.device).expand( + len(valid_lengths), post_conv_max_width + ) < valid_lengths.unsqueeze(1) + mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1] + x = x * mask + + x = x.transpose(1, 2) + # x: 4D tensor [batch_size, post_conv_width, + # num_channels==128, post_conv_height] + + post_conv_width = x.size(1) + x = x.contiguous().view(batch_size, post_conv_width, -1) + # x: 3D tensor [batch_size, post_conv_width, + # num_channels*post_conv_height] + + # Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding + post_conv_input_lengths = valid_lengths + packed_seqs = nn.utils.rnn.pack_padded_sequence( + x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False + ) # dynamic rnn sequence padding + self.recurrence.flatten_parameters() + _, (ht, _) = self.recurrence(packed_seqs) + last_output = ht[-1] + + return last_output.to(inputs.device) # [B, 128] + + @staticmethod + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for _ in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + +class TextSummary(nn.Module): + def __init__(self, embedding_dim, encoder_output_dim): + super().__init__() + self.lstm = nn.LSTM( + encoder_output_dim, # text embedding dimension from the text encoder + embedding_dim, # fixed length output summary the lstm creates from the input + batch_first=True, + bidirectional=False, + ) + + def forward(self, inputs, input_lengths): + # Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding + packed_seqs = nn.utils.rnn.pack_padded_sequence( + inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False + ) # dynamic rnn sequence padding + self.lstm.flatten_parameters() + _, (ht, _) = self.lstm(packed_seqs) + last_output = ht[-1] + return last_output + + +class PostEncoderMLP(nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.hidden_size = hidden_size + modules = [ + nn.Linear(input_size, hidden_size), # Hidden Layer + nn.Tanh(), + nn.Linear(hidden_size, hidden_size * 2), + ] # Output layer twice the size for mean and variance + self.net = nn.Sequential(*modules) + self.softplus = nn.Softplus() + + def forward(self, _input): + mlp_output = self.net(_input) + # The mean parameter is unconstrained + mu = mlp_output[:, : self.hidden_size] + # The standard deviation must be positive. Parameterise with a softplus + sigma = self.softplus(mlp_output[:, self.hidden_size :]) + return mu, sigma diff --git a/TTS/tts/layers/tacotron/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f78ff1e75f6c23eb1a0fe827247a1127bc8f9958 --- /dev/null +++ b/TTS/tts/layers/tacotron/common_layers.py @@ -0,0 +1,119 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Linear(nn.Module): + """Linear layer with a specific initialization. + + Args: + in_features (int): number of channels in the input tensor. + out_features (int): number of channels in the output tensor. + bias (bool, optional): enable/disable bias in the layer. Defaults to True. + init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'. + """ + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class LinearBN(nn.Module): + """Linear layer with Batch Normalization. + + x -> linear -> BN -> o + + Args: + in_features (int): number of channels in the input tensor. + out_features (int ): number of channels in the output tensor. + bias (bool, optional): enable/disable bias in the linear layer. Defaults to True. + init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'. + """ + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) + self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + """ + Shapes: + x: [T, B, C] or [B, C] + """ + out = self.linear_layer(x) + if len(out.shape) == 3: + out = out.permute(1, 2, 0) + out = self.batch_normalization(out) + if len(out.shape) == 3: + out = out.permute(2, 0, 1) + return out + + +class Prenet(nn.Module): + """Tacotron specific Prenet with an optional Batch Normalization. + + Note: + Prenet with BN improves the model performance significantly especially + if it is enabled after learning a diagonal attention alignment with the original + prenet. However, if the target dataset is high quality then it also works from + the start. It is also suggested to disable dropout if BN is in use. + + prenet_type == "original" + x -> [linear -> ReLU -> Dropout]xN -> o + + prenet_type == "bn" + x -> [linear -> BN -> ReLU -> Dropout]xN -> o + + Args: + in_features (int): number of channels in the input tensor and the inner layers. + prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original". + prenet_dropout (bool, optional): dropout rate. Defaults to True. + dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models. + out_features (list, optional): List of output channels for each prenet block. + It also defines number of the prenet blocks based on the length of argument list. + Defaults to [256, 256]. + bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + prenet_type="original", + prenet_dropout=True, + dropout_at_inference=False, + out_features=[256, 256], + bias=True, + ): + super().__init__() + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + self.dropout_at_inference = dropout_at_inference + in_features = [in_features] + out_features[:-1] + if prenet_type == "bn": + self.linear_layers = nn.ModuleList( + [LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) + elif prenet_type == "original": + self.linear_layers = nn.ModuleList( + [Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) + + def forward(self, x): + for linear in self.linear_layers: + if self.prenet_dropout: + x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference) + else: + x = F.relu(linear(x)) + return x diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..ec622e4db80eb7f0e319bc11df950086b9562f41 --- /dev/null +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -0,0 +1,151 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class GST(nn.Module): + """Global Style Token Module for factorizing prosody in speech. + + See https://arxiv.org/pdf/1803.09017""" + + def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None): + super().__init__() + self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) + self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim) + + def forward(self, inputs, speaker_embedding=None): + enc_out = self.encoder(inputs) + # concat speaker_embedding + if speaker_embedding is not None: + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) + style_embed = self.style_token_layer(enc_out) + + return style_embed + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, num_mel, embedding_dim): + + super().__init__() + self.num_mel = num_mel + filters = [1] + [32, 32, 64, 64, 128, 128] + num_layers = len(filters) - 1 + convs = [ + nn.Conv2d( + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ) + for i in range(num_layers) + ] + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) + + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) + self.recurrence = nn.GRU( + input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True + ) + + def forward(self, inputs): + batch_size = inputs.size(0) + x = inputs.view(batch_size, 1, -1, self.num_mel) + # x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel] + for conv, bn in zip(self.convs, self.bns): + x = conv(x) + x = bn(x) + x = F.relu(x) + + x = x.transpose(1, 2) + # x: 4D tensor [batch_size, post_conv_width, + # num_channels==128, post_conv_height] + post_conv_width = x.size(1) + x = x.contiguous().view(batch_size, post_conv_width, -1) + # x: 3D tensor [batch_size, post_conv_width, + # num_channels*post_conv_height] + self.recurrence.flatten_parameters() + _, out = self.recurrence(x) + # out: 3D tensor [seq_len==1, batch_size, encoding_size=128] + + return out.squeeze(0) + + @staticmethod + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for _ in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + +class StyleTokenLayer(nn.Module): + """NN Module attending to style tokens based on prosody encodings.""" + + def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None): + super().__init__() + + self.query_dim = gst_embedding_dim // 2 + + if d_vector_dim: + self.query_dim += d_vector_dim + + self.key_dim = gst_embedding_dim // num_heads + self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) + nn.init.normal_(self.style_tokens, mean=0, std=0.5) + self.attention = MultiHeadAttention( + query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads + ) + + def forward(self, inputs): + batch_size = inputs.size(0) + prosody_encoding = inputs.unsqueeze(1) + # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128] + tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1) + # tokens: 3D tensor [batch_size, num tokens, token embedding size] + style_embed = self.attention(prosody_encoding, tokens) + + return style_embed + + +class MultiHeadAttention(nn.Module): + """ + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + """ + + def __init__(self, query_dim, key_dim, num_units, num_heads): + + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query, key): + queries = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + + split_size = self.num_units // self.num_heads + queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k**0.5)) + scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim**0.5) + scores = F.softmax(scores, dim=3) + + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..bddaf449c112a99458c9047c5c07df592e935972 --- /dev/null +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -0,0 +1,504 @@ +# coding: utf-8 +# adapted from https://github.com/r9y9/tacotron_pytorch + +import torch +from torch import nn + +from .attentions import init_attn +from .common_layers import Prenet + + +class BatchNormConv1d(nn.Module): + r"""A wrapper for Conv1d with BatchNorm. It sets the activation + function between Conv and BatchNorm layers. BatchNorm layer + is initialized with the TF default values for momentum and eps. + + Args: + in_channels: size of each input sample + out_channels: size of each output samples + kernel_size: kernel size of conv filters + stride: stride of conv filters + padding: padding of conv filters + activation: activation function set b/w Conv1d and BatchNorm + + Shapes: + - input: (B, D) + - output: (B, D) + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): + + super().__init__() + self.padding = padding + self.padder = nn.ConstantPad1d(padding, 0) + self.conv1d = nn.Conv1d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False + ) + # Following tensorflow's default parameters + self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) + self.activation = activation + # self.init_layers() + + def init_layers(self): + if isinstance(self.activation, torch.nn.ReLU): + w_gain = "relu" + elif isinstance(self.activation, torch.nn.Tanh): + w_gain = "tanh" + elif self.activation is None: + w_gain = "linear" + else: + raise RuntimeError("Unknown activation function") + torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) + + def forward(self, x): + x = self.padder(x) + x = self.conv1d(x) + x = self.bn(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class Highway(nn.Module): + r"""Highway layers as explained in https://arxiv.org/abs/1505.00387 + + Args: + in_features (int): size of each input sample + out_feature (int): size of each output sample + + Shapes: + - input: (B, *, H_in) + - output: (B, *, H_out) + """ + + # TODO: Try GLU layer + def __init__(self, in_features, out_feature): + super().__init__() + self.H = nn.Linear(in_features, out_feature) + self.H.bias.data.zero_() + self.T = nn.Linear(in_features, out_feature) + self.T.bias.data.fill_(-1) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + # self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid")) + + def forward(self, inputs): + H = self.relu(self.H(inputs)) + T = self.sigmoid(self.T(inputs)) + return H * T + inputs * (1.0 - T) + + +class CBHG(nn.Module): + """CBHG module: a recurrent neural network composed of: + - 1-d convolution banks + - Highway networks + residual connections + - Bidirectional gated recurrent units + + Args: + in_features (int): sample size + K (int): max filter size in conv bank + projections (list): conv channel sizes for conv projections + num_highways (int): number of highways layers + + Shapes: + - input: (B, C, T_in) + - output: (B, T_in, C*2) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ): + super().__init__() + self.in_features = in_features + self.conv_bank_features = conv_bank_features + self.highway_features = highway_features + self.gru_features = gru_features + self.conv_projections = conv_projections + self.relu = nn.ReLU() + # list of conv1d bank with filter size k=1...K + # TODO: try dilational layers instead + self.conv1d_banks = nn.ModuleList( + [ + BatchNormConv1d( + in_features, + conv_bank_features, + kernel_size=k, + stride=1, + padding=[(k - 1) // 2, k // 2], + activation=self.relu, + ) + for k in range(1, K + 1) + ] + ) + # max pooling of conv bank, with padding + # TODO: try average pooling OR larger kernel size + out_features = [K * conv_bank_features] + conv_projections[:-1] + activations = [self.relu] * (len(conv_projections) - 1) + activations += [None] + # setup conv1d projection layers + layer_set = [] + for (in_size, out_size, ac) in zip(out_features, conv_projections, activations): + layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac) + layer_set.append(layer) + self.conv1d_projections = nn.ModuleList(layer_set) + # setup Highway layers + if self.highway_features != conv_projections[-1]: + self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False) + self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)]) + # bi-directional GPU layer + self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True) + + def forward(self, inputs): + # (B, in_features, T_in) + x = inputs + # (B, hid_features*K, T_in) + # Concat conv1d bank outputs + outs = [] + for conv1d in self.conv1d_banks: + out = conv1d(x) + outs.append(out) + x = torch.cat(outs, dim=1) + assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks) + for conv1d in self.conv1d_projections: + x = conv1d(x) + x += inputs + x = x.transpose(1, 2) + if self.highway_features != self.conv_projections[-1]: + x = self.pre_highway(x) + # Residual connection + # TODO: try residual scaling as in Deep Voice 3 + # TODO: try plain residual layers + for highway in self.highways: + x = highway(x) + # (B, T_in, hid_features*2) + # TODO: replace GRU with convolution as in Deep Voice 3 + self.gru.flatten_parameters() + outputs, _ = self.gru(x) + return outputs + + +class EncoderCBHG(nn.Module): + r"""CBHG module with Encoder specific arguments""" + + def __init__(self): + super().__init__() + self.cbhg = CBHG( + 128, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ) + + def forward(self, x): + return self.cbhg(x) + + +class Encoder(nn.Module): + r"""Stack Prenet and CBHG module for encoder + Args: + inputs (FloatTensor): embedding features + + Shapes: + - inputs: (B, T, D_in) + - outputs: (B, T, 128 * 2) + """ + + def __init__(self, in_features): + super().__init__() + self.prenet = Prenet(in_features, out_features=[256, 128]) + self.cbhg = EncoderCBHG() + + def forward(self, inputs): + # B x T x prenet_dim + outputs = self.prenet(inputs) + outputs = self.cbhg(outputs.transpose(1, 2)) + return outputs + + +class PostCBHG(nn.Module): + def __init__(self, mel_dim): + super().__init__() + self.cbhg = CBHG( + mel_dim, + K=8, + conv_bank_features=128, + conv_projections=[256, mel_dim], + highway_features=128, + gru_features=128, + num_highways=4, + ) + + def forward(self, x): + return self.cbhg(x) + + +class Decoder(nn.Module): + """Tacotron decoder. + + Args: + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). + memory_size (int): size of the past window. if <= 0 memory_size = r + attn_type (string): type of attention used in decoder. + attn_windowing (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + d_vector_dim (int): size of speaker embedding vector, for multi-speaker training. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + + def __init__( + self, + in_channels, + frame_channels, + r, + memory_size, + attn_type, + attn_windowing, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + max_decoder_steps, + ): + super().__init__() + self.r_init = r + self.r = r + self.in_channels = in_channels + self.max_decoder_steps = max_decoder_steps + self.use_memory_queue = memory_size > 0 + self.memory_size = memory_size if memory_size > 0 else r + self.frame_channels = frame_channels + self.separate_stopnet = separate_stopnet + self.query_dim = 256 + # memory -> |Prenet| -> processed_memory + prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels + self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128]) + # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State + # attention_rnn generates queries for the attention mechanism + self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) + # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input + self.project_to_decoder_in = nn.Linear(256 + in_channels, 256) + # decoder_RNN_input -> |RNN| -> RNN_state + self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)]) + # RNN_state -> |Linear| -> mel_spec + self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init) + # learn init values instead of zero init. + self.stopnet = StopNet(256 + frame_channels * self.r_init) + + def set_r(self, new_r): + self.r = new_r + + def _reshape_memory(self, memory): + """ + Reshape the spectrograms for given 'r' + """ + # Grouping multiple frames if necessary + if memory.size(-1) == self.frame_channels: + memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) + # Time first (T_decoder, B, frame_channels) + memory = memory.transpose(0, 1) + return memory + + def _init_states(self, inputs): + """ + Initialization of decoder states + """ + B = inputs.size(0) + # go frame as zeros matrix + if self.use_memory_queue: + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size) + else: + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels) + # decoder states + self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) + self.decoder_rnn_hiddens = [ + torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) + ] + self.context_vec = inputs.data.new(B, self.in_channels).zero_() + # cache attention inputs + self.processed_inputs = self.attention.preprocess_inputs(inputs) + + def _parse_outputs(self, outputs, attentions, stop_tokens): + # Back to batch first + attentions = torch.stack(attentions).transpose(0, 1) + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) + outputs = outputs.transpose(1, 2) + return outputs, attentions, stop_tokens + + def decode(self, inputs, mask=None): + # Prenet + processed_memory = self.prenet(self.memory_input) + # Attention RNN + self.attention_rnn_hidden = self.attention_rnn( + torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden + ) + self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) + # Concat RNN output and attention context vector + decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) + + # Pass through the decoder RNNs + for idx, decoder_rnn in enumerate(self.decoder_rnns): + self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx]) + # Residual connection + decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input + decoder_output = decoder_input + + # predict mel vectors from decoder vectors + output = self.proj_to_mel(decoder_output) + # output = torch.sigmoid(output) + # predict stop token + stopnet_input = torch.cat([decoder_output, output], -1) + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) + output = output[:, : self.r * self.frame_channels] + return output, stop_token, self.attention.attention_weights + + def _update_memory_input(self, new_memory): + if self.use_memory_queue: + if self.memory_size > self.r: + # memory queue size is larger than number of frames per decoder iter + self.memory_input = torch.cat( + [new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()], + dim=-1, + ) + else: + # memory queue size smaller than number of frames per decoder iter + self.memory_input = new_memory[:, : self.memory_size * self.frame_channels] + else: + # use only the last frame prediction + # assert new_memory.shape[-1] == self.r * self.frame_channels + self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :] + + def forward(self, inputs, memory, mask): + """ + Args: + inputs: Encoder outputs. + memory: Decoder memory (autoregression. If None (at eval-time), + decoder outputs are used as decoder inputs. If None, it uses the last + output as the input. + mask: Attention mask for sequence padding. + + Shapes: + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) + """ + # Run greedy decoding if memory is None + memory = self._reshape_memory(memory) + outputs = [] + attentions = [] + stop_tokens = [] + t = 0 + self._init_states(inputs) + self.attention.init_states(inputs) + while len(outputs) < memory.size(0): + if t > 0: + new_memory = memory[t - 1] + self._update_memory_input(new_memory) + + output, stop_token, attention = self.decode(inputs, mask) + outputs += [output] + attentions += [attention] + stop_tokens += [stop_token.squeeze(1)] + t += 1 + return self._parse_outputs(outputs, attentions, stop_tokens) + + def inference(self, inputs): + """ + Args: + inputs: encoder outputs. + Shapes: + - inputs: batch x time x encoder_out_dim + """ + outputs = [] + attentions = [] + stop_tokens = [] + t = 0 + self._init_states(inputs) + self.attention.init_states(inputs) + while True: + if t > 0: + new_memory = outputs[-1] + self._update_memory_input(new_memory) + output, stop_token, attention = self.decode(inputs, None) + stop_token = torch.sigmoid(stop_token.data) + outputs += [output] + attentions += [attention] + stop_tokens += [stop_token] + t += 1 + if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): + break + if t > self.max_decoder_steps: + print(" | > Decoder stopped with 'max_decoder_steps") + break + return self._parse_outputs(outputs, attentions, stop_tokens) + + +class StopNet(nn.Module): + r"""Stopnet signalling decoder to stop inference. + Args: + in_features (int): feature dimension of input. + """ + + def __init__(self, in_features): + super().__init__() + self.dropout = nn.Dropout(0.1) + self.linear = nn.Linear(in_features, 1) + torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear")) + + def forward(self, inputs): + outputs = self.dropout(inputs) + outputs = self.linear(outputs) + return outputs diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..c79b70997249efc94cbac630bcc7d6c571f5743e --- /dev/null +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -0,0 +1,414 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .attentions import init_attn +from .common_layers import Linear, Prenet + + +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg +class ConvBNBlock(nn.Module): + r"""Convolutions with Batch Normalization and non-linear activation. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): convolution kernel size. + activation (str): 'relu', 'tanh', None (linear). + + Shapes: + - input: (B, C_in, T) + - output: (B, C_out, T) + """ + + def __init__(self, in_channels, out_channels, kernel_size, activation=None): + super().__init__() + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) + self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) + self.dropout = nn.Dropout(p=0.5) + if activation == "relu": + self.activation = nn.ReLU() + elif activation == "tanh": + self.activation = nn.Tanh() + else: + self.activation = nn.Identity() + + def forward(self, x): + o = self.convolution1d(x) + o = self.batch_normalization(o) + o = self.activation(o) + o = self.dropout(o) + return o + + +class Postnet(nn.Module): + r"""Tacotron2 Postnet + + Args: + in_out_channels (int): number of output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + + def __init__(self, in_out_channels, num_convs=5): + super().__init__() + self.convolutions = nn.ModuleList() + self.convolutions.append(ConvBNBlock(in_out_channels, 512, kernel_size=5, activation="tanh")) + for _ in range(1, num_convs - 1): + self.convolutions.append(ConvBNBlock(512, 512, kernel_size=5, activation="tanh")) + self.convolutions.append(ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) + + def forward(self, x): + o = x + for layer in self.convolutions: + o = layer(o) + return o + + +class Encoder(nn.Module): + r"""Tacotron2 Encoder + + Args: + in_out_channels (int): number of input and output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + + def __init__(self, in_out_channels=512): + super().__init__() + self.convolutions = nn.ModuleList() + for _ in range(3): + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, bidirectional=True + ) + self.rnn_state = None + + def forward(self, x, input_lengths): + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, input_lengths.cpu(), batch_first=True) + self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + return o + + def inference(self, x): + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + # self.lstm.flatten_parameters() + o, _ = self.lstm(o) + return o + + +# adapted from https://github.com/NVIDIA/tacotron2/ +class Decoder(nn.Module): + """Tacotron2 decoder. We don't use Zoneout but Dropout between RNN layers. + + Args: + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). + memory_size (int): size of the past window. if <= 0 memory_size = r + attn_type (string): type of attention used in decoder. + attn_win (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + in_channels, + frame_channels, + r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + max_decoder_steps, + ): + super().__init__() + self.frame_channels = frame_channels + self.r_init = r + self.r = r + self.encoder_embedding_dim = in_channels + self.separate_stopnet = separate_stopnet + self.max_decoder_steps = max_decoder_steps + self.stop_threshold = 0.5 + + # model dimensions + self.query_dim = 1024 + self.decoder_rnn_dim = 1024 + self.prenet_dim = 256 + self.attn_dim = 128 + self.p_attention_dropout = 0.1 + self.p_decoder_dropout = 0.1 + + # memory -> |Prenet| -> processed_memory + prenet_dim = self.frame_channels + self.prenet = Prenet( + prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False + ) + + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True) + + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) + + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True) + + self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, self.frame_channels * self.r_init) + + self.stopnet = nn.Sequential( + nn.Dropout(0.1), + Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain="sigmoid"), + ) + self.memory_truncated = None + + def set_r(self, new_r): + self.r = new_r + + def get_go_frame(self, inputs): + B = inputs.size(0) + memory = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.r) + return memory + + def _init_states(self, inputs, mask, keep_states=False): + B = inputs.size(0) + # T = inputs.size(1) + if not keep_states: + self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.context = torch.zeros(1, device=inputs.device).repeat(B, self.encoder_embedding_dim) + self.inputs = inputs + self.processed_inputs = self.attention.preprocess_inputs(inputs) + self.mask = mask + + def _reshape_memory(self, memory): + """ + Reshape the spectrograms for given 'r' + """ + # Grouping multiple frames if necessary + if memory.size(-1) == self.frame_channels: + memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) + # Time first (T_decoder, B, frame_channels) + memory = memory.transpose(0, 1) + return memory + + def _parse_outputs(self, outputs, stop_tokens, alignments): + alignments = torch.stack(alignments).transpose(0, 1) + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) + outputs = outputs.transpose(1, 2) + return outputs, stop_tokens, alignments + + def _update_memory(self, memory): + if len(memory.shape) == 2: + return memory[:, self.frame_channels * (self.r - 1) :] + return memory[:, :, self.frame_channels * (self.r - 1) :] + + def decode(self, memory): + """ + shapes: + - memory: B x r * self.frame_channels + """ + # self.context: B x D_en + # query_input: B x D_en + (r * self.frame_channels) + query_input = torch.cat((memory, self.context), -1) + # self.query and self.attention_rnn_cell_state : B x D_attn_rnn + self.query, self.attention_rnn_cell_state = self.attention_rnn( + query_input, (self.query, self.attention_rnn_cell_state) + ) + self.query = F.dropout(self.query, self.p_attention_dropout, self.training) + self.attention_rnn_cell_state = F.dropout( + self.attention_rnn_cell_state, self.p_attention_dropout, self.training + ) + # B x D_en + self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) + # B x (D_en + D_attn_rnn) + decoder_rnn_input = torch.cat((self.query, self.context), -1) + # self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_rnn_input, (self.decoder_hidden, self.decoder_cell) + ) + self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) + # B x (D_decoder_rnn + D_en) + decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) + # B x (self.r * self.frame_channels) + decoder_output = self.linear_projection(decoder_hidden_context) + # B x (D_decoder_rnn + (self.r * self.frame_channels)) + stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) + # select outputs for the reduction rate self.r + decoder_output = decoder_output[:, : self.r * self.frame_channels] + return decoder_output, self.attention.attention_weights, stop_token + + def forward(self, inputs, memories, mask): + r"""Train Decoder with teacher forcing. + Args: + inputs: Encoder outputs. + memories: Feature frames for teacher-forcing. + mask: Attention mask for sequence padding. + + Shapes: + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ + memory = self.get_go_frame(inputs).unsqueeze(0) + memories = self._reshape_memory(memories) + memories = torch.cat((memory, memories), dim=0) + memories = self._update_memory(memories) + memories = self.prenet(memories) + + self._init_states(inputs, mask=mask) + self.attention.init_states(inputs) + + outputs, stop_tokens, alignments = [], [], [] + while len(outputs) < memories.size(0) - 1: + memory = memories[len(outputs)] + decoder_output, attention_weights, stop_token = self.decode(memory) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token.squeeze(1)] + alignments += [attention_weights] + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + return outputs, alignments, stop_tokens + + def inference(self, inputs): + r"""Decoder inference without teacher forcing and use + Stopnet to stop decoder. + Args: + inputs: Encoder outputs. + + Shapes: + - inputs: (B, T, D_out_enc) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ + memory = self.get_go_frame(inputs) + memory = self._update_memory(memory) + + self._init_states(inputs, mask=None) + self.attention.init_states(inputs) + + outputs, stop_tokens, alignments, t = [], [], [], 0 + while True: + memory = self.prenet(memory) + decoder_output, alignment, stop_token = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token] + alignments += [alignment] + + if stop_token > self.stop_threshold and t > inputs.shape[0] // 2: + break + if len(outputs) == self.max_decoder_steps: + print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}") + break + + memory = self._update_memory(decoder_output) + t += 1 + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + + return outputs, alignments, stop_tokens + + def inference_truncated(self, inputs): + """ + Preserve decoder states for continuous inference + """ + if self.memory_truncated is None: + self.memory_truncated = self.get_go_frame(inputs) + self._init_states(inputs, mask=None, keep_states=False) + else: + self._init_states(inputs, mask=None, keep_states=True) + + self.attention.init_states(inputs) + outputs, stop_tokens, alignments, t = [], [], [], 0 + while True: + memory = self.prenet(self.memory_truncated) + decoder_output, alignment, stop_token = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token] + alignments += [alignment] + + if stop_token > 0.7: + break + if len(outputs) == self.max_decoder_steps: + print(" | > Decoder stopped with 'max_decoder_steps") + break + + self.memory_truncated = decoder_output + t += 1 + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + + return outputs, alignments, stop_tokens + + def inference_step(self, inputs, t, memory=None): + """ + For debug purposes + """ + if t == 0: + memory = self.get_go_frame(inputs) + self._init_states(inputs, mask=None) + + memory = self.prenet(memory) + decoder_output, stop_token, alignment = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + memory = decoder_output + return decoder_output, stop_token, alignment diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..148f283c9010e522c49ad2595860ab859ba6aa48 --- /dev/null +++ b/TTS/tts/layers/vits/discriminator.py @@ -0,0 +1,89 @@ +import torch +from torch import nn +from torch.nn.modules.conv import Conv1d + +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, 0.1) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class VitsDiscriminator(nn.Module): + """VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator. + + :: + waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats + |--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^ + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): + super().__init__() + self.nets = nn.ModuleList() + self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) + self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) + + def forward(self, x, x_hat=None): + """ + Args: + x (Tensor): ground truth waveform. + x_hat (Tensor): predicted waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + x_scores = [] + x_hat_scores = [] if x_hat is not None else None + x_feats = [] + x_hat_feats = [] if x_hat is not None else None + for net in self.nets: + x_score, x_feat = net(x) + x_scores.append(x_score) + x_feats.append(x_feat) + if x_hat is not None: + x_hat_score, x_hat_feat = net(x_hat) + x_hat_scores.append(x_hat_score) + x_hat_feats.append(x_hat_feat) + return x_scores, x_feats, x_hat_scores, x_hat_feats diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..f97b584fe6ed311127a8c01a089b159946219cb2 --- /dev/null +++ b/TTS/tts/layers/vits/networks.py @@ -0,0 +1,288 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.glow import WN +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.helpers import sequence_mask + +LRELU_SLOPE = 0.1 + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size: int, + dropout_p: float, + language_emb_dim: int = None, + ): + """Text Encoder for VITS model. + + Args: + n_vocab (int): Number of characters for the embedding layer. + out_channels (int): Number of channels for the output. + hidden_channels (int): Number of channels for the hidden layers. + hidden_channels_ffn (int): Number of channels for the convolutional layers. + num_heads (int): Number of attention heads for the Transformer layers. + num_layers (int): Number of Transformer layers. + kernel_size (int): Kernel size for the FFN layers in Transformer network. + dropout_p (float): Dropout rate for the Transformer layers. + """ + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + + self.emb = nn.Embedding(n_vocab, hidden_channels) + + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + if language_emb_dim: + hidden_channels += language_emb_dim + + self.encoder = RelativePositionTransformer( + in_channels=hidden_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + hidden_channels_ffn=hidden_channels_ffn, + num_heads=num_heads, + num_layers=num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + layer_norm_type="2", + rel_attn_window_size=4, + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, lang_emb=None): + """ + Shapes: + - x: :math:`[B, T]` + - x_length: :math:`[B]` + """ + assert x.shape[0] == x_lengths.shape[0] + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + + # concat the lang emb in embedding chars + if lang_emb is not None: + x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=0, + cond_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.half_channels = channels // 2 + self.mean_only = mean_only + # input layer + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + # coupling layers + self.enc = WN( + hidden_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=dropout_p, + c_in_channels=cond_channels, + ) + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + log_scale = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(log_scale) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(log_scale, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-log_scale) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ResidualCouplingBlocks(nn.Module): + def __init__( + self, + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + num_flows=4, + cond_channels=0, + ): + """Redisual Coupling blocks for VITS flow layers. + + Args: + channels (int): Number of input and output tensor channels. + hidden_channels (int): Number of hidden network channels. + kernel_size (int): Kernel size of the WaveNet layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. + """ + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.num_flows = num_flows + self.cond_channels = cond_channels + + self.flows = nn.ModuleList() + for _ in range(num_flows): + self.flows.append( + ResidualCouplingBlock( + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + cond_channels=cond_channels, + mean_only=True, + ) + ) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + x = torch.flip(x, [1]) + else: + for flow in reversed(self.flows): + x = torch.flip(x, [1]) + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + cond_channels=0, + ): + """Posterior Encoder of VITS model. + + :: + x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of output tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of the WaveNet convolution layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.cond_channels = cond_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B, 1]` + - g: :math:`[B, C, 1]` + """ + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + mean, log_scale = torch.split(stats, self.out_channels, dim=1) + z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + return z, mean, log_scale, x_mask diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..738ee341e649dfaf62059735c2620cb6ae1a2b1f --- /dev/null +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -0,0 +1,294 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform + + +class DilatedDepthSeparableConv(nn.Module): + def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor: + """Dilated Depth-wise Separable Convolution module. + + :: + x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o + |-------------------------------------------------------------------------------------^ + + Args: + channels ([type]): [description] + kernel_size ([type]): [description] + num_layers ([type]): [description] + dropout_p (float, optional): [description]. Defaults to 0.0. + + Returns: + torch.tensor: Network output masked by the input sequence mask. + """ + super().__init__() + self.num_layers = num_layers + + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(num_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm2(channels)) + self.norms_2.append(LayerNorm2(channels)) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + if g is not None: + x = x + g + for i in range(self.num_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.dropout(y) + x = x + y + return x * x_mask + + +class ElementwiseAffine(nn.Module): + """Element-wise affine transform like no-population stats BatchNorm alternative. + + Args: + channels (int): Number of input tensor channels. + """ + + def __init__(self, channels): + super().__init__() + self.translation = nn.Parameter(torch.zeros(channels, 1)) + self.log_scale = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument + if not reverse: + y = (x * torch.exp(self.log_scale) + self.translation) * x_mask + logdet = torch.sum(self.log_scale * x_mask, [1, 2]) + return y, logdet + x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask + return x + + +class ConvFlow(nn.Module): + """Dilated depth separable convolutional based spline flow. + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of in network channels. + kernel_size (int): Convolutional kernel size. + num_layers (int): Number of convolutional layers. + num_bins (int, optional): Number of spline bins. Defaults to 10. + tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + num_layers: int, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.num_bins = num_bins + self.tail_bound = tail_bound + self.hidden_channels = hidden_channels + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0) + self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + return x + + +class StochasticDurationPredictor(nn.Module): + """Stochastic duration predictor with Spline Flows. + + It applies Variational Dequantization and Variationsl Data Augmentation. + + Paper: + SDP: https://arxiv.org/pdf/2106.06103.pdf + Spline Flow: https://arxiv.org/abs/1906.04032 + + :: + ## Inference + + x -> TextCondEncoder() -> Flow() -> dr_hat + noise ----------------------^ + + ## Training + |---------------------| + x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise + d -> DurCondEncoder() -> ^ | + |------------------------------------------------------------------------------| + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of convolutional layers. + dropout_p (float): Dropout rate. + num_flows (int, optional): Number of flow blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + dropout_p: float, + num_flows=4, + cond_channels=0, + language_emb_dim=0, + ): + super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + + # condition encoder text + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # posterior encoder + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + # condition encoder duration + self.post_pre = nn.Conv1d(1, hidden_channels, 1) + self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # flow layers + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + if cond_channels != 0 and cond_channels is not None: + self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) + + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1) + + def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - dr: :math:`[B, 1, T]` + - g: :math:`[B, C]` + """ + # condition encoder text + x = self.pre(x) + if g is not None: + x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert dr is not None + + # condition encoder duration + h = self.post_pre(dr) + h = self.post_convs(h, x_mask) + h = self.post_proj(h) * x_mask + noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = noise + + # posterior encoder + logdet_tot_q = 0.0 + for idx, flow in enumerate(self.post_flows): + z_q, logdet_q = flow(z_q, x_mask, g=(x + h)) + logdet_tot_q = logdet_tot_q + logdet_q + if idx > 0: + z_q = torch.flip(z_q, [1]) + + z_u, z_v = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (dr - u) * x_mask + + # posterior encoder - neg log likelihood + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + nll_posterior_encoder = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q + ) + + z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask + logdet_tot = torch.sum(-z0, [1, 2]) + z = torch.cat([z0, z_v], 1) + + # flow layers + for idx, flow in enumerate(flows): + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + if idx > 0: + z = torch.flip(z, [1]) + + # flow layers - neg log likelihood + nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot + return nll_flow_layers + nll_posterior_encoder + + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = torch.flip(z, [1]) + z = flow(z, x_mask, g=x, reverse=reverse) + + z0, _ = torch.split(z, [1, 1], 1) + logw = z0 + return logw diff --git a/TTS/tts/layers/vits/transforms.py b/TTS/tts/layers/vits/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c1505554488fb18010b82bd97c88b28c7d4547e1 --- /dev/null +++ b/TTS/tts/layers/vits/transforms.py @@ -0,0 +1,203 @@ +# adopted from https://github.com/bayesiains/nflows + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d76a3bebee652f44a65f4a3d919ae2c3971d82f8 --- /dev/null +++ b/TTS/tts/models/__init__.py @@ -0,0 +1,14 @@ +from typing import Dict, List, Union + +from TTS.utils.generic_utils import find_module + + +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": + print(" > Using model: {}".format(config.model)) + # fetch the right model implementation. + if "base_model" in config and config["base_model"] is not None: + MyModel = find_module("TTS.tts.models", config.base_model.lower()) + else: + MyModel = find_module("TTS.tts.models", config.model.lower()) + model = MyModel.init_from_config(config, samples) + return model diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..0eef18aefe7aee00f502e5f3a87f7d0a3020392b --- /dev/null +++ b/TTS/tts/models/align_tts.py @@ -0,0 +1,453 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.align_tts.mdn import MDNBlock +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.io import load_fsspec + + +@dataclass +class AlignTTSArgs(Coqpit): + """ + Args: + num_chars (int): + number of unique input to characters + out_channels (int): + number of output tensor channels. It is equal to the expected spectrogram size. + hidden_channels (int): + number of channels in all the model layers. + hidden_channels_ffn (int): + number of channels in transformer's conv layers. + hidden_channels_dp (int): + number of channels in duration predictor network. + num_heads (int): + number of attention heads in transformer networks. + num_transformer_layers (int): + number of layers in encoder and decoder transformer blocks. + dropout_p (int): + dropout rate in transformer layers. + length_scale (int, optional): + coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1. + num_speakers (int, optional): + number of speakers for multi-speaker training. Defaults to 0. + external_c (bool, optional): + enable external speaker embeddings. Defaults to False. + c_in_channels (int, optional): + number of channels in speaker embedding vectors. Defaults to 0. + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 256 + hidden_channels_dp: int = 256 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + length_scale: float = 1.0 + num_speakers: int = 0 + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_dim: int = 0 + + +class AlignTTS(BaseTTS): + """AlignTTS with modified duration predictor. + https://arxiv.org/pdf/2003.01950.pdf + + Encoder -> DurationPredictor -> Decoder + + Check :class:`AlignTTSArgs` for the class arguments. + + Paper Abstract: + Targeting at both high efficiency and performance, we propose AlignTTS to predict the + mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a + sequence of characters, and the duration of each character is determined by a duration predictor.Instead of + adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented + to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s + how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean + option score (MOS), but also a high efficiency which is more than 50 times faster than real-time. + + Note: + Original model uses a separate character embedding layer for duration predictor. However, it causes the + duration predictor to overfit and prevents learning higher level interactions among characters. Therefore, + we predict durations based on encoder outputs which has higher level information about input characters. This + enables training without phases as in the original paper. + + Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture + differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. + + Examples: + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + >>> model = AlignTTS(config) + + """ + + # pylint: disable=dangerous-default-value + + def __init__( + self, + config: "AlignTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + + super().__init__(config, ap, tokenizer, speaker_manager) + self.speaker_manager = speaker_manager + self.phase = -1 + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + + self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) + + self.embedded_speaker_dim = 0 + self.init_multispeaker(config) + + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + self.embedded_speaker_dim, + ) + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp) + + self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1) + + self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels) + + if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1) + + @staticmethod + def compute_log_probs(mu, log_sigma, y): + # pylint: disable=protected-access, c-extension-no-member + y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) + exponential = -0.5 * torch.mean( + torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1 + ) # B, L, T + logp = exponential - 0.5 * log_sigma.mean(dim=-1) + return logp + + def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): + # find the max alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + log_p = self.compute_log_probs(mu, log_sigma, y) + # [B, T_en, T_dec] + attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1) + dr_mas = torch.sum(attn, -1) + return dr_mas.squeeze(1), log_p + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Examples:: + - encoder output: [a,b,c,d] + - durations: [1, 3, 2, 1] + + - expanded: [a, b, b, b, c, c, d] + - attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # fix extreme predictions #MYEDITS + if hasattr(self, "pos_encoder"): + if dr.sum() > self.pos_encoder.max_len: + dr = torch.floor(dr * torch.div(self.pos_encoder.max_len, dr.sum())) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de, attn.transpose(1, 2) + + def _forward_mdn(self, o_en, y, y_lengths, x_mask): + # MAS potentials and alignment + mu, log_sigma = self.mdn_block(o_en) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) + return dr_mas, mu, log_sigma, logp + + def forward( + self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None + ): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + """ + y = y.transpose(1, 2) + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None + if phase == 0: + # train encoder and MDN + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + attn = self.generate_attn(dr_mas, x_mask, y_mask) + elif phase == 1: + # train decoder + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) + elif phase == 2: + # train the whole except duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + elif phase == 3: + # train duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(x, x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + else: + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + dr_mas_log = torch.log(dr_mas + 1).squeeze(1) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "alignments": attn, + "durations_log": o_dr_log, + "durations_mas_log": dr_mas_log, + "mu": mu, + "log_sigma": log_sigma, + "logp": logp, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - g: :math:`[B, C]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + # pad input to prevent dropping the last word + # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # o_dr_log = self.duration_predictor(x, x_mask) + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + # duration predictor pass + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn} + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input, self.phase) + loss_dict = criterion( + outputs["logp"], + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + outputs["durations_mas_log"], + text_lengths, + phase=self.phase, + ) + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel + + return AlignTTSLoss(self.config) + + @staticmethod + def _set_phase(config, global_step): + """Decide AlignTTS training phase""" + if isinstance(config.phase_start_steps, list): + vals = [i < global_step for i in config.phase_start_steps] + if not True in vals: + phase = 0 + else: + phase = ( + len(config.phase_start_steps) + - [i < global_step for i in config.phase_start_steps][::-1].index(True) + - 1 + ) + else: + phase = None + return phase + + def on_epoch_start(self, trainer): + """Set AlignTTS training phase on epoch start.""" + self.phase = self._set_phase(trainer.config, trainer.total_steps_done) + + @staticmethod + def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (AlignTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return AlignTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f4c3392deedddaf0fa133cc751c45d52fd908a --- /dev/null +++ b/TTS/tts/models/base_tacotron.py @@ -0,0 +1,299 @@ +import copy +from abc import abstractmethod +from typing import Dict, Tuple + +import torch +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.losses import TacotronLoss +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec +from TTS.utils.training import gradual_training_scheduler + + +class BaseTacotron(BaseTTS): + """Base class shared by Tacotron and Tacotron2""" + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields as class attributes + for key in config: + setattr(self, key, config[key]) + + # layers + self.embedding = None + self.encoder = None + self.decoder = None + self.postnet = None + + # init tensors + self.embedded_speakers = None + self.embedded_speakers_projected = None + + # global style token + if self.gst and self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim + self.gst_layer = None + + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim + self.capacitron_vae_layer = None + + # additional layers + self.decoder_backward = None + self.coarse_decoder = None + + @staticmethod + def _format_aux_input(aux_input: Dict) -> Dict: + """Set missing fields to their default values""" + if aux_input: + return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + return None + + ############################# + # INIT FUNCTIONS + ############################# + + def _init_backward_decoder(self): + """Init the backward decoder for Forward-Backward decoding.""" + self.decoder_backward = copy.deepcopy(self.decoder) + + def _init_coarse_decoder(self): + """Init the coarse decoder for Double-Decoder Consistency.""" + self.coarse_decoder = copy.deepcopy(self.decoder) + self.coarse_decoder.r_init = self.ddc_r + self.coarse_decoder.set_r(self.ddc_r) + + ############################# + # CORE FUNCTIONS + ############################# + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self): + pass + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load model checkpoint and set up internals. + + Args: + config (Coqpi): model configuration. + checkpoint_path (str): path to checkpoint file. + eval (bool): whether to load model for evaluation. + """ + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + # TODO: set r in run-time by taking it from the new config + if "r" in state: + # set r from the state (for compatibility with older checkpoints) + self.decoder.set_r(state["r"]) + elif "config" in state: + # set r from config used at training time (for inference) + self.decoder.set_r(state["config"]["r"]) + else: + # set r from the new config (for new-models) + self.decoder.set_r(config.r) + if eval: + self.eval() + print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") + assert not self.training + + def get_criterion(self) -> nn.Module: + """Get the model criterion used in training.""" + return TacotronLoss(self.config) + + @staticmethod + def init_from_config(config: Coqpit): + """Initialize model from config.""" + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return BaseTacotron(config, ap, tokenizer, speaker_manager) + + ########################## + # TEST AND LOG FUNCTIONS # + ########################## + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + ############################# + # COMMON COMPUTE FUNCTIONS + ############################# + + def compute_masks(self, text_lengths, mel_lengths): + """Compute masks against sequence paddings.""" + # B x T_in_max (boolean) + input_mask = sequence_mask(text_lengths) + output_mask = None + if mel_lengths is not None: + max_len = mel_lengths.max() + r = self.decoder.r + max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len + output_mask = sequence_mask(mel_lengths, max_len=max_len) + return input_mask, output_mask + + def _backward_pass(self, mel_specs, encoder_outputs, mask): + """Run backwards decoder""" + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask + ) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() + return decoder_outputs_b, alignments_b + + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): + """Double Decoder Consistency""" + T = mel_specs.shape[1] + if T % self.coarse_decoder.r > 0: + padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) + mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) + decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( + encoder_outputs.detach(), mel_specs, input_mask + ) + # scale_factor = self.decoder.r_init / self.decoder.r + alignments_backward = torch.nn.functional.interpolate( + alignments_backward.transpose(1, 2), + size=alignments.shape[1], + mode="nearest", + ).transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward[:, :T, :] + return decoder_outputs_backward, alignments_backward + + ############################# + # EMBEDDING FUNCTIONS + ############################# + + def compute_gst(self, inputs, style_input, speaker_embedding=None): + """Compute global style token""" + if isinstance(style_input, dict): + # multiply each style token with a weight + query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) + if speaker_embedding is not None: + query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) + + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + # ignore style token and return zero tensor + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + else: + # compute style tokens + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) + return inputs + + def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None): + """Capacitron Variational Autoencoder""" + (VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer( + reference_mel_info, + text_info, + speaker_embedding, # pylint: disable=not-callable + ) + + VAE_outputs = VAE_outputs.to(inputs.device) + encoder_output = self._concat_speaker_embedding( + inputs, VAE_outputs + ) # concatenate to the output of the basic tacotron encoder + return ( + encoder_output, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) + + @staticmethod + def _add_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = outputs + embedded_speakers_ + return outputs + + @staticmethod + def _concat_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, embedded_speakers_], dim=-1) + return outputs + + ############################# + # CALLBACKS + ############################# + + def on_epoch_start(self, trainer): + """Callback for setting values wrt gradual training schedule. + + Args: + trainer (TrainerTTS): TTS trainer object that is used to train this model. + """ + if self.gradual_training: + r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) + trainer.config.r = r + self.decoder.set_r(r) + if trainer.config.bidirectional_decoder: + trainer.model.decoder_backward.set_r(r) + print(f"\n > Number of output frames: {self.decoder.r}") diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..c86bd391b4281195478c79d332cb6dbbd33cdfb4 --- /dev/null +++ b/TTS/tts/models/base_tts.py @@ -0,0 +1,430 @@ +import os +import random +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper + +from TTS.model import BaseTrainerModel +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + +# pylint: skip-file + + +class BaseTTS(BaseTrainerModel): + """Base `tts` class. Every new `tts` model must inherit this. + + It defines common `tts` specific functions on top of `Model` implementation. + """ + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__() + self.config = config + self.ap = ap + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + config_num_chars = ( + self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars + ) + num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + if "characters" in config: + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. + + Args: + config (Coqpit): Model configuration. + """ + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_setences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embeddings() + else: + d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.ids[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `TTSDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + language_ids = batch["language_ids"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + "language_ids": language_ids, + } + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + print(" > Using Length weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if hasattr(self, "speaker_manager") and self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None + else: + speaker_id_mapping = None + d_vector_mapping = None + + # setup multi-lingual attributes + if hasattr(self, "language_manager") and self.language_manager is not None: + language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), + samples=samples, + ap=self.ap, + return_wav=config.return_wav if "return_wav" in config else False, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + use_noise_augment=False if is_eval else config.use_noise_augment, + verbose=verbose, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + language_id_mapping=language_id_mapping, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def _get_test_aux_input( + self, + ) -> Dict: + + d_vector = None + if self.config.use_d_vector_file: + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] + d_vector = (random.sample(sorted(d_vector), 1),) + + aux_inputs = { + "speaker_id": None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.ids.values()), 1), + "d_vector": d_vector, + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return test_figures, test_audios + + def on_init_start(self, trainer): + """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.pth") + self.speaker_manager.save_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `speakers.pth` is saved to {output_path}.") + print(" > `speakers_file` is updated in the config.json.") + + if hasattr(self, "language_manager") and self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `language_ids.json` is saved to {output_path}.") + print(" > `language_ids_file` is updated in the config.json.") diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..a6568a49627f8479a1122056d8f802275dc174ab --- /dev/null +++ b/TTS/tts/models/forward_tts.py @@ -0,0 +1,887 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union + +import torch +import torchaudio +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast + +from TTS.config import load_config +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.audio import AudioProcessor +from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram +from TTS.vocoder.models import setup_model as setup_vocoder_model +from trainer.trainer_utils import get_optimizer, get_scheduler + +@dataclass +class ForwardTTSArgs(Coqpit): + """ForwardTTS Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + use_aligner (bool): + Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. + If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the + pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True. + + use_pitch (bool): + Use pitch predictor to learn the pitch. Defaults to True. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + use_speaker_encoder_as_loss (bool): + Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. + + speaker_encoder_config_path (str): + Path to the file speaker encoder config file, to use for SCL. Defaults to "". + + speaker_encoder_model_path (str): + Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 384 + use_aligner: bool = True + use_pitch: bool = True + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + duration_predictor_hidden_channels: int = 256 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + detach_duration_predictor: bool = False + max_duration: int = 75 + num_speakers: int = 1 + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_dim: int = None + d_vector_file: str = None + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + # external vocoder for speaker encoder loss + vocoder_path: str = None + vocoder_config_path: str = None + use_separate_optimizers: bool = False + + +class ForwardTTS(BaseTTS): + """General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment + network and a pitch predictor. + + If the alignment network is used, the model learns the text-to-speech alignment + from the data instead of using pre-computed durations. + + If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each + input character as in the FastPitch model. + + `ForwardTTS` can be configured to one of these architectures, + + - FastPitch + - SpeedySpeech + - FastSpeech + - TODO: FastSpeech2 (requires average speech energy predictor) + + Args: + config (Coqpit): Model coqpit class. + speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models. + Defaults to None. + + Examples: + >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs + >>> config = ForwardTTSArgs() + >>> model = ForwardTTS(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self._set_model_args(config) + + self.init_multispeaker(config) + + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_pitch = self.args.use_pitch + self.binary_loss_weight = 0.0 + self.train_aligner = True + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.embedded_speaker_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) + + self.decoder = Decoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + self.args.hidden_channels + self.embedded_speaker_dim, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, + ) + + if self.args.use_pitch: + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels + self.embedded_speaker_dim, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + ) + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_aligner: + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) + + if self.args.vocoder_path and self.args.vocoder_config_path: + self.vocoder_config = load_config(self.args.vocoder_config_path) + self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) + self.vocoder_model = setup_vocoder_model(self.vocoder_config) + self.vocoder_model.load_checkpoint(self.vocoder_config, self.args.vocoder_path, eval=False) + self.vocoder_model.cuda() + print("> Vocoder loaded for speaker_encoder_loss") + + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # init speaker manager + if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # init d-vector embedding + if config.use_d_vector_file: + #self.embedded_speaker_dim = config.d_vector_dim + if self.args.d_vector_dim != self.args.hidden_channels: + self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.encoder is None and ( + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + + self.speaker_manager.encoder.eval() + print(" > External Speaker Encoder Loaded !!") + + # pylint: disable=W0101,W0105 + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], + ) + + # as we are loading spectograms directly + # self.speaker_manager.encoder.use_torch_spec = False + # print(" > External Speaker Encoder use_torch_spec is set to False !!") + # if self.args.out_channels != self.speaker_manager.encoder.input_dim: + # self.pre_speaker_encoder = nn.Conv1d(self.args.out_channels, self.speaker_manager.encoder.input_dim, 1) + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Shapes: + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + + Examples:: + + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + def _forward_encoder( + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Sum encoder outputs and speaker embeddings + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ + if hasattr(self, "emb_g"): + g = self.emb_g(g) # [] -> [C] for single input; [B] -> [B, C] + if g is not None: + g = g.unsqueeze(-1) # [C] -> [C, 1] for single input; [B, C] -> [B, C, 1] + x_emb = self.emb(x) # [T] -> [T, C] for single input; [B, T] -> [B, T, C] + # encoder pass + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) # [C, T] for single input; [B, C, T] + # speaker conditioning + # TODO: try different ways of conditioning + if g is not None: + o_en = o_en + g # [C, T] for single input; [B, C, T] + return o_en, x_mask, g, x_emb + + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_over_durations(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument + waveform: torch.tensor = None, + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + g = self._set_speaker_input(aux_input) + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + # encoder pass + o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner + o_alignment_dur = None + alignment_soft = None + alignment_logprob = None + alignment_mas = None + if self.use_aligner: + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) + alignment_soft = alignment_soft.transpose(1, 2) + alignment_mas = alignment_mas.transpose(1, 2) + dr = o_alignment_dur + # pitch predictor pass + o_pitch = None + avg_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder( + o_en, dr, x_mask, y_lengths, g=None + ) # TODO: maybe pass speaker embedding (g) too + + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: + # ensure tss config and vocoder config are same + waveform_pred = self.vocoder_model.forward(o_de.transpose(1, 2)) + + # concate generated and GT waveforms + wavs_batch = torch.cat((waveform.squeeze(dim=2), waveform_pred.squeeze(dim=1)), dim=0) + + # resample audio to speaker encoder sample_rate + # pylint: disable=W0105 + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + pred_embs = self.speaker_manager.encoder.forward(wavs_batch.float(), l2_norm=True) + + # specs_batch = torch.cat((y, o_de), dim=0) + # specs_batch = specs_batch.transpose(1, 2) # swapping time and freq dimensions # [B, F, T] + # if self.pre_speaker_encoder: # specs_batch.size(1) != self.speaker_manager.encoder.input_dim: + # specs_batch = self.pre_speaker_encoder(specs_batch) + # specs_batch = torch.nn.functional.relu(specs_batch) + # pred_embs = self.speaker_manager.encoder.forward(specs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + + outputs = { + "model_outputs": o_de, # [B, T, C] + "durations_log": o_dr_log.squeeze(1), # [B, T] + "durations": o_dr.squeeze(1), # [B, T] + "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, + "alignments": attn, # [B, T_de, T_en] + "alignment_soft": alignment_soft, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = self._set_speaker_input(aux_input) + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + # pitch predictor pass + o_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "durations_log": o_dr_log, + } + return outputs + + @torch.no_grad() + def inference2(self, x, x_lengths, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = self._set_speaker_input(aux_input) + #x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + # pitch predictor pass + o_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "durations_log": o_dr_log, + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx=None): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + waveform = batch["waveform"] + pitch = batch["pitch"] if self.args.use_pitch else None + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + + + # forward pass + outputs = self.forward( + text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input, waveform=waveform + ) + # use aligner's output as the duration target + if self.use_aligner: + durations = outputs["o_alignment_dur"] + # use float32 in AMP + with autocast(enabled=False): + # compute loss + loss_dict = criterion( + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch_avg"] if self.use_pitch else None, + pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, + alignment_soft=outputs["alignment_soft"], + alignment_hard=outputs["alignment_mas"], + binary_loss_weight=self.binary_loss_weight, + train_aligner=self.train_aligner, + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=outputs['gt_spk_emb'], + syn_spk_emb=outputs['syn_spk_emb'], + ) + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + """Create common logger outputs.""" + if isinstance(outputs, list): + outputs = outputs[0] + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx=None): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel + + return ForwardTTSLoss(self.config) + + def on_train_step_start(self, trainer): + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + if trainer.epochs_done >= self.config.aligner_epochs: + self.train_aligner = False + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (ForwardTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + if config.model_args.speaker_encoder_model_path: # use_speaker_encoder_as_loss + speaker_manager.init_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) + # as we are loading spectograms directly + speaker_manager.encoder.use_torch_spec = False + return ForwardTTS(new_config, ap, tokenizer, speaker_manager) + + def get_optimizer(self): + if self.args.use_separate_optimizers: + parameters = (value for key, value in self.named_parameters() if not key.startswith('vocoder_model.') and not key.startswith('aligner.')) + parameters_aligner = (value for key, value in self.named_parameters() if key.startswith('aligner.')) + optimizer = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, parameters=parameters) + optimizer_aligner = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, parameters=parameters_aligner) + return [optimizer, optimizer_aligner] + else: + parameters = (value for key, value in self.named_parameters() if not key.startswith('vocoder_model.')) + optimizer = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, parameters=parameters) + return optimizer \ No newline at end of file diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0f95e151cf6b468d55988002d691b176faa98c --- /dev/null +++ b/TTS/tts/models/glow_tts.py @@ -0,0 +1,558 @@ +import math +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F + +from TTS.tts.configs.glow_tts_config import GlowTTSConfig +from TTS.tts.layers.glow_tts.decoder import Decoder +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.io import load_fsspec + + +class GlowTTS(BaseTTS): + """GlowTTS model. + + Paper:: + https://arxiv.org/abs/2005.11129 + + Paper abstract:: + Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate + mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained + without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS, + a flow-based generative model for parallel TTS that does not require any external aligner. By combining the + properties of flows and dynamic programming, the proposed model searches for the most probable monotonic + alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard + monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows + enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over + the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our + model can be easily extended to a multi-speaker setting. + + Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. + + Examples: + Init only model layers. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig(num_chars=2) + >>> model = GlowTTS(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig() + >>> model = GlowTTS.init_from_config(config, verbose=False) + """ + + def __init__( + self, + config: GlowTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + # init multi-speaker layers if necessary + self.init_multispeaker(config) + + self.run_data_dep_init = config.data_dep_init_steps > 0 + self.encoder = Encoder( + self.num_chars, + out_channels=self.out_channels, + hidden_channels=self.hidden_channels_enc, + hidden_channels_dp=self.hidden_channels_dp, + encoder_type=self.encoder_type, + encoder_params=self.encoder_params, + mean_only=self.mean_only, + use_prenet=self.use_encoder_prenet, + dropout_p_dp=self.dropout_p_dp, + c_in_channels=self.c_in_channels, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + def init_multispeaker(self, config: Coqpit): + """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding + vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets + speaker embedding vector dimension to the d-vector dimension from the config. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # set ultimate speaker embedding size + if config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + if self.speaker_manager is not None: + assert ( + config.d_vector_dim == self.speaker_manager.embedding_dim + ), " [!] d-vector dimension mismatch b/w config and speaker manager." + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.embedded_speaker_dim = self.hidden_channels_enc + self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + # set conditioning dimensions + self.c_in_channels = self.embedded_speaker_dim + + @staticmethod + def compute_outputs(attn, o_mean, o_log_scale, x_mask): + """Compute and format the mode outputs with the given alignment map""" + y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + # compute total duration with adjustment + o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask + return y_mean, y_log_scale, o_attn_dur + + def unlock_act_norm_layers(self): + """Unlock activation normalization layers for data depended initalization.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(True) + + def lock_act_norm_layers(self): + """Lock activation normalization layers.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(False) + + def _set_speaker_input(self, aux_input: Dict): + if aux_input is None: + d_vectors = None + speaker_ids = None + else: + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]: + g = self._set_speaker_input(aux_input) + # speaker embedding + if g is not None: + if hasattr(self, "emb_g"): + # use speaker embedding layer + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + else: + # use d-vector + g = F.normalize(g).unsqueeze(-1) # [b, h, 1] + return g + + def forward( + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + + aux_input (Dict): + Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. + :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding + layer. :math:`B` + + Returns: + Dict: + - z: :math: `[B, T_de, C]` + - logdet: :math:`B` + - y_mean: :math:`[B, T_de, C]` + - y_log_scale: :math:`[B, T_de, C]` + - alignments: :math:`[B, T_en, T_de]` + - durations_log: :math:`[B, T_en, 1]` + - total_durations_log: :math:`[B, T_en, 1]` + """ + # [B, T, C] -> [B, C, T] + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + # [B, 1, T_en, T_de] + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "z": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def inference_with_MAS( + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + It's similar to the teacher forcing in Tacotron. + It was proposed in: https://arxiv.org/abs/2104.05557 + + Shapes: + - x: :math:`[B, T]` + - x_lenghts: :math:`B` + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path between z and encoder output + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + + # get predited aligned distribution + z = y_mean * y_mask + + # reverse the decoder and predict using the aligned distribution + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = { + "model_outputs": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def decoder_inference( + self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Shapes: + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + g = self._speaker_embedding(aux_input) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # reverse decoder and predict + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = {} + outputs["model_outputs"] = y.transpose(1, 2) + outputs["logdet"] = logdet + return outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + x_lengths = aux_input["x_lengths"] + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # compute output durations + w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale + w_ceil = torch.clamp_min(torch.ceil(w), 1) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = None + # compute masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # compute attention mask + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask + # decoder pass + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "model_outputs": y.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + """A single training step. Forward pass and loss computation. Run data depended initialization for the + first `config.data_dep_init_steps` steps. + + Args: + batch (dict): [description] + criterion (nn.Module): [description] + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + if self.run_data_dep_init and self.training: + # compute data-dependent initialization of activation norm layers + self.unlock_act_norm_layers() + with torch.no_grad(): + _ = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + outputs = None + loss_dict = None + self.lock_act_norm_layers() + else: + # normal training step + outputs = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + + with autocast(enabled=False): # avoid mixed_precision in criterion + loss_dict = criterion( + outputs["z"].float(), + outputs["y_mean"].float(), + outputs["y_log_scale"].float(), + outputs["logdet"].float(), + mel_lengths, + outputs["durations_log"].float(), + outputs["total_durations_log"].float(), + text_lengths, + ) + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + alignments = outputs["alignments"] + text_input = batch["text_input"][:1] if batch["text_input"] is not None else None + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None + speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None + + # model runs reverse flow to predict spectrograms + pred_outputs = self.inference( + text_input, + aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + model_outputs = pred_outputs["model_outputs"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + if len(test_sentences) == 0: + print(" | [!] No test sentences provided.") + else: + for idx, sen in enumerate(test_sentences): + outputs = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + + test_audios["{}-audio".format(idx)] = outputs["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return test_figures, test_audios + + def preprocess(self, y, y_lengths, y_max_length, attn=None): + if y_max_length is not None: + y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze + y = y[:, :, :y_max_length] + if attn is not None: + attn = attn[:, :, :, :y_max_length] + y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze + return y, y_lengths, y_max_length, attn + + def store_inverse(self): + self.decoder.store_inverse() + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.store_inverse() + assert not self.training + + @staticmethod + def get_criterion(): + from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel + + return GlowTTSLoss() + + def on_train_step_start(self, trainer): + """Decide on every training step wheter enable/disable data depended initialization.""" + self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps + + @staticmethod + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfa6ba5e4d2ae502a89fb04621e8be0ac771e2f --- /dev/null +++ b/TTS/tts/models/tacotron.py @@ -0,0 +1,410 @@ +# coding: utf-8 + +from typing import Dict, List, Tuple, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron(BaseTacotron): + """Tacotron as in https://arxiv.org/abs/1703.10135 + It's an autoregressive encoder-attention-decoder-postnet architecture. + Check `TacotronConfig` for the arguments. + + Args: + config (TacotronConfig): Configuration for the Tacotron model. + speaker_manager (SpeakerManager): Speaker manager to handle multi-speaker settings. Only use if the model is + a multi-speaker model. Defaults to None. + """ + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # set speaker embedding channel size for determining `in_channels` for the connected layers. + # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based + # on the number of speakers infered from the dataset. + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0) + self.embedding.weight.data.normal_(0, 0.3) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = PostCBHG(self.decoder_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=self.embedded_speaker_dim + if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding + else None, + text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """ + Shapes: + text: [B, T_in] + text_lengths: [B] + mel_specs: [B, T_out, C] + mel_lengths: [B] + aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C] + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + inputs = self.embedding(text) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x T_in x encoder_in_features + encoder_outputs = self.encoder(inputs) + # sequence masking + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # global style token + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # speaker embedding + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=[inputs, text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + # decoder_outputs: B x decoder_in_features x T_out + # alignments: B x T_in x encoder_in_features + # stop_tokens: B x T_in + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if output_mask is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x T_out x decoder_in_features + postnet_outputs = self.postnet(decoder_outputs) + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs) + # B x T_out x posnet_dim + postnet_outputs = self.last_linear(postnet_outputs) + # B x T_out x decoder_in_features + decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text_input, aux_input=None): + aux_input = self._format_aux_input(aux_input) + inputs = self.embedding(text_input) + encoder_outputs = self.encoder(inputs) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[aux_input["style_mel"], reference_mel_length] + if aux_input["style_mel"] is not None + else None, + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=aux_input["d_vectors"] + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + ) + if self.num_speakers > 1: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"]) + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = self.last_linear(postnet_outputs) + decoder_outputs = decoder_outputs.transpose(1, 2) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: + """Perform a single training step by fetching the right set of samples from the batch. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([torch.nn.Module]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + linear_input = batch["linear_input"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + linear_input.float(), + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + postnet_outputs = outputs["model_outputs"] + decoder_outputs = outputs["decoder_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + linear_input = batch["linear_input"] + + pred_linear_spec = postnet_outputs[0].data.cpu().numpy() + pred_mel_spec = decoder_outputs[0].data.cpu().numpy() + gt_linear_spec = linear_input[0].data.cpu().numpy() + gt_mel_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "pred_linear_spec": plot_spectrogram(pred_linear_spec, ap, output_fig=False), + "real_linear_spec": plot_spectrogram(gt_linear_spec, ap, output_fig=False), + "pred_mel_spec": plot_spectrogram(pred_mel_spec, ap, output_fig=False), + "real_mel_spec": plot_spectrogram(gt_mel_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_spectrogram(pred_linear_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (TacotronConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Tacotron(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..95d339f17d54f7130e2bd5d435620df41d100b6d --- /dev/null +++ b/TTS/tts/models/tacotron2.py @@ -0,0 +1,434 @@ +# coding: utf-8 + +from typing import Dict, List, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron2(BaseTacotron): + """Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`. + + Paper:: + https://arxiv.org/abs/1712.05884 + + Paper abstract:: + This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text. + The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character + embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize + timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable + to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation + studies of key components of our system and evaluate the impact of using mel spectrograms as the input to + WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic + intermediate representation enables significant simplification of the WaveNet architecture. + + Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments. + + Args: + config (TacotronConfig): + Configuration for the Tacotron2 model. + speaker_manager (SpeakerManager): + Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. + """ + + def __init__( + self, + config: "Tacotron2Config", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + + super().__init__(config, ap, tokenizer, speaker_manager) + + self.decoder_output_dim = config.out_channels + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # init multi-speaker layers + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = Postnet(self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron VAE Layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=self.embedded_speaker_dim + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + @staticmethod + def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): + """Final reshape of the model output tensors.""" + mel_outputs = mel_outputs.transpose(1, 2) + mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) + return mel_outputs, mel_outputs_postnet, alignments + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """Forward pass for training with Teacher Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + mel_specs: :math:`[B, T_out, C]` + mel_lengths: :math:`[B]` + aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]` + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + # compute mask for padding + # B x T_in_max (boolean) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x D_embed x T_in_max + embedded_inputs = self.embedding(text).transpose(1, 2) + # B x T_in_max x D_en + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + # capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=[embedded_inputs.transpose(1, 2), text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + + # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if mel_lengths is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x mel_dim x T_out + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) + # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text, aux_input=None): + """Forward pass for inference with no Teacher-Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + """ + aux_input = self._format_aux_input(aux_input) + embedded_inputs = self.embedding(text).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[aux_input["style_mel"], reference_mel_length] + if aux_input["style_mel"] is not None + else None, + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=aux_input["d_vectors"] + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + ) + + if self.num_speakers > 1: + if not self.use_d_vector_file: + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None] + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + embedded_speakers = aux_input["d_vectors"] + + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module): + """A single training step. Forward pass and loss computation. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([type]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + None, + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + """Create dashboard log information.""" + postnet_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + + pred_spec = postnet_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (Tacotron2Config): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(new_config, samples) + return Tacotron2(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b1c74332a89757fe28c143acd18c938bd5f274 --- /dev/null +++ b/TTS/tts/models/vits.py @@ -0,0 +1,1704 @@ +import math +import os +from dataclasses import dataclass, field, replace +from itertools import chain +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder +from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + +############################## +# IO / Feature extraction +############################## + +# pylint: disable=global-statement +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +class VitsDataset(TTSDataset): + def __init__(self, model_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + self.model_args = model_args + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + if self.model_args.encoder_sample_rate is not None: + if wav.size(1) % self.model_args.encoder_sample_rate != 0: + wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] + + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + } + + +############################## +# MODEL DEFINITION +############################## + + +@dataclass +class VitsArgs(Coqpit): + """VITS model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels of the decoder. Defaults to 513. + + spec_segment_size (int): + Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. + + hidden_channels (int): + Number of hidden channels of the model. Defaults to 192. + + hidden_channels_ffn_text_encoder (int): + Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. + + num_heads_text_encoder (int): + Number of attention heads of the text encoder transformer. Defaults to 2. + + num_layers_text_encoder (int): + Number of transformer layers in the text encoder. Defaults to 6. + + kernel_size_text_encoder (int): + Kernel size of the text encoder transformer FFN layers. Defaults to 3. + + dropout_p_text_encoder (float): + Dropout rate of the text encoder. Defaults to 0.1. + + dropout_p_duration_predictor (float): + Dropout rate of the duration predictor. Defaults to 0.1. + + kernel_size_posterior_encoder (int): + Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. + + dilatation_posterior_encoder (int): + Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. + + num_layers_posterior_encoder (int): + Number of posterior encoder's WaveNet layers. Defaults to 16. + + kernel_size_flow (int): + Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. + + dilatation_flow (int): + Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. + + num_layers_flow (int): + Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. + + resblock_type_decoder (str): + Type of the residual block in the decoder network. Defaults to "1". + + resblock_kernel_sizes_decoder (List[int]): + Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. + + resblock_dilation_sizes_decoder (List[List[int]]): + Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. + + upsample_rates_decoder (List[int]): + Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these + values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. + + upsample_initial_channel_decoder (int): + Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. + + upsample_kernel_sizes_decoder (List[int]): + Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + + periods_multi_period_discriminator (List[int]): + Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. + + use_sdp (bool): + Use Stochastic Duration Predictor. Defaults to True. + + noise_scale (float): + Noise scale used for the sample noise tensor in training. Defaults to 1.0. + + inference_noise_scale (float): + Noise scale used for the sample noise tensor in inference. Defaults to 0.667. + + length_scale (float): + Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. + + noise_scale_dp (float): + Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. + + inference_noise_scale_dp (float): + Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. + + max_inference_len (int): + Maximum inference length to limit the memory use. Defaults to None. + + init_discriminator (bool): + Initialize the disciminator network if set True. Set False for inference. Defaults to True. + + use_spectral_norm_disriminator (bool): + Use spectral normalization over weight norm in the discriminator. Defaults to False. + + use_speaker_embedding (bool): + Enable/Disable speaker embedding for multi-speaker models. Defaults to False. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + detach_dp_input (bool): + Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + + use_language_embedding (bool): + Enable/Disable language embedding for multilingual models. Defaults to False. + + embedded_language_dim (int): + Number of language embedding channels. Defaults to 4. + + num_languages (int): + Number of languages for the language embedding layer. Defaults to 0. + + language_ids_file (str): + Path to the language mapping file for the Language Manager. Defaults to None. + + use_speaker_encoder_as_loss (bool): + Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. + + speaker_encoder_config_path (str): + Path to the file speaker encoder config file, to use for SCL. Defaults to "". + + speaker_encoder_model_path (str): + Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + + freeze_encoder (bool): + Freeze the encoder weigths during training. Defaults to False. + + freeze_DP (bool): + Freeze the duration predictor weigths during training. Defaults to False. + + freeze_PE (bool): + Freeze the posterior encoder weigths during training. Defaults to False. + + freeze_flow_encoder (bool): + Freeze the flow encoder weigths during training. Defaults to False. + + freeze_waveform_decoder (bool): + Freeze the waveform decoder weigths during training. Defaults to False. + + encoder_sample_rate (int): + If not None this sample rate will be used for training the Posterior Encoder, + flow, text_encoder and duration predictor. The decoder part (vocoder) will be + trained with the `config.audio.sample_rate`. Defaults to None. + + interpolate_z (bool): + If `encoder_sample_rate` not None and this parameter True the nearest interpolation + will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` + to the `config.audio.sample_rate`. If it is False you will need to add extra + `upsample_rates_decoder` to match the shape. Defaults to True. + + """ + + num_chars: int = 100 + out_channels: int = 513 + spec_segment_size: int = 32 + hidden_channels: int = 192 + hidden_channels_ffn_text_encoder: int = 768 + num_heads_text_encoder: int = 2 + num_layers_text_encoder: int = 6 + kernel_size_text_encoder: int = 3 + dropout_p_text_encoder: float = 0.1 + dropout_p_duration_predictor: float = 0.5 + kernel_size_posterior_encoder: int = 5 + dilation_rate_posterior_encoder: int = 1 + num_layers_posterior_encoder: int = 16 + kernel_size_flow: int = 5 + dilation_rate_flow: int = 1 + num_layers_flow: int = 4 + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + use_sdp: bool = True + noise_scale: float = 1.0 + inference_noise_scale: float = 0.667 + length_scale: float = 1 + noise_scale_dp: float = 1.0 + inference_noise_scale_dp: float = 1.0 + max_inference_len: int = None + init_discriminator: bool = True + use_spectral_norm_disriminator: bool = False + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: str = None + speaker_embedding_channels: int = 256 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + detach_dp_input: bool = True + use_language_embedding: bool = False + embedded_language_dim: int = 4 + num_languages: int = 0 + language_ids_file: str = None + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True + freeze_encoder: bool = False + freeze_DP: bool = False + freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False + encoder_sample_rate: int = None + interpolate_z: bool = True + reinit_DP: bool = False + reinit_text_encoder: bool = False + + +class Vits(BaseTTS): + """VITS TTS model + + Paper:: + https://arxiv.org/pdf/2106.06103.pdf + + Paper Abstract:: + Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel + sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. + In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than + current two-stage models. Our method adopts variational inference augmented with normalizing flows and + an adversarial training process, which improves the expressive power of generative modeling. We also propose a + stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the + uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the + natural one-to-many relationship in which a text input can be spoken in multiple ways + with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) + on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly + available TTS systems and achieves a MOS comparable to ground truth. + + Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. + + Examples: + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> from TTS.tts.models.vits import Vits + >>> config = VitsConfig() + >>> model = Vits(config) + """ + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) + + self.init_multispeaker(config) + self.init_multilingual(config) + self.init_upsampling() + + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size + + self.text_encoder = TextEncoder( + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, + ) + + self.posterior_encoder = PosteriorEncoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, + cond_channels=self.embedded_speaker_dim, + ) + + self.flow = ResidualCouplingBlocks( + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, + cond_channels=self.embedded_speaker_dim, + ) + + if self.args.use_sdp: + self.duration_predictor = StochasticDurationPredictor( + self.args.hidden_channels, + 192, + 3, + self.args.dropout_p_duration_predictor, + 4, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + language_emb_dim=self.embedded_language_dim, + ) + else: + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + 256, + 3, + self.args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, + ) + + self.waveform_decoder = HifiganGenerator( + self.args.hidden_channels, + 1, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, + inference_padding=0, + cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.args.init_discriminator: + self.disc = VitsDiscriminator( + periods=self.args.periods_multi_period_discriminator, + use_spectral_norm=self.args.use_spectral_norm_disriminator, + ) + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + # TODO: make this a function + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.encoder is None and ( + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + + self.speaker_manager.encoder.eval() + print(" > External Speaker Encoder Loaded !!") + + if ( + hasattr(self.speaker_manager.encoder, "audio_config") + and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"] + ): + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.audio_config["sample_rate"], + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], + ) + # pylint: disable=W0101,W0105 + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], + ) + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + def init_multilingual(self, config: Coqpit): + """Initialize multilingual modules of a model. + + Args: + config (Coqpit): Model configuration. + """ + if self.args.language_ids_file is not None: + self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + + if self.args.use_language_embedding and self.language_manager: + print(" > initialization of language-embedding layers.") + self.num_languages = self.language_manager.num_languages + self.embedded_language_dim = self.args.embedded_language_dim + self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) + else: + self.embedded_language_dim = 0 + + def init_upsampling(self): + """ + Initialize upsampling modules of a model. + """ + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate + ) # pylint: disable=W0201 + + def on_init_end(self, trainer): # pylint: disable=W0613 + """Reinit layes if needed""" + if self.args.reinit_DP: + before_dict = get_module_weights_sum(self.duration_predictor) + # Applies weights_reset recursively to every submodule of the duration predictor + self.duration_predictor.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.duration_predictor) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") + print(" > Duration Predictor was reinit.") + + if self.args.reinit_text_encoder: + before_dict = get_module_weights_sum(self.text_encoder) + # Applies weights_reset recursively to every submodule of the duration predictor + self.text_encoder.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.text_encoder) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") + print(" > Text Encoder was reinit.") + + def get_aux_input(self, aux_input: Dict): + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid = None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + return sid, g, lid + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn + + def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): + spec_segment_size = self.spec_segment_size + if self.args.encoder_sample_rate: + # recompute the slices and spec_segment_size if needed + slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids + spec_segment_size = spec_segment_size * int(self.interpolate_factor) + # interpolate z if needed + if self.args.interpolate_z: + z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0) + # recompute the mask if needed + if y_lengths is not None and y_mask is not None: + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + return z, spec_segment_size, slice_ids, y_mask + + def forward( # pylint: disable=dangerous-default-value + self, + x: torch.tensor, + x_lengths: torch.tensor, + y: torch.tensor, + y_lengths: torch.tensor, + waveform: torch.tensor, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + ) -> Dict: + """Forward pass of the model. + + Args: + x (torch.tensor): Batch of input character sequence IDs. + x_lengths (torch.tensor): Batch of input character sequence lengths. + y (torch.tensor): Batch of input spectrograms. + y_lengths (torch.tensor): Batch of input spectrogram lengths. + waveform (torch.tensor): Batch of ground truth waveforms per sample. + aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. + Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. + + Returns: + Dict: model outputs keyed by the output name. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - y: :math:`[B, C, T_spec]` + - y_lengths: :math:`[B]` + - waveform: :math:`[B, 1, T_wav]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + """ + outputs = {} + sid, g, lid = self._set_cond_input(aux_input) + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # duration predictor + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + + # select a random feature segment for the waveform decoder + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + + # interpolate z if needed + z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) + + o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform, + slice_ids * self.config.audio.hop_length, + spec_segment_size * self.config.audio.hop_length, + pad_short=True, + ) + + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0) + + # resample audio to speaker encoder sample_rate + # pylint: disable=W0105 + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + + pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + + outputs.update( + { + "model_outputs": o, + "alignments": attn.squeeze(1), + "m_p": m_p, + "logs_p": logs_p, + "z": z, + "z_p": z_p, + "m_q": m_q, + "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, + } + ) + return outputs + + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` + - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + """ + sid, g, lid = self._set_cond_input(aux_input) + x_lengths = self._set_x_lengths(x, aux_input) + + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + if self.args.use_sdp: + logw = self.duration_predictor( + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, + ) + else: + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) + + w = torch.exp(logw) * x_mask * self.length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + + m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + + # upsampling if needed + z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) + + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) + + outputs = { + "model_outputs": o, + "alignments": attn.squeeze(1), + "durations": w_ceil, + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference_voice_conversion( + self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None + ): + """Inference for voice conversion + + Args: + reference_wav (Tensor): Reference wavform. Tensor of shape [B, T] + speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B] + d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]` + reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B] + reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` + """ + # compute spectrograms + y = wav_to_spec( + reference_wav, + self.config.audio.fft_size, + self.config.audio.hop_length, + self.config.audio.win_length, + center=False, + ) + y_lengths = torch.tensor([y.size(-1)]).to(y.device) + speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector + speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector + # print(y.shape, y_lengths.shape) + wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) + return wav + + def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): + """Forward pass for voice conversion + + TODO: create an end-point for voice conversion + + Args: + y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] + y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] + speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] + speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] + """ + assert self.num_speakers > 0, "num_speakers have to be larger than 0." + # speaker embedding + if self.args.use_speaker_embedding and not self.args.use_d_vector_file: + g_src = self.emb_g(speaker_cond_src).unsqueeze(-1) + g_tgt = self.emb_g(speaker_cond_tgt).unsqueeze(-1) + elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: + g_src = F.normalize(speaker_cond_src).unsqueeze(-1) + g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) + else: + raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") + + z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + + self._freeze_layers() + + spec_lens = batch["spec_lens"] + + if optimizer_idx == 0: + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + waveform = batch["waveform"] + + # generator pass + outputs = self.forward( + tokens, + token_lenghts, + spec, + spec_lens, + waveform, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + # compute scores and features + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_real, + scores_disc_fake, + ) + return outputs, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + + if self.args.encoder_sample_rate: + spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) + else: + spec_segment_size = self.spec_segment_size + + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True + ) + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), + z_len=spec_lens, + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + figures = plot_results(y_hat, y, ap, name_prefix) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name_prefix}/audio": sample_voice} + + alignments = outputs[1]["alignments"] + align_img = alignments[0].data.cpu().numpy().T + + figures.update( + { + "alignment": plot_alignment(align_img, output_fig=False), + } + ) + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embeddings() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.ids[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding: + language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + if self.args.encoder_sample_rate: + wav = self.audio_resampler(batch["waveform"]) + else: + wav = batch["waveform"] + + # compute spectrograms + batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + + if self.args.encoder_sample_rate: + # recompute spec with high sampling rate to the loss + spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + # remove extra stft frames if needed + if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] + else: + spec_mel = batch["spec"] + + batch["mel"] = spec_to_mel( + spec=spec_mel, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + ) + + if self.args.encoder_sample_rate: + assert batch["spec"].shape[2] == int( + batch["mel"].shape[2] / self.interpolate_factor + ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + else: + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + + if self.args.encoder_sample_rate: + assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0 + else: + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = VitsDataset( + model_args=self.args, + samples=samples, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + # select generator parameters + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer0, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def get_criterion(self): + """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in + `train_step()`""" + from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel + VitsDiscriminatorLoss, + VitsGeneratorLoss, + ) + + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] + + def load_checkpoint( + self, + config, + checkpoint_path, + eval=False, + strict=True, + ): # pylint: disable=unused-argument, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + # compat band-aid for the pre-trained models to not use the encoder baked into the model + # TODO: consider baking the speaker encoder into the model and call it from there. + # as it is probably easier for model distribution. + state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} + + if self.args.encoder_sample_rate is not None and eval: + # audio resampler is not used in inference time + self.audio_resampler = None + + # handle fine-tuning from a checkpoint with additional speakers + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + emb_g = state["model"]["emb_g.weight"] + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) + + if eval: + self.eval() + assert not self.training + + @staticmethod + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() + + if not config.model_args.encoder_sample_rate: + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + else: + encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate + effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor + assert ( + upsample_rate == effective_hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" + + ap = AudioProcessor.init_from_config(config, verbose=verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path: + speaker_manager.init_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + + +################################## +# VITS CHARACTERS +################################## + + +class VitsCharacters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + # pylint: disable=unnecessary-comprehension + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = VitsCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + ) diff --git a/TTS/tts/utils/__init__.py b/TTS/tts/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..22e46b683adfc7f6c7c8a57fb5b697e422cd915c --- /dev/null +++ b/TTS/tts/utils/data.py @@ -0,0 +1,79 @@ +import bisect + +import numpy as np +import torch + + +def _pad_data(x, length): + _pad = 0 + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) + + +def prepare_data(inputs): + max_len = max((len(x) for x in inputs)) + return np.stack([_pad_data(x, max_len) for x in inputs]) + + +def _pad_tensor(x, length): + _pad = 0.0 + assert x.ndim == 2 + x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode="constant", constant_values=_pad) + return x + + +def prepare_tensor(inputs, out_steps): + max_len = max((x.shape[1] for x in inputs)) + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_tensor(x, pad_len) for x in inputs]) + + +def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray: + """Pad stop target array. + + Args: + x (np.ndarray): Stop target array. + length (int): Length after padding. + pad_val (int, optional): Padding value. Defaults to 1. + + Returns: + np.ndarray: Padded stop target array. + """ + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad_val) + + +def prepare_stop_target(inputs, out_steps): + """Pad row vectors with 1.""" + max_len = max((x.shape[0] for x in inputs)) + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_stop_target(x, pad_len) for x in inputs]) + + +def pad_per_step(inputs, pad_len): + return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0) + + +def get_length_balancer_weights(items: list, num_buckets=10): + # get all durations + audio_lengths = np.array([item["audio_length"] for item in items]) + # create the $num_buckets buckets classes based in the dataset max and min length + max_length = int(max(audio_lengths)) + min_length = int(min(audio_lengths)) + step = int((max_length - min_length) / num_buckets) + 1 + buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)] + # add each sample in their respective length bucket + buckets_names = np.array( + [buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items] + ) + # count and compute the weights_bucket for each sample + unique_buckets_names = np.unique(buckets_names).tolist() + bucket_ids = [unique_buckets_names.index(l) for l in buckets_names] + bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names]) + weight_bucket = 1.0 / bucket_count + dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e7f56146aa95b52824307189573bb32a17eaac --- /dev/null +++ b/TTS/tts/utils/helpers.py @@ -0,0 +1,238 @@ +import numpy as np +import torch +from torch.nn import functional as F + +try: + from TTS.tts.utils.monotonic_align.core import maximum_path_c + + CYTHON = True +except ModuleNotFoundError: + CYTHON = False + + +class StandardScaler: + """StandardScaler for mean-scale normalization with the given mean and scale values.""" + + def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None: + self.mean_ = mean + self.scale_ = scale + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, "mean_") + delattr(self, "scale_") + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def sequence_mask(sequence_length, max_len=None): + """Create a sequence mask for filtering padding in a sequence tensor. + + Args: + sequence_length (torch.tensor): Sequence lengths. + max_len (int, Optional): Maximum sequence length. Defaults to None. + + Shapes: + - mask: :math:`[B, T_max]` + """ + if max_len is None: + max_len = sequence_length.data.max() + seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) + # B x T_max + mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) + return mask + + +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): + """Segment each sample in a batch based on the provided segment indices + + Args: + x (torch.tensor): Input tensor. + segment_indices (torch.tensor): Segment indices. + segment_size (int): Expected output segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. + """ + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + + segments = torch.zeros_like(x[:, :, :segment_size]) + + for i in range(x.size(0)): + index_start = segment_indices[i] + index_end = index_start + segment_size + x_i = x[i] + if pad_short and index_end > x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] + return segments + + +def rand_segments( + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False +): + """Create random segments based on the input lengths. + + Args: + x (torch.tensor): Input tensor. + x_lengths (torch.tensor): Input lengths. + segment_size (int): Expected output segment size. + let_short_samples (bool): Allow shorter samples than the segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. + + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B]` + """ + _x_lenghts = x_lengths.clone() + B, _, T = x.size() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + 1 + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + 1 + else: + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * len_diff).long() + ret = segment(x, segment_indices, segment_size) + return ret, segment_indices + + +def average_over_durations(values, durs): + """Average values over durations. + + Shapes: + - values: :math:`[B, 1, T_de]` + - durs: :math:`[B, T_en]` + - avg: :math:`[B, 1, T_en]` + """ + durs_cums_ends = torch.cumsum(durs, dim=1).long() + durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) + values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0)) + values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0)) + + bs, l = durs_cums_ends.size() + n_formants = values.size(1) + dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) + dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) + + values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).float() + values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).float() + + avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems) + return avg + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + """ + Shapes: + - duration: :math:`[B, T_en]` + - mask: :math:'[B, T_en, T_de]` + - path: :math:`[B, T_en, T_de]` + """ + device = duration.device + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def maximum_path(value, mask): + if CYTHON: + return maximum_path_cython(value, mask) + return maximum_path_numpy(value, mask) + + +def maximum_path_cython(value, mask): + """Cython optimised version. + Shapes: + - value: :math:`[B, T_en, T_de]` + - mask: :math:`[B, T_en, T_de]` + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +def maximum_path_numpy(value, mask, max_neg_val=None): + """ + Monotonic alignment search algorithm + Numpy-friendly version. It's about 4 times faster than torch version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + if max_neg_val is None: + max_neg_val = -np.inf # Patch for Sphinx complaint + value = value * mask + + device = value.device + dtype = value.dtype + value = value.cpu().detach().numpy() + mask = mask.cpu().detach().numpy().astype(np.bool) + + b, t_x, t_y = value.shape + direction = np.zeros(value.shape, dtype=np.int64) + v = np.zeros((b, t_x), dtype=np.float32) + x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) + for j in range(t_y): + v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1] + v1 = v + max_mask = v1 >= v0 + v_max = np.where(max_mask, v1, v0) + direction[:, :, j] = max_mask + + index_mask = x_range <= j + v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) + direction = np.where(mask, direction, 1) + + path = np.zeros(value.shape, dtype=np.float32) + index = mask[:, :, 0].sum(1).astype(np.int64) - 1 + index_range = np.arange(b) + for j in reversed(range(t_y)): + path[index_range, index, j] = 1 + index = index + direction[index_range, index, j] - 1 + path = path * mask.astype(np.float32) + path = torch.from_numpy(path).to(device=device, dtype=dtype) + return path diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5e2007a1b2ee012028f0c094640ec4a114b6b0 --- /dev/null +++ b/TTS/tts/utils/languages.py @@ -0,0 +1,125 @@ +import os +from typing import Any, Dict, List + +import fsspec +import numpy as np +import torch +from coqpit import Coqpit + +from TTS.config import check_config_and_model_args +from TTS.tts.utils.managers import BaseIDManager + + +class LanguageManager(BaseIDManager): + """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by language. + + Args: + language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by + TTS models. Defaults to "". + config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed. + Defaults to None. + + Examples: + >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path) + >>> language_id_mapper = manager.language_ids + """ + + def __init__( + self, + language_ids_file_path: str = "", + config: Coqpit = None, + ): + super().__init__(id_file_path=language_ids_file_path) + + if config: + self.set_language_ids_from_config(config) + + @property + def num_languages(self) -> int: + return len(list(self.ids.keys())) + + @property + def language_names(self) -> List: + return list(self.ids.keys()) + + @staticmethod + def parse_language_ids_from_config(c: Coqpit) -> Dict: + """Set language id from config. + + Args: + c (Coqpit): Config + + Returns: + Tuple[Dict, int]: Language ID mapping and the number of languages. + """ + languages = set({}) + for dataset in c.datasets: + if "language" in dataset: + languages.add(dataset["language"]) + else: + raise ValueError(f"Dataset {dataset['name']} has no language specified.") + return {name: i for i, name in enumerate(sorted(list(languages)))} + + def set_language_ids_from_config(self, c: Coqpit) -> None: + """Set language IDs from config samples. + + Args: + c (Coqpit): Config. + """ + self.ids = self.parse_language_ids_from_config(c) + + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Any: + raise NotImplementedError + + def set_ids_from_data(self, items: List, parse_key: str) -> Any: + raise NotImplementedError + + def save_ids_to_file(self, file_path: str) -> None: + """Save language IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.ids) + + @staticmethod + def init_from_config(config: Coqpit) -> "LanguageManager": + """Initialize the language manager from a Coqpit config. + + Args: + config (Coqpit): Coqpit config. + """ + language_manager = None + if check_config_and_model_args(config, "use_language_embedding", True): + if config.get("language_ids_file", None): + language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + language_manager = LanguageManager(config=config) + return language_manager + + +def _set_file_path(path): + """Find the language_ids.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "language_ids.json") + path_continue = os.path.join(path, "language_ids.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + return None + + +def get_language_balancer_weights(items: list): + language_names = np.array([item["language"] for item in items]) + unique_language_names = np.unique(language_names).tolist() + language_ids = [unique_language_names.index(l) for l in language_names] + language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) + weight_language = 1.0 / language_count + # get weight for each sample + dataset_samples_weight = np.array([weight_language[l] for l in language_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py new file mode 100644 index 0000000000000000000000000000000000000000..0243d3b4bc0df6ebc78f4799b37f5718e7712a18 --- /dev/null +++ b/TTS/tts/utils/managers.py @@ -0,0 +1,309 @@ +import json +import random +from typing import Any, Dict, List, Tuple, Union + +import fsspec +import numpy as np +import torch + +from TTS.config import load_config +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.utils.audio import AudioProcessor + + +def load_file(path: str): + if path.endswith(".json"): + with fsspec.open(path, "r") as f: + return json.load(f) + elif path.endswith(".pth"): + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location="cpu") + else: + raise ValueError("Unsupported file type") + + +def save_file(obj: Any, path: str): + if path.endswith(".json"): + with fsspec.open(path, "w") as f: + json.dump(obj, f, indent=4) + elif path.endswith(".pth"): + with fsspec.open(path, "wb") as f: + torch.save(obj, f) + else: + raise ValueError("Unsupported file type") + + +class BaseIDManager: + """Base `ID` Manager class. Every new `ID` manager must inherit this. + It defines common `ID` manager specific functions. + """ + + def __init__(self, id_file_path: str = ""): + self.ids = {} + + if id_file_path: + self.load_ids_from_file(id_file_path) + + @staticmethod + def _load_json(json_file_path: str) -> Dict: + with fsspec.open(json_file_path, "r") as f: + return json.load(f) + + @staticmethod + def _save_json(json_file_path: str, data: dict) -> None: + with fsspec.open(json_file_path, "w") as f: + json.dump(data, f, indent=4) + + def set_ids_from_data(self, items: List, parse_key: str) -> None: + """Set IDs from data samples. + + Args: + items (List): Data sampled returned by `load_tts_samples()`. + """ + self.ids = self.parse_ids_from_data(items, parse_key=parse_key) + + def load_ids_from_file(self, file_path: str) -> None: + """Set IDs from a file. + + Args: + file_path (str): Path to the file. + """ + self.ids = load_file(file_path) + + def save_ids_to_file(self, file_path: str) -> None: + """Save IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + save_file(self.ids, file_path) + + def get_random_id(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.ids: + return self.ids[random.choices(list(self.ids.keys()))[0]] + + return None + + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]: + """Parse IDs from data samples retured by `load_tts_samples()`. + + Args: + items (list): Data sampled returned by `load_tts_samples()`. + parse_key (str): The key to being used to parse the data. + Returns: + Tuple[Dict]: speaker IDs. + """ + classes = sorted({item[parse_key] for item in items}) + ids = {name: i for i, name in enumerate(classes)} + return ids + + +class EmbeddingManager(BaseIDManager): + """Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. + It defines common `Embedding` manager specific functions. + """ + + def __init__( + self, + embedding_file_path: str = "", + id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__(id_file_path=id_file_path) + + self.embeddings = {} + self.embeddings_by_names = {} + self.clip_ids = [] + self.encoder = None + self.encoder_ap = None + self.use_cuda = use_cuda + + if embedding_file_path: + self.load_embeddings_from_file(embedding_file_path) + + if encoder_model_path and encoder_config_path: + self.init_encoder(encoder_model_path, encoder_config_path, use_cuda) + + @property + def embedding_dim(self): + """Dimensionality of embeddings. If embeddings are not loaded, returns zero.""" + if self.embeddings: + return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"]) + return 0 + + def save_embeddings_to_file(self, file_path: str) -> None: + """Save embeddings to a json file. + + Args: + file_path (str): Path to the output file. + """ + save_file(self.embeddings, file_path) + + def load_embeddings_from_file(self, file_path: str) -> None: + """Load embeddings from a json file. + + Args: + file_path (str): Path to the target json file. + """ + self.embeddings = load_file(file_path) + + speakers = sorted({x["name"] for x in self.embeddings.values()}) + self.ids = {name: i for i, name in enumerate(speakers)} + + self.clip_ids = list(set(sorted(clip_name for clip_name in self.embeddings.keys()))) + # cache embeddings_by_names for fast inference using a bigger speakers.json + self.embeddings_by_names = self.get_embeddings_by_names() + + def get_embedding_by_clip(self, clip_idx: str) -> List: + """Get embedding by clip ID. + + Args: + clip_idx (str): Target clip ID. + + Returns: + List: embedding as a list. + """ + return self.embeddings[clip_idx]["embedding"] + + def get_embeddings_by_name(self, idx: str) -> List[List]: + """Get all embeddings of a speaker. + + Args: + idx (str): Target name. + + Returns: + List[List]: all the embeddings of the given speaker. + """ + return self.embeddings_by_names[idx] + + def get_embeddings_by_names(self) -> Dict: + """Get all embeddings by names. + + Returns: + Dict: all the embeddings of each speaker. + """ + embeddings_by_names = {} + for x in self.embeddings.values(): + if x["name"] not in embeddings_by_names.keys(): + embeddings_by_names[x["name"]] = [x["embedding"]] + else: + embeddings_by_names[x["name"]].append(x["embedding"]) + return embeddings_by_names + + def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: + """Get mean embedding of a idx. + + Args: + idx (str): Target name. + num_samples (int, optional): Number of samples to be averaged. Defaults to None. + randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False. + + Returns: + np.ndarray: Mean embedding. + """ + embeddings = self.get_embeddings_by_name(idx) + if num_samples is None: + embeddings = np.stack(embeddings).mean(0) + else: + assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}" + if randomize: + embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0) + else: + embeddings = np.stack(embeddings[:num_samples]).mean(0) + return embeddings + + def get_random_embedding(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.embeddings: + return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"] + + return None + + def get_clips(self) -> List: + return sorted(self.embeddings.keys()) + + def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None: + """Initialize a speaker encoder model. + + Args: + model_path (str): Model file path. + config_path (str): Model config file path. + use_cuda (bool, optional): Use CUDA. Defaults to False. + """ + self.use_cuda = use_cuda + self.encoder_config = load_config(config_path) + self.encoder = setup_encoder_model(self.encoder_config) + self.encoder_criterion = self.encoder.load_checkpoint( + self.encoder_config, model_path, eval=True, use_cuda=use_cuda + ) + self.encoder_ap = AudioProcessor(**self.encoder_config.audio) + + def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: + """Compute a embedding from a given audio file. + + Args: + wav_file (Union[str, List[str]]): Target file path. + + Returns: + list: Computed embedding. + """ + + def _compute(wav_file: str): + waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate) + if not self.encoder_config.model_params.get("use_torch_spec", False): + m_input = self.encoder_ap.melspectrogram(waveform) + m_input = torch.from_numpy(m_input) + else: + m_input = torch.from_numpy(waveform) + + if self.use_cuda: + m_input = m_input.cuda() + m_input = m_input.unsqueeze(0) + embedding = self.encoder.compute_embedding(m_input) + return embedding + + if isinstance(wav_file, list): + # compute the mean embedding + embeddings = None + for wf in wav_file: + embedding = _compute(wf) + if embeddings is None: + embeddings = embedding + else: + embeddings += embedding + return (embeddings / len(wav_file))[0].tolist() + embedding = _compute(wav_file) + return embedding[0].tolist() + + def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List: + """Compute embedding from features. + + Args: + feats (Union[torch.Tensor, np.ndarray]): Input features. + + Returns: + List: computed embedding. + """ + if isinstance(feats, np.ndarray): + feats = torch.from_numpy(feats) + if feats.ndim == 2: + feats = feats.unsqueeze(0) + if self.use_cuda: + feats = feats.cuda() + return self.encoder.compute_embedding(feats) diff --git a/TTS/tts/utils/measures.py b/TTS/tts/utils/measures.py new file mode 100644 index 0000000000000000000000000000000000000000..90e862e1190bdb8443933580b3ff47321f70cecd --- /dev/null +++ b/TTS/tts/utils/measures.py @@ -0,0 +1,15 @@ +def alignment_diagonal_score(alignments, binary=False): + """ + Compute how diagonal alignment predictions are. It is useful + to measure the alignment consistency of a model + Args: + alignments (torch.Tensor): batch of alignments. + binary (bool): if True, ignore scores and consider attention + as a binary mask. + Shape: + - alignments : :math:`[B, T_de, T_en]` + """ + maxs = alignments.max(dim=1)[0] + if binary: + maxs[maxs > 0] = 1 + return maxs.mean(dim=1).mean(dim=0).item() diff --git a/TTS/tts/utils/monotonic_align/__init__.py b/TTS/tts/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/monotonic_align/core.pyx b/TTS/tts/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..091fcc3a50a51f3d3fee47a70825260757e6d885 --- /dev/null +++ b/TTS/tts/utils/monotonic_align/core.pyx @@ -0,0 +1,47 @@ +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/TTS/tts/utils/monotonic_align/setup.py b/TTS/tts/utils/monotonic_align/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f22bc6a35a5a04c9e6d7b82040973722c9b770c9 --- /dev/null +++ b/TTS/tts/utils/monotonic_align/setup.py @@ -0,0 +1,7 @@ +# from distutils.core import setup +# from Cython.Build import cythonize +# import numpy + +# setup(name='monotonic_align', +# ext_modules=cythonize("core.pyx"), +# include_dirs=[numpy.get_include()]) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py new file mode 100644 index 0000000000000000000000000000000000000000..77b61a8d11e5f981cb1d56720df3f597742e3614 --- /dev/null +++ b/TTS/tts/utils/speakers.py @@ -0,0 +1,226 @@ +import json +import os +from typing import Any, Dict, List, Union + +import fsspec +import numpy as np +import torch +from coqpit import Coqpit + +from TTS.config import get_from_config_or_model_args_with_default +from TTS.tts.utils.managers import EmbeddingManager + + +class SpeakerManager(EmbeddingManager): + """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by speaker or clip. + + There are 3 different scenarios considered: + + 1. Models using speaker embedding layers. The datafile only maps speaker names to ids used by the embedding layer. + 2. Models using d-vectors. The datafile includes a dictionary in the following format. + + :: + + { + 'clip_name.wav':{ + 'name': 'speakerA', + 'embedding'[] + }, + ... + } + + + 3. Computing the d-vectors by the speaker encoder. It loads the speaker encoder model and + computes the d-vectors for a given clip or speaker. + + Args: + d_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". + speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by + TTS models. Defaults to "". + encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "". + encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "". + + Examples: + >>> # load audio processor and speaker encoder + >>> ap = AudioProcessor(**config.audio) + >>> manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) + >>> # load a sample audio and compute embedding + >>> waveform = ap.load_wav(sample_wav_path) + >>> mel = ap.melspectrogram(waveform) + >>> d_vector = manager.compute_embeddings(mel.T) + """ + + def __init__( + self, + data_items: List[List[Any]] = None, + d_vectors_file_path: str = "", + speaker_id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__( + embedding_file_path=d_vectors_file_path, + id_file_path=speaker_id_file_path, + encoder_model_path=encoder_model_path, + encoder_config_path=encoder_config_path, + use_cuda=use_cuda, + ) + + if data_items: + self.set_ids_from_data(data_items, parse_key="speaker_name") + + @property + def num_speakers(self): + return len(self.ids) + + @property + def speaker_names(self): + return list(self.ids.keys()) + + def get_speakers(self) -> List: + return self.ids + + @staticmethod + def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager": + """Initialize a speaker manager from config + + Args: + config (Coqpit): Config object. + samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names. + Defaults to None. + + Returns: + SpeakerEncoder: Speaker encoder object. + """ + speaker_manager = None + if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False): + if samples: + speaker_manager = SpeakerManager(data_items=samples) + if get_from_config_or_model_args_with_default(config, "speaker_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None) + ) + + if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False): + speaker_manager = SpeakerManager() + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "d_vector_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None) + ) + return speaker_manager + + +def _set_file_path(path): + """Find the speakers.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "speakers.json") + path_continue = os.path.join(path, "speakers.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + raise FileNotFoundError(f" [!] `speakers.json` not found in {path}") + + +def load_speaker_mapping(out_path): + """Loads speaker mapping if already present.""" + if os.path.splitext(out_path)[1] == ".json": + json_file = out_path + else: + json_file = _set_file_path(out_path) + with fsspec.open(json_file, "r") as f: + return json.load(f) + + +def save_speaker_mapping(out_path, speaker_mapping): + """Saves speaker mapping if not yet present.""" + if out_path is not None: + speakers_json_path = _set_file_path(out_path) + with fsspec.open(speakers_json_path, "w") as f: + json.dump(speaker_mapping, f, indent=4) + + +def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> SpeakerManager: + """Initiate a `SpeakerManager` instance by the provided config. + + Args: + c (Coqpit): Model configuration. + restore_path (str): Path to a previous training folder. + data (List): Data samples used in training to infer speakers from. It must be provided if speaker embedding + layers is used. Defaults to None. + out_path (str, optional): Save the generated speaker IDs to a output path. Defaults to None. + + Returns: + SpeakerManager: initialized and ready to use instance. + """ + speaker_manager = SpeakerManager() + if c.use_speaker_embedding: + if data is not None: + speaker_manager.set_ids_from_data(data, parse_key="speaker_name") + if restore_path: + speakers_file = _set_file_path(restore_path) + # restoring speaker manager from a previous run. + if c.use_d_vector_file: + # restore speaker manager with the embedding file + if not os.path.exists(speakers_file): + print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file") + if not os.path.exists(c.d_vector_file): + raise RuntimeError( + "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" + ) + speaker_manager.load_embeddings_from_file(c.d_vector_file) + speaker_manager.load_embeddings_from_file(speakers_file) + elif not c.use_d_vector_file: # restor speaker manager with speaker ID file. + speaker_ids_from_data = speaker_manager.ids + speaker_manager.load_ids_from_file(speakers_file) + assert all( + speaker in speaker_manager.ids for speaker in speaker_ids_from_data + ), " [!] You cannot introduce new speakers to a pre-trained model." + elif c.use_d_vector_file and c.d_vector_file: + # new speaker manager with external speaker embeddings. + speaker_manager.load_embeddings_from_file(c.d_vector_file) + elif c.use_d_vector_file and not c.d_vector_file: + raise "use_d_vector_file is True, so you need pass a external speaker embedding file." + elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file: + # new speaker manager with speaker IDs file. + speaker_manager.load_ids_from_file(c.speakers_file) + + if speaker_manager.num_speakers > 0: + print( + " > Speaker manager is loaded with {} speakers: {}".format( + speaker_manager.num_speakers, ", ".join(speaker_manager.ids) + ) + ) + + # save file if path is defined + if out_path: + out_file_path = os.path.join(out_path, "speakers.json") + print(f" > Saving `speakers.json` to {out_file_path}.") + if c.use_d_vector_file and c.d_vector_file: + speaker_manager.save_embeddings_to_file(out_file_path) + else: + speaker_manager.save_ids_to_file(out_file_path) + return speaker_manager + + +def get_speaker_balancer_weights(items: list): + speaker_names = np.array([item["speaker_name"] for item in items]) + unique_speaker_names = np.unique(speaker_names).tolist() + speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] + speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) + weight_speaker = 1.0 / speaker_count + dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..ab2c69914e70a5321b998ad6587b3190d925890d --- /dev/null +++ b/TTS/tts/utils/ssim.py @@ -0,0 +1,73 @@ +# taken from https://github.com/Po-Hsun-Su/pytorch-ssim + +from math import exp + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + # TODO: check if you need AMP disabled + # with torch.cuda.amp.autocast(enabled=False): + mu1_sq = mu1.float().pow(2) + mu2_sq = mu2.float().pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel).type_as(img1) + window = window.type_as(img1) + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py new file mode 100644 index 0000000000000000000000000000000000000000..a74300dc948a4eeb88cdf9e936825203d68da9ca --- /dev/null +++ b/TTS/tts/utils/synthesis.py @@ -0,0 +1,319 @@ +from typing import Dict + +import numpy as np +import torch +from torch import nn + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +def compute_style_mel(style_wav, ap, cuda=False): + style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) + if cuda: + return style_mel.cuda() + return style_mel + + +def run_model_torch( + model: nn.Module, + inputs: torch.Tensor, + speaker_id: int = None, + style_mel: torch.Tensor = None, + style_text: str = None, + d_vector: torch.Tensor = None, + language_id: torch.Tensor = None, +) -> Dict: + """Run a torch model for inference. It does not support batch inference. + + Args: + model (nn.Module): The model to run inference. + inputs (torch.Tensor): Input tensor with character ids. + speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None. + style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None. + d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None. + + Returns: + Dict: model outputs. + """ + input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) + if hasattr(model, "module"): + _func = model.module.inference + else: + _func = model.inference + outputs = _func( + inputs, + aux_input={ + "x_lengths": input_lengths, + "speaker_ids": speaker_id, + "d_vectors": d_vector, + "style_mel": style_mel, + "style_text": style_text, + "language_ids": language_id, + }, + ) + return outputs + + +def trim_silence(wav, ap): + return wav[: ap.find_endpoint(wav)] + + +def inv_spectrogram(postnet_output, ap, CONFIG): + if CONFIG.model.lower() in ["tacotron"]: + wav = ap.inv_spectrogram(postnet_output.T) + else: + wav = ap.inv_melspectrogram(postnet_output.T) + return wav + + +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) + if cuda: + return aux_id.cuda() + return aux_id + + +def embedding_to_torch(d_vector, cuda=False): + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0) + if cuda: + return d_vector.cuda() + return d_vector + + +# TODO: perform GL with pytorch for batching +def apply_griffin_lim(inputs, input_lens, CONFIG, ap): + """Apply griffin-lim to each sample iterating throught the first dimension. + Args: + inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size. + input_lens (Tensor or np.Array): 1D array of sample lengths. + CONFIG (Dict): TTS config. + ap (AudioProcessor): TTS audio processor. + """ + wavs = [] + for idx, spec in enumerate(inputs): + wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding + wav = inv_spectrogram(spec, ap, CONFIG) + # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}" + wavs.append(wav[:wav_len]) + return wavs + + +def synthesis( + model, + text, + CONFIG, + use_cuda, + speaker_id=None, + style_wav=None, + style_text=None, + use_griffin_lim=False, + do_trim_silence=False, + d_vector=None, + language_id=None, +): + """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to + the vocoder model. + + Args: + model (TTS.tts.models): + The TTS model to synthesize audio with. + + text (str): + The input text to convert to speech. + + CONFIG (Coqpit): + Model configuration. + + use_cuda (bool): + Enable/disable CUDA. + + speaker_id (int): + Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + style_wav (str | Dict[str, float]): + Path or tensor to/of a waveform used for computing the style embedding based on GST or Capacitron. + Defaults to None, meaning that Capacitron models will sample from the prior distribution to + generate random but realistic prosody. + + style_text (str): + Transcription of style_wav for Capacitron models. Defaults to None. + + enable_eos_bos_chars (bool): + enable special chars for end of sentence and start of sentence. Defaults to False. + + do_trim_silence (bool): + trim silence after synthesis. Defaults to False. + + d_vector (torch.Tensor): + d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + language_id (int): + Language ID passed to the language embedding layer in multi-langual model. Defaults to None. + """ + # GST or Capacitron processing + # TODO: need to handle the case of setting both gst and capacitron to true somewhere + style_mel = None + if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: + if isinstance(style_wav, dict): + style_mel = style_wav + else: + style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + + if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None: + style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = style_mel.transpose(1, 2) # [1, time, depth] + + # convert text to sequence of token IDs + text_inputs = np.asarray( + model.tokenizer.text_to_ids(text, language=language_id), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + + if language_id is not None: + language_id = id_to_torch(language_id, cuda=use_cuda) + + if not isinstance(style_mel, dict): + # GST or Capacitron style mel + style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + if style_text is not None: + style_text = np.asarray( + model.tokenizer.text_to_ids(style_text, language=language_id), + dtype=np.int32, + ) + style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda) + style_text = style_text.unsqueeze(0) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) + text_inputs = text_inputs.unsqueeze(0) + # synthesize voice + outputs = run_model_torch( + model, + text_inputs, + speaker_id, + style_mel, + style_text, + d_vector=d_vector, + language_id=language_id, + ) + model_outputs = outputs["model_outputs"] + model_outputs = model_outputs[0].data.cpu().numpy() + alignments = outputs["alignments"] + + # convert outputs to numpy + # plot results + wav = None + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] + if use_griffin_lim: + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + +def transfer_voice( + model, + CONFIG, + use_cuda, + reference_wav, + speaker_id=None, + d_vector=None, + reference_speaker_id=None, + reference_d_vector=None, + do_trim_silence=False, + use_griffin_lim=False, +): + """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to + the vocoder model. + + Args: + model (TTS.tts.models): + The TTS model to synthesize audio with. + + CONFIG (Coqpit): + Model configuration. + + use_cuda (bool): + Enable/disable CUDA. + + reference_wav (str): + Path of reference_wav to be used to voice conversion. + + speaker_id (int): + Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + d_vector (torch.Tensor): + d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + reference_speaker_id (int): + Reference Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + reference_d_vector (torch.Tensor): + Reference d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + enable_eos_bos_chars (bool): + enable special chars for end of sentence and start of sentence. Defaults to False. + + do_trim_silence (bool): + trim silence after synthesis. Defaults to False. + """ + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + + if reference_d_vector is not None: + reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda) + + # load reference_wav audio + reference_wav = embedding_to_torch(model.ap.load_wav(reference_wav, sr=model.ap.sample_rate), cuda=use_cuda) + + if hasattr(model, "module"): + _func = model.module.inference_voice_conversion + else: + _func = model.inference_voice_conversion + model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector) + + # convert outputs to numpy + # plot results + wav = None + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] + if use_griffin_lim: + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs + + return wav diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..593372dc7cb2fba240eb5f08e8e2cfae5a4b4e45 --- /dev/null +++ b/TTS/tts/utils/text/__init__.py @@ -0,0 +1 @@ +from TTS.tts.utils.text.tokenizer import TTSTokenizer diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py new file mode 100644 index 0000000000000000000000000000000000000000..1b375e4fca38c29d7929ddf65a3c2932e2168992 --- /dev/null +++ b/TTS/tts/utils/text/characters.py @@ -0,0 +1,468 @@ +from dataclasses import replace +from typing import Dict + +from TTS.tts.configs.shared_configs import CharactersConfig + + +def parse_symbols(): + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } + + +# DEFAULT SET OF GRAPHEMES +_pad = "" +_eos = "" +_bos = "" +_blank = "" # TODO: check if we need this alongside with PAD +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_punctuations = "!'(),-.:;? " + + +# DEFAULT SET OF IPA PHONEMES +# Phonemes definition (All IPA characters) +_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" +_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" +_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" +_suprasegmentals = "ˈˌːˑ" +_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" +_diacrilics = "ɚ˞ɫ" +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics + + +class BaseVocabulary: + """Base Vocabulary class. + + This class only needs a vocabulary dictionary without specifying the characters. + + Args: + vocab (Dict): A dictionary of characters and their corresponding indices. + """ + + def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None): + self.vocab = vocab + self.pad = pad + self.blank = blank + self.bos = bos + self.eos = eos + + @property + def pad_id(self) -> int: + """Return the index of the padding character. If the padding character is not specified, return the length + of the vocabulary.""" + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + """Return the index of the blank character. If the blank character is not specified, return the length of + the vocabulary.""" + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab): + """Set the vocabulary dictionary and character mapping dictionaries.""" + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension + } + + @staticmethod + def init_from_config(config, **kwargs): + """Initialize from the given config.""" + if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict: + return ( + BaseVocabulary( + config.characters.vocab_dict, + config.characters.pad, + config.characters.blank, + config.characters.bos, + config.characters.eos, + ), + config, + ) + return BaseVocabulary(**kwargs), config + + @property + def num_chars(self): + """Return number of tokens in the vocabulary.""" + return len(self._vocab) + + def char_to_id(self, char: str) -> int: + """Map a character to an token ID.""" + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + """Map an token ID to a character.""" + return self._id_to_char[idx] + + +class BaseCharacters: + """🐸BaseCharacters class + + Every new character class should inherit from this. + + Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```. + + If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method. + + Args: + characters (str): + Main set of characters to be used in the vocabulary. + + punctuations (str): + Characters to be treated as punctuation. + + pad (str): + Special padding character that would be ignored by the model. + + eos (str): + End of the sentence character. + + bos (str): + Beginning of the sentence character. + + blank (str): + Optional character used between characters by some models for better prosody. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + el + is_sorted (bool): + Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True. + """ + + def __init__( + self, + characters: str = None, + punctuations: str = None, + pad: str = None, + eos: str = None, + bos: str = None, + blank: str = None, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + self._characters = characters + self._punctuations = punctuations + self._pad = pad + self._eos = eos + self._bos = bos + self._blank = blank + self.is_unique = is_unique + self.is_sorted = is_sorted + self._create_vocab() + + @property + def pad_id(self) -> int: + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, characters): + self._characters = characters + self._create_vocab() + + @property + def punctuations(self): + return self._punctuations + + @punctuations.setter + def punctuations(self, punctuations): + self._punctuations = punctuations + self._create_vocab() + + @property + def pad(self): + return self._pad + + @pad.setter + def pad(self, pad): + self._pad = pad + self._create_vocab() + + @property + def eos(self): + return self._eos + + @eos.setter + def eos(self, eos): + self._eos = eos + self._create_vocab() + + @property + def bos(self): + return self._bos + + @bos.setter + def bos(self, bos): + self._bos = bos + self._create_vocab() + + @property + def blank(self): + return self._blank + + @blank.setter + def blank(self, blank): + self._blank = blank + self._create_vocab() + + @property + def vocab(self): + return self._vocab + + @vocab.setter + def vocab(self, vocab): + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension + } + + @property + def num_chars(self): + return len(self._vocab) + + def _create_vocab(self): + _vocab = self._characters + if self.is_unique: + _vocab = list(set(_vocab)) + if self.is_sorted: + _vocab = sorted(_vocab) + _vocab = list(_vocab) + _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab + _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab + _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab + _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab + self.vocab = _vocab + list(self._punctuations) + if self.is_unique: + duplicates = {x for x in self.vocab if self.vocab.count(x) > 1} + assert ( + len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) + ), f" [!] There are duplicate characters in the character set. {duplicates}" + + def char_to_id(self, char: str) -> int: + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + return self._id_to_char[idx] + + def print_log(self, level: int = 0): + """ + Prints the vocabulary in a nice format. + """ + indent = "\t" * level + print(f"{indent}| > Characters: {self._characters}") + print(f"{indent}| > Punctuations: {self._punctuations}") + print(f"{indent}| > Pad: {self._pad}") + print(f"{indent}| > EOS: {self._eos}") + print(f"{indent}| > BOS: {self._bos}") + print(f"{indent}| > Blank: {self._blank}") + print(f"{indent}| > Vocab: {self.vocab}") + print(f"{indent}| > Num chars: {self.num_chars}") + + @staticmethod + def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument + """Init your character class from a config. + + Implement this method for your subclass. + """ + # use character set from config + if config.characters is not None: + return BaseCharacters(**config.characters), config + # return default character set + characters = BaseCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=self._eos, + bos=self._bos, + blank=self._blank, + is_unique=self.is_unique, + is_sorted=self.is_sorted, + ) + + +class IPAPhonemes(BaseCharacters): + """🐸IPAPhonemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using IPAPhonemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _phonemes, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a IPAPhonemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + # band-aid for compatibility with old models + if "characters" in config and config.characters is not None: + if "phonemes" in config.characters and config.characters.phonemes is not None: + config.characters["characters"] = config.characters["phonemes"] + return ( + IPAPhonemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + # use character set from config + if config.characters is not None: + return IPAPhonemes(**config.characters), config + # return default character set + characters = IPAPhonemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +class Graphemes(BaseCharacters): + """🐸Graphemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using graphemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a Graphemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + if config.characters is not None: + # band-aid for compatibility with old models + if "phonemes" in config.characters: + return ( + Graphemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + return Graphemes(**config.characters), config + characters = Graphemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +if __name__ == "__main__": + gr = Graphemes() + ph = IPAPhonemes() + gr.print_log() + ph.print_log() diff --git a/TTS/tts/utils/text/chinese_mandarin/__init__.py b/TTS/tts/utils/text/chinese_mandarin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/chinese_mandarin/numbers.py b/TTS/tts/utils/text/chinese_mandarin/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..4787ea61007656819eb57d52d5865b38c7afa915 --- /dev/null +++ b/TTS/tts/utils/text/chinese_mandarin/numbers.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Licensed under WTFPL or the Unlicense or CC0. +# This uses Python 3, but it's easy to port to Python 2 by changing +# strings to u'xx'. + +import itertools +import re + + +def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: + """Convert numerical arabic numbers (0->9) to chinese hanzi numbers (〇 -> 九) + + Args: + num (str): arabic number to convert + big (bool, optional): use financial characters. Defaults to False. + simp (bool, optional): use simplified characters instead of tradictional characters. Defaults to True. + o (bool, optional): use 〇 for 'zero'. Defaults to False. + twoalt (bool, optional): use 两/兩 for 'two' when appropriate. Defaults to False. + + Raises: + ValueError: if number is more than 1e48 + ValueError: if 'e' exposent in number + + Returns: + str: converted number as hanzi characters + """ + + # check num first + nd = str(num) + if abs(float(nd)) >= 1e48: + raise ValueError("number out of range") + if "e" in nd: + raise ValueError("scientific notation is not supported") + c_symbol = "正负点" if simp else "正負點" + if o: # formal + twoalt = False + if big: + c_basic = "零壹贰叁肆伍陆柒捌玖" if simp else "零壹貳參肆伍陸柒捌玖" + c_unit1 = "拾佰仟" + c_twoalt = "贰" if simp else "貳" + else: + c_basic = "〇一二三四五六七八九" if o else "零一二三四五六七八九" + c_unit1 = "十百千" + if twoalt: + c_twoalt = "两" if simp else "兩" + else: + c_twoalt = "二" + c_unit2 = "万亿兆京垓秭穰沟涧正载" if simp else "萬億兆京垓秭穰溝澗正載" + revuniq = lambda l: "".join(k for k, g in itertools.groupby(reversed(l))) + nd = str(num) + result = [] + if nd[0] == "+": + result.append(c_symbol[0]) + elif nd[0] == "-": + result.append(c_symbol[1]) + if "." in nd: + integer, remainder = nd.lstrip("+-").split(".") + else: + integer, remainder = nd.lstrip("+-"), None + if int(integer): + splitted = [integer[max(i - 4, 0) : i] for i in range(len(integer), 0, -4)] + intresult = [] + for nu, unit in enumerate(splitted): + # special cases + if int(unit) == 0: # 0000 + intresult.append(c_basic[0]) + continue + if nu > 0 and int(unit) == 2: # 0002 + intresult.append(c_twoalt + c_unit2[nu - 1]) + continue + ulist = [] + unit = unit.zfill(4) + for nc, ch in enumerate(reversed(unit)): + if ch == "0": + if ulist: # ???0 + ulist.append(c_basic[0]) + elif nc == 0: + ulist.append(c_basic[int(ch)]) + elif nc == 1 and ch == "1" and unit[1] == "0": + # special case for tens + # edit the 'elif' if you don't like + # 十四, 三千零十四, 三千三百一十四 + ulist.append(c_unit1[0]) + elif nc > 1 and ch == "2": + ulist.append(c_twoalt + c_unit1[nc - 1]) + else: + ulist.append(c_basic[int(ch)] + c_unit1[nc - 1]) + ustr = revuniq(ulist) + if nu == 0: + intresult.append(ustr) + else: + intresult.append(ustr + c_unit2[nu - 1]) + result.append(revuniq(intresult).strip(c_basic[0])) + else: + result.append(c_basic[0]) + if remainder: + result.append(c_symbol[2]) + result.append("".join(c_basic[int(ch)] for ch in remainder)) + return "".join(result) + + +def _number_replace(match) -> str: + """function to apply in a match, transform all numbers in a match by chinese characters + + Args: + match (re.Match): numbers regex matches + + Returns: + str: replaced characters for the numbers + """ + match_str: str = match.group() + return _num2chinese(match_str) + + +def replace_numbers_to_characters_in_text(text: str) -> str: + """Replace all arabic numbers in a text by their equivalent in chinese characters (simplified) + + Args: + text (str): input text to transform + + Returns: + str: output text + """ + text = re.sub(r"[0-9]+", _number_replace, text) + return text diff --git a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..727c881e1062badc57df7418aa07e7434d57335c --- /dev/null +++ b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py @@ -0,0 +1,37 @@ +from typing import List + +import jieba +import pypinyin + +from .pinyinToPhonemes import PINYIN_DICT + + +def _chinese_character_to_pinyin(text: str) -> List[str]: + pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True) + pinyins_flat_list = [item for sublist in pinyins for item in sublist] + return pinyins_flat_list + + +def _chinese_pinyin_to_phoneme(pinyin: str) -> str: + segment = pinyin[:-1] + tone = pinyin[-1] + phoneme = PINYIN_DICT.get(segment, [""])[0] + return phoneme + tone + + +def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str: + tokenized_text = jieba.cut(text, HMM=False) + tokenized_text = " ".join(tokenized_text) + pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text) + + results: List[str] = [] + + for token in pinyined_text: + if token[-1] in "12345": # TODO transform to is_pinyin() + pinyin_phonemes = _chinese_pinyin_to_phoneme(token) + + results += list(pinyin_phonemes) + else: # is ponctuation or other + results += list(token) + + return seperator.join(results) diff --git a/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py b/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py new file mode 100644 index 0000000000000000000000000000000000000000..4e25c3a4c91cddd0bf0e5d6e273262e3dbd3a2dd --- /dev/null +++ b/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py @@ -0,0 +1,419 @@ +PINYIN_DICT = { + "a": ["a"], + "ai": ["ai"], + "an": ["an"], + "ang": ["ɑŋ"], + "ao": ["aʌ"], + "ba": ["ba"], + "bai": ["bai"], + "ban": ["ban"], + "bang": ["bɑŋ"], + "bao": ["baʌ"], + # "be": ["be"], doesnt exist + "bei": ["bɛi"], + "ben": ["bœn"], + "beng": ["bɵŋ"], + "bi": ["bi"], + "bian": ["biɛn"], + "biao": ["biaʌ"], + "bie": ["bie"], + "bin": ["bin"], + "bing": ["bɨŋ"], + "bo": ["bo"], + "bu": ["bu"], + "ca": ["tsa"], + "cai": ["tsai"], + "can": ["tsan"], + "cang": ["tsɑŋ"], + "cao": ["tsaʌ"], + "ce": ["tsø"], + "cen": ["tsœn"], + "ceng": ["tsɵŋ"], + "cha": ["ʈʂa"], + "chai": ["ʈʂai"], + "chan": ["ʈʂan"], + "chang": ["ʈʂɑŋ"], + "chao": ["ʈʂaʌ"], + "che": ["ʈʂø"], + "chen": ["ʈʂœn"], + "cheng": ["ʈʂɵŋ"], + "chi": ["ʈʂʏ"], + "chong": ["ʈʂoŋ"], + "chou": ["ʈʂou"], + "chu": ["ʈʂu"], + "chua": ["ʈʂua"], + "chuai": ["ʈʂuai"], + "chuan": ["ʈʂuan"], + "chuang": ["ʈʂuɑŋ"], + "chui": ["ʈʂuei"], + "chun": ["ʈʂun"], + "chuo": ["ʈʂuo"], + "ci": ["tsɪ"], + "cong": ["tsoŋ"], + "cou": ["tsou"], + "cu": ["tsu"], + "cuan": ["tsuan"], + "cui": ["tsuei"], + "cun": ["tsun"], + "cuo": ["tsuo"], + "da": ["da"], + "dai": ["dai"], + "dan": ["dan"], + "dang": ["dɑŋ"], + "dao": ["daʌ"], + "de": ["dø"], + "dei": ["dei"], + # "den": ["dœn"], + "deng": ["dɵŋ"], + "di": ["di"], + "dia": ["dia"], + "dian": ["diɛn"], + "diao": ["diaʌ"], + "die": ["die"], + "ding": ["dɨŋ"], + "diu": ["dio"], + "dong": ["doŋ"], + "dou": ["dou"], + "du": ["du"], + "duan": ["duan"], + "dui": ["duei"], + "dun": ["dun"], + "duo": ["duo"], + "e": ["ø"], + "ei": ["ei"], + "en": ["œn"], + # "ng": ["œn"], + # "eng": ["ɵŋ"], + "er": ["er"], + "fa": ["fa"], + "fan": ["fan"], + "fang": ["fɑŋ"], + "fei": ["fei"], + "fen": ["fœn"], + "feng": ["fɵŋ"], + "fo": ["fo"], + "fou": ["fou"], + "fu": ["fu"], + "ga": ["ga"], + "gai": ["gai"], + "gan": ["gan"], + "gang": ["gɑŋ"], + "gao": ["gaʌ"], + "ge": ["gø"], + "gei": ["gei"], + "gen": ["gœn"], + "geng": ["gɵŋ"], + "gong": ["goŋ"], + "gou": ["gou"], + "gu": ["gu"], + "gua": ["gua"], + "guai": ["guai"], + "guan": ["guan"], + "guang": ["guɑŋ"], + "gui": ["guei"], + "gun": ["gun"], + "guo": ["guo"], + "ha": ["xa"], + "hai": ["xai"], + "han": ["xan"], + "hang": ["xɑŋ"], + "hao": ["xaʌ"], + "he": ["xø"], + "hei": ["xei"], + "hen": ["xœn"], + "heng": ["xɵŋ"], + "hong": ["xoŋ"], + "hou": ["xou"], + "hu": ["xu"], + "hua": ["xua"], + "huai": ["xuai"], + "huan": ["xuan"], + "huang": ["xuɑŋ"], + "hui": ["xuei"], + "hun": ["xun"], + "huo": ["xuo"], + "ji": ["dʑi"], + "jia": ["dʑia"], + "jian": ["dʑiɛn"], + "jiang": ["dʑiɑŋ"], + "jiao": ["dʑiaʌ"], + "jie": ["dʑie"], + "jin": ["dʑin"], + "jing": ["dʑɨŋ"], + "jiong": ["dʑioŋ"], + "jiu": ["dʑio"], + "ju": ["dʑy"], + "juan": ["dʑyɛn"], + "jue": ["dʑye"], + "jun": ["dʑyn"], + "ka": ["ka"], + "kai": ["kai"], + "kan": ["kan"], + "kang": ["kɑŋ"], + "kao": ["kaʌ"], + "ke": ["kø"], + "kei": ["kei"], + "ken": ["kœn"], + "keng": ["kɵŋ"], + "kong": ["koŋ"], + "kou": ["kou"], + "ku": ["ku"], + "kua": ["kua"], + "kuai": ["kuai"], + "kuan": ["kuan"], + "kuang": ["kuɑŋ"], + "kui": ["kuei"], + "kun": ["kun"], + "kuo": ["kuo"], + "la": ["la"], + "lai": ["lai"], + "lan": ["lan"], + "lang": ["lɑŋ"], + "lao": ["laʌ"], + "le": ["lø"], + "lei": ["lei"], + "leng": ["lɵŋ"], + "li": ["li"], + "lia": ["lia"], + "lian": ["liɛn"], + "liang": ["liɑŋ"], + "liao": ["liaʌ"], + "lie": ["lie"], + "lin": ["lin"], + "ling": ["lɨŋ"], + "liu": ["lio"], + "lo": ["lo"], + "long": ["loŋ"], + "lou": ["lou"], + "lu": ["lu"], + "lv": ["ly"], + "luan": ["luan"], + "lve": ["lye"], + "lue": ["lue"], + "lun": ["lun"], + "luo": ["luo"], + "ma": ["ma"], + "mai": ["mai"], + "man": ["man"], + "mang": ["mɑŋ"], + "mao": ["maʌ"], + "me": ["mø"], + "mei": ["mei"], + "men": ["mœn"], + "meng": ["mɵŋ"], + "mi": ["mi"], + "mian": ["miɛn"], + "miao": ["miaʌ"], + "mie": ["mie"], + "min": ["min"], + "ming": ["mɨŋ"], + "miu": ["mio"], + "mo": ["mo"], + "mou": ["mou"], + "mu": ["mu"], + "na": ["na"], + "nai": ["nai"], + "nan": ["nan"], + "nang": ["nɑŋ"], + "nao": ["naʌ"], + "ne": ["nø"], + "nei": ["nei"], + "nen": ["nœn"], + "neng": ["nɵŋ"], + "ni": ["ni"], + "nia": ["nia"], + "nian": ["niɛn"], + "niang": ["niɑŋ"], + "niao": ["niaʌ"], + "nie": ["nie"], + "nin": ["nin"], + "ning": ["nɨŋ"], + "niu": ["nio"], + "nong": ["noŋ"], + "nou": ["nou"], + "nu": ["nu"], + "nv": ["ny"], + "nuan": ["nuan"], + "nve": ["nye"], + "nue": ["nye"], + "nuo": ["nuo"], + "o": ["o"], + "ou": ["ou"], + "pa": ["pa"], + "pai": ["pai"], + "pan": ["pan"], + "pang": ["pɑŋ"], + "pao": ["paʌ"], + "pe": ["pø"], + "pei": ["pei"], + "pen": ["pœn"], + "peng": ["pɵŋ"], + "pi": ["pi"], + "pian": ["piɛn"], + "piao": ["piaʌ"], + "pie": ["pie"], + "pin": ["pin"], + "ping": ["pɨŋ"], + "po": ["po"], + "pou": ["pou"], + "pu": ["pu"], + "qi": ["tɕi"], + "qia": ["tɕia"], + "qian": ["tɕiɛn"], + "qiang": ["tɕiɑŋ"], + "qiao": ["tɕiaʌ"], + "qie": ["tɕie"], + "qin": ["tɕin"], + "qing": ["tɕɨŋ"], + "qiong": ["tɕioŋ"], + "qiu": ["tɕio"], + "qu": ["tɕy"], + "quan": ["tɕyɛn"], + "que": ["tɕye"], + "qun": ["tɕyn"], + "ran": ["ʐan"], + "rang": ["ʐɑŋ"], + "rao": ["ʐaʌ"], + "re": ["ʐø"], + "ren": ["ʐœn"], + "reng": ["ʐɵŋ"], + "ri": ["ʐʏ"], + "rong": ["ʐoŋ"], + "rou": ["ʐou"], + "ru": ["ʐu"], + "rua": ["ʐua"], + "ruan": ["ʐuan"], + "rui": ["ʐuei"], + "run": ["ʐun"], + "ruo": ["ʐuo"], + "sa": ["sa"], + "sai": ["sai"], + "san": ["san"], + "sang": ["sɑŋ"], + "sao": ["saʌ"], + "se": ["sø"], + "sen": ["sœn"], + "seng": ["sɵŋ"], + "sha": ["ʂa"], + "shai": ["ʂai"], + "shan": ["ʂan"], + "shang": ["ʂɑŋ"], + "shao": ["ʂaʌ"], + "she": ["ʂø"], + "shei": ["ʂei"], + "shen": ["ʂœn"], + "sheng": ["ʂɵŋ"], + "shi": ["ʂʏ"], + "shou": ["ʂou"], + "shu": ["ʂu"], + "shua": ["ʂua"], + "shuai": ["ʂuai"], + "shuan": ["ʂuan"], + "shuang": ["ʂuɑŋ"], + "shui": ["ʂuei"], + "shun": ["ʂun"], + "shuo": ["ʂuo"], + "si": ["sɪ"], + "song": ["soŋ"], + "sou": ["sou"], + "su": ["su"], + "suan": ["suan"], + "sui": ["suei"], + "sun": ["sun"], + "suo": ["suo"], + "ta": ["ta"], + "tai": ["tai"], + "tan": ["tan"], + "tang": ["tɑŋ"], + "tao": ["taʌ"], + "te": ["tø"], + "tei": ["tei"], + "teng": ["tɵŋ"], + "ti": ["ti"], + "tian": ["tiɛn"], + "tiao": ["tiaʌ"], + "tie": ["tie"], + "ting": ["tɨŋ"], + "tong": ["toŋ"], + "tou": ["tou"], + "tu": ["tu"], + "tuan": ["tuan"], + "tui": ["tuei"], + "tun": ["tun"], + "tuo": ["tuo"], + "wa": ["wa"], + "wai": ["wai"], + "wan": ["wan"], + "wang": ["wɑŋ"], + "wei": ["wei"], + "wen": ["wœn"], + "weng": ["wɵŋ"], + "wo": ["wo"], + "wu": ["wu"], + "xi": ["ɕi"], + "xia": ["ɕia"], + "xian": ["ɕiɛn"], + "xiang": ["ɕiɑŋ"], + "xiao": ["ɕiaʌ"], + "xie": ["ɕie"], + "xin": ["ɕin"], + "xing": ["ɕɨŋ"], + "xiong": ["ɕioŋ"], + "xiu": ["ɕio"], + "xu": ["ɕy"], + "xuan": ["ɕyɛn"], + "xue": ["ɕye"], + "xun": ["ɕyn"], + "ya": ["ia"], + "yan": ["iɛn"], + "yang": ["iɑŋ"], + "yao": ["iaʌ"], + "ye": ["ie"], + "yi": ["i"], + "yin": ["in"], + "ying": ["ɨŋ"], + "yo": ["io"], + "yong": ["ioŋ"], + "you": ["io"], + "yu": ["y"], + "yuan": ["yɛn"], + "yue": ["ye"], + "yun": ["yn"], + "za": ["dza"], + "zai": ["dzai"], + "zan": ["dzan"], + "zang": ["dzɑŋ"], + "zao": ["dzaʌ"], + "ze": ["dzø"], + "zei": ["dzei"], + "zen": ["dzœn"], + "zeng": ["dzɵŋ"], + "zha": ["dʒa"], + "zhai": ["dʒai"], + "zhan": ["dʒan"], + "zhang": ["dʒɑŋ"], + "zhao": ["dʒaʌ"], + "zhe": ["dʒø"], + # "zhei": ["dʒei"], it doesn't exist + "zhen": ["dʒœn"], + "zheng": ["dʒɵŋ"], + "zhi": ["dʒʏ"], + "zhong": ["dʒoŋ"], + "zhou": ["dʒou"], + "zhu": ["dʒu"], + "zhua": ["dʒua"], + "zhuai": ["dʒuai"], + "zhuan": ["dʒuan"], + "zhuang": ["dʒuɑŋ"], + "zhui": ["dʒuei"], + "zhun": ["dʒun"], + "zhuo": ["dʒuo"], + "zi": ["dzɪ"], + "zong": ["dzoŋ"], + "zou": ["dzou"], + "zu": ["dzu"], + "zuan": ["dzuan"], + "zui": ["dzuei"], + "zun": ["dzun"], + "zuo": ["dzuo"], +} diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..f02f8fb48e23cce5ca604c0c86d3e13abeb42654 --- /dev/null +++ b/TTS/tts/utils/text/cleaners.py @@ -0,0 +1,145 @@ +"""Set of default text cleaners""" +# TODO: pick the cleaner for languages dynamically + +import re + +from anyascii import anyascii + +from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text + +from .english.abbreviations import abbreviations_en +from .english.number_norm import normalize_numbers as en_normalize_numbers +from .english.time_norm import expand_time_english +from .french.abbreviations import abbreviations_fr + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + + +def expand_abbreviations(text, lang="en"): + if lang == "en": + _abbreviations = abbreviations_en + elif lang == "fr": + _abbreviations = abbreviations_fr + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text).strip() + + +def convert_to_ascii(text): + return anyascii(text) + + +def remove_aux_symbols(text): + text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text) + return text + + +def replace_symbols(text, lang="en"): + text = text.replace(";", ",") + text = text.replace("-", " ") + text = text.replace(":", ",") + if lang == "en": + text = text.replace("&", " and ") + elif lang == "fr": + text = text.replace("&", " et ") + elif lang == "pt": + text = text.replace("&", " e ") + return text + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + # text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def basic_german_cleaners(text): + """Pipeline for German text""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +# TODO: elaborate it +def basic_turkish_cleaners(text): + """Pipeline for Turkish text""" + text = text.replace("I", "ı") + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + """Pipeline for English text, including number and abbreviation expansion.""" + # text = convert_to_ascii(text) + text = lowercase(text) + text = expand_time_english(text) + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def phoneme_cleaners(text): + """Pipeline for phonemes mode, including number and abbreviation expansion.""" + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def french_cleaners(text): + """Pipeline for French text. There is no need to expand numbers, phonemizer already does that""" + text = expand_abbreviations(text, lang="fr") + text = lowercase(text) + text = replace_symbols(text, lang="fr") + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def portuguese_cleaners(text): + """Basic pipeline for Portuguese text. There is no need to expand abbreviation and + numbers, phonemizer already does that""" + text = lowercase(text) + text = replace_symbols(text, lang="pt") + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def chinese_mandarin_cleaners(text: str) -> str: + """Basic pipeline for chinese""" + text = replace_numbers_to_characters_in_text(text) + return text + + +def multilingual_cleaners(text): + """Pipeline for multilingual text""" + text = lowercase(text) + text = replace_symbols(text, lang=None) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text diff --git a/TTS/tts/utils/text/cmudict.py b/TTS/tts/utils/text/cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..f206fb043be1d478fa6ace36fefdefa30b0acb02 --- /dev/null +++ b/TTS/tts/utils/text/cmudict.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +import re + +VALID_SYMBOLS = [ + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", +] + + +class CMUDict: + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" + + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding="latin-1") as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + def __len__(self): + return len(self._entries) + + def lookup(self, word): + """Returns list of ARPAbet pronunciations of the given word.""" + return self._entries.get(word.upper()) + + @staticmethod + def get_arpabet(word, cmudict, punctuation_symbols): + first_symbol, last_symbol = "", "" + if word and word[0] in punctuation_symbols: + first_symbol = word[0] + word = word[1:] + if word and word[-1] in punctuation_symbols: + last_symbol = word[-1] + word = word[:-1] + arpabet = cmudict.lookup(word) + if arpabet is not None: + return first_symbol + "{%s}" % arpabet[0] + last_symbol + return first_symbol + word + last_symbol + + +_alt_re = re.compile(r"\([0-9]+\)") + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if line and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(" ") + for part in parts: + if part not in VALID_SYMBOLS: + return None + return " ".join(parts) diff --git a/TTS/tts/utils/text/english/__init__.py b/TTS/tts/utils/text/english/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/english/abbreviations.py b/TTS/tts/utils/text/english/abbreviations.py new file mode 100644 index 0000000000000000000000000000000000000000..cd93c13c8ecfbc0df2d0c6d2fa348388940c213a --- /dev/null +++ b/TTS/tts/utils/text/english/abbreviations.py @@ -0,0 +1,26 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in english: +abbreviations_en = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] diff --git a/TTS/tts/utils/text/english/number_norm.py b/TTS/tts/utils/text/english/number_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e8377ede87ebc9d1bb9cffbbb290aa7787caea4f --- /dev/null +++ b/TTS/tts/utils/text/english/number_norm.py @@ -0,0 +1,97 @@ +""" from https://github.com/keithito/tacotron """ + +import re +from typing import Dict + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"-?[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def __expand_currency(value: str, inflection: Dict[float, str]) -> str: + parts = value.replace(",", "").split(".") + if len(parts) > 2: + return f"{value} {inflection[2]}" # Unexpected format + text = [] + integer = int(parts[0]) if parts[0] else 0 + if integer > 0: + integer_unit = inflection.get(integer, inflection[2]) + text.append(f"{integer} {integer_unit}") + fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if fraction > 0: + fraction_unit = inflection.get(fraction / 100, inflection[0.02]) + text.append(f"{fraction} {fraction_unit}") + if len(text) == 0: + return f"zero {inflection[2]}" + return " ".join(text) + + +def _expand_currency(m: "re.Match") -> str: + currencies = { + "$": { + 0.01: "cent", + 0.02: "cents", + 1: "dollar", + 2: "dollars", + }, + "€": { + 0.01: "cent", + 0.02: "cents", + 1: "euro", + 2: "euros", + }, + "£": { + 0.01: "penny", + 0.02: "pence", + 1: "pound sterling", + 2: "pounds sterling", + }, + "¥": { + # TODO rin + 0.02: "sen", + 2: "yen", + }, + } + unit = m.group(1) + currency = currencies[unit] + value = m.group(2) + return __expand_currency(value, currency) + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if 1000 < num < 3000: + if num == 2000: + return "two thousand" + if 2000 < num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + if num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_currency_re, _expand_currency, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/TTS/tts/utils/text/english/time_norm.py b/TTS/tts/utils/text/english/time_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ac09e79db4a239a7f72f101503dbf0d6feb3ae --- /dev/null +++ b/TTS/tts/utils/text/english/time_norm.py @@ -0,0 +1,47 @@ +import re + +import inflect + +_inflect = inflect.engine() + +_time_re = re.compile( + r"""\b + ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours + : + ([0-5][0-9]) # minutes + \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm + \b""", + re.IGNORECASE | re.X, +) + + +def _expand_num(n: int) -> str: + return _inflect.number_to_words(n) + + +def _expand_time_english(match: "re.Match") -> str: + hour = int(match.group(1)) + past_noon = hour >= 12 + time = [] + if hour > 12: + hour -= 12 + elif hour == 0: + hour = 12 + past_noon = True + time.append(_expand_num(hour)) + + minute = int(match.group(6)) + if minute > 0: + if minute < 10: + time.append("oh") + time.append(_expand_num(minute)) + am_pm = match.group(7) + if am_pm is None: + time.append("p m" if past_noon else "a m") + else: + time.extend(list(am_pm.replace(".", ""))) + return " ".join(time) + + +def expand_time_english(text: str) -> str: + return re.sub(_time_re, _expand_time_english, text) diff --git a/TTS/tts/utils/text/french/__init__.py b/TTS/tts/utils/text/french/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/french/abbreviations.py b/TTS/tts/utils/text/french/abbreviations.py new file mode 100644 index 0000000000000000000000000000000000000000..f580dfed7b4576a9f87b0a4145cb729e70050d50 --- /dev/null +++ b/TTS/tts/utils/text/french/abbreviations.py @@ -0,0 +1,48 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in french: +abbreviations_fr = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("M", "monsieur"), + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ("N.B", "nota bene"), + ("M", "monsieur"), + ("p.c.q", "parce que"), + ("Pr", "professeur"), + ("qqch", "quelque chose"), + ("rdv", "rendez-vous"), + ("max", "maximum"), + ("min", "minimum"), + ("no", "numéro"), + ("adr", "adresse"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "companie"), + ("jr", "junior"), + ("sgt", "sergent"), + ("capt", "capitain"), + ("col", "colonel"), + ("av", "avenue"), + ("av. J.-C", "avant Jésus-Christ"), + ("apr. J.-C", "après Jésus-Christ"), + ("art", "article"), + ("boul", "boulevard"), + ("c.-à-d", "c’est-à-dire"), + ("etc", "et cetera"), + ("ex", "exemple"), + ("excl", "exclusivement"), + ("boul", "boulevard"), + ] +] + [ + (re.compile("\\b%s" % x[0]), x[1]) + for x in [ + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ] +] diff --git a/TTS/tts/utils/text/japanese/__init__.py b/TTS/tts/utils/text/japanese/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..969becfdcabdff2da68cb6f9e7d098d363298faf --- /dev/null +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -0,0 +1,467 @@ +# Convert Japanese text to phonemes which is +# compatible with Julius https://github.com/julius-speech/segmentation-kit + +import re +import unicodedata + +import MeCab +from num2words import num2words + +_CONVRULES = [ + # Conversion of 2 letters + "アァ/ a a", + "イィ/ i i", + "イェ/ i e", + "イャ/ y a", + "ウゥ/ u:", + "エェ/ e e", + "オォ/ o:", + "カァ/ k a:", + "キィ/ k i:", + "クゥ/ k u:", + "クャ/ ky a", + "クュ/ ky u", + "クョ/ ky o", + "ケェ/ k e:", + "コォ/ k o:", + "ガァ/ g a:", + "ギィ/ g i:", + "グゥ/ g u:", + "グャ/ gy a", + "グュ/ gy u", + "グョ/ gy o", + "ゲェ/ g e:", + "ゴォ/ g o:", + "サァ/ s a:", + "シィ/ sh i:", + "スゥ/ s u:", + "スャ/ sh a", + "スュ/ sh u", + "スョ/ sh o", + "セェ/ s e:", + "ソォ/ s o:", + "ザァ/ z a:", + "ジィ/ j i:", + "ズゥ/ z u:", + "ズャ/ zy a", + "ズュ/ zy u", + "ズョ/ zy o", + "ゼェ/ z e:", + "ゾォ/ z o:", + "タァ/ t a:", + "チィ/ ch i:", + "ツァ/ ts a", + "ツィ/ ts i", + "ツゥ/ ts u:", + "ツャ/ ch a", + "ツュ/ ch u", + "ツョ/ ch o", + "ツェ/ ts e", + "ツォ/ ts o", + "テェ/ t e:", + "トォ/ t o:", + "ダァ/ d a:", + "ヂィ/ j i:", + "ヅゥ/ d u:", + "ヅャ/ zy a", + "ヅュ/ zy u", + "ヅョ/ zy o", + "デェ/ d e:", + "ドォ/ d o:", + "ナァ/ n a:", + "ニィ/ n i:", + "ヌゥ/ n u:", + "ヌャ/ ny a", + "ヌュ/ ny u", + "ヌョ/ ny o", + "ネェ/ n e:", + "ノォ/ n o:", + "ハァ/ h a:", + "ヒィ/ h i:", + "フゥ/ f u:", + "フャ/ hy a", + "フュ/ hy u", + "フョ/ hy o", + "ヘェ/ h e:", + "ホォ/ h o:", + "バァ/ b a:", + "ビィ/ b i:", + "ブゥ/ b u:", + "フャ/ hy a", + "ブュ/ by u", + "フョ/ hy o", + "ベェ/ b e:", + "ボォ/ b o:", + "パァ/ p a:", + "ピィ/ p i:", + "プゥ/ p u:", + "プャ/ py a", + "プュ/ py u", + "プョ/ py o", + "ペェ/ p e:", + "ポォ/ p o:", + "マァ/ m a:", + "ミィ/ m i:", + "ムゥ/ m u:", + "ムャ/ my a", + "ムュ/ my u", + "ムョ/ my o", + "メェ/ m e:", + "モォ/ m o:", + "ヤァ/ y a:", + "ユゥ/ y u:", + "ユャ/ y a:", + "ユュ/ y u:", + "ユョ/ y o:", + "ヨォ/ y o:", + "ラァ/ r a:", + "リィ/ r i:", + "ルゥ/ r u:", + "ルャ/ ry a", + "ルュ/ ry u", + "ルョ/ ry o", + "レェ/ r e:", + "ロォ/ r o:", + "ワァ/ w a:", + "ヲォ/ o:", + "ディ/ d i", + "デェ/ d e:", + "デャ/ dy a", + "デュ/ dy u", + "デョ/ dy o", + "ティ/ t i", + "テェ/ t e:", + "テャ/ ty a", + "テュ/ ty u", + "テョ/ ty o", + "スィ/ s i", + "ズァ/ z u a", + "ズィ/ z i", + "ズゥ/ z u", + "ズャ/ zy a", + "ズュ/ zy u", + "ズョ/ zy o", + "ズェ/ z e", + "ズォ/ z o", + "キャ/ ky a", + "キュ/ ky u", + "キョ/ ky o", + "シャ/ sh a", + "シュ/ sh u", + "シェ/ sh e", + "ショ/ sh o", + "チャ/ ch a", + "チュ/ ch u", + "チェ/ ch e", + "チョ/ ch o", + "トゥ/ t u", + "トャ/ ty a", + "トュ/ ty u", + "トョ/ ty o", + "ドァ/ d o a", + "ドゥ/ d u", + "ドャ/ dy a", + "ドュ/ dy u", + "ドョ/ dy o", + "ドォ/ d o:", + "ニャ/ ny a", + "ニュ/ ny u", + "ニョ/ ny o", + "ヒャ/ hy a", + "ヒュ/ hy u", + "ヒョ/ hy o", + "ミャ/ my a", + "ミュ/ my u", + "ミョ/ my o", + "リャ/ ry a", + "リュ/ ry u", + "リョ/ ry o", + "ギャ/ gy a", + "ギュ/ gy u", + "ギョ/ gy o", + "ヂェ/ j e", + "ヂャ/ j a", + "ヂュ/ j u", + "ヂョ/ j o", + "ジェ/ j e", + "ジャ/ j a", + "ジュ/ j u", + "ジョ/ j o", + "ビャ/ by a", + "ビュ/ by u", + "ビョ/ by o", + "ピャ/ py a", + "ピュ/ py u", + "ピョ/ py o", + "ウァ/ u a", + "ウィ/ w i", + "ウェ/ w e", + "ウォ/ w o", + "ファ/ f a", + "フィ/ f i", + "フゥ/ f u", + "フャ/ hy a", + "フュ/ hy u", + "フョ/ hy o", + "フェ/ f e", + "フォ/ f o", + "ヴァ/ b a", + "ヴィ/ b i", + "ヴェ/ b e", + "ヴォ/ b o", + "ヴュ/ by u", + # Conversion of 1 letter + "ア/ a", + "イ/ i", + "ウ/ u", + "エ/ e", + "オ/ o", + "カ/ k a", + "キ/ k i", + "ク/ k u", + "ケ/ k e", + "コ/ k o", + "サ/ s a", + "シ/ sh i", + "ス/ s u", + "セ/ s e", + "ソ/ s o", + "タ/ t a", + "チ/ ch i", + "ツ/ ts u", + "テ/ t e", + "ト/ t o", + "ナ/ n a", + "ニ/ n i", + "ヌ/ n u", + "ネ/ n e", + "ノ/ n o", + "ハ/ h a", + "ヒ/ h i", + "フ/ f u", + "ヘ/ h e", + "ホ/ h o", + "マ/ m a", + "ミ/ m i", + "ム/ m u", + "メ/ m e", + "モ/ m o", + "ラ/ r a", + "リ/ r i", + "ル/ r u", + "レ/ r e", + "ロ/ r o", + "ガ/ g a", + "ギ/ g i", + "グ/ g u", + "ゲ/ g e", + "ゴ/ g o", + "ザ/ z a", + "ジ/ j i", + "ズ/ z u", + "ゼ/ z e", + "ゾ/ z o", + "ダ/ d a", + "ヂ/ j i", + "ヅ/ z u", + "デ/ d e", + "ド/ d o", + "バ/ b a", + "ビ/ b i", + "ブ/ b u", + "ベ/ b e", + "ボ/ b o", + "パ/ p a", + "ピ/ p i", + "プ/ p u", + "ペ/ p e", + "ポ/ p o", + "ヤ/ y a", + "ユ/ y u", + "ヨ/ y o", + "ワ/ w a", + "ヰ/ i", + "ヱ/ e", + "ヲ/ o", + "ン/ N", + "ッ/ q", + "ヴ/ b u", + "ー/:", + # Try converting broken text + "ァ/ a", + "ィ/ i", + "ゥ/ u", + "ェ/ e", + "ォ/ o", + "ヮ/ w a", + "ォ/ o", + # Symbols + "、/ ,", + "。/ .", + "!/ !", + "?/ ?", + "・/ ,", +] + +_COLON_RX = re.compile(":+") +_REJECT_RX = re.compile("[^ a-zA-Z:,.?]") + + +def _makerulemap(): + l = [tuple(x.split("/")) for x in _CONVRULES] + return tuple({k: v for k, v in l if len(k) == i} for i in (1, 2)) + + +_RULEMAP1, _RULEMAP2 = _makerulemap() + + +def kata2phoneme(text: str) -> str: + """Convert katakana text to phonemes.""" + text = text.strip() + res = "" + while text: + if len(text) >= 2: + x = _RULEMAP2.get(text[:2]) + if x is not None: + text = text[2:] + res += x + continue + x = _RULEMAP1.get(text[0]) + if x is not None: + text = text[1:] + res += x + continue + res += " " + text[0] + text = text[1:] + res = _COLON_RX.sub(":", res) + return res[1:] + + +_KATAKANA = "".join(chr(ch) for ch in range(ord("ァ"), ord("ン") + 1)) +_HIRAGANA = "".join(chr(ch) for ch in range(ord("ぁ"), ord("ん") + 1)) +_HIRA2KATATRANS = str.maketrans(_HIRAGANA, _KATAKANA) + + +def hira2kata(text: str) -> str: + text = text.translate(_HIRA2KATATRANS) + return text.replace("う゛", "ヴ") + + +_SYMBOL_TOKENS = set(list("・、。?!")) +_NO_YOMI_TOKENS = set(list("「」『』―()[][] …")) +_TAGGER = MeCab.Tagger() + + +def text2kata(text: str) -> str: + parsed = _TAGGER.parse(text) + res = [] + for line in parsed.split("\n"): + if line == "EOS": + break + parts = line.split("\t") + + word, yomi = parts[0], parts[1] + if yomi: + res.append(yomi) + else: + if word in _SYMBOL_TOKENS: + res.append(word) + elif word in ("っ", "ッ"): + res.append("ッ") + elif word in _NO_YOMI_TOKENS: + pass + else: + res.append(word) + return hira2kata("".join(res)) + + +_ALPHASYMBOL_YOMI = { + "#": "シャープ", + "%": "パーセント", + "&": "アンド", + "+": "プラス", + "-": "マイナス", + ":": "コロン", + ";": "セミコロン", + "<": "小なり", + "=": "イコール", + ">": "大なり", + "@": "アット", + "a": "エー", + "b": "ビー", + "c": "シー", + "d": "ディー", + "e": "イー", + "f": "エフ", + "g": "ジー", + "h": "エイチ", + "i": "アイ", + "j": "ジェー", + "k": "ケー", + "l": "エル", + "m": "エム", + "n": "エヌ", + "o": "オー", + "p": "ピー", + "q": "キュー", + "r": "アール", + "s": "エス", + "t": "ティー", + "u": "ユー", + "v": "ブイ", + "w": "ダブリュー", + "x": "エックス", + "y": "ワイ", + "z": "ゼット", + "α": "アルファ", + "β": "ベータ", + "γ": "ガンマ", + "δ": "デルタ", + "ε": "イプシロン", + "ζ": "ゼータ", + "η": "イータ", + "θ": "シータ", + "ι": "イオタ", + "κ": "カッパ", + "λ": "ラムダ", + "μ": "ミュー", + "ν": "ニュー", + "ξ": "クサイ", + "ο": "オミクロン", + "π": "パイ", + "ρ": "ロー", + "σ": "シグマ", + "τ": "タウ", + "υ": "ウプシロン", + "φ": "ファイ", + "χ": "カイ", + "ψ": "プサイ", + "ω": "オメガ", +} + + +_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") +_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} +_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") +_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?") + + +def japanese_convert_numbers_to_words(text: str) -> str: + res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) + res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) + res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) + return res + + +def japanese_convert_alpha_symbols_to_words(text: str) -> str: + return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()]) + + +def japanese_text_to_phonemes(text: str) -> str: + """Convert Japanese text to phonemes.""" + res = unicodedata.normalize("NFKC", text) + res = japanese_convert_numbers_to_words(res) + res = japanese_convert_alpha_symbols_to_words(res) + res = text2kata(res) + res = kata2phoneme(res) + return res.replace(" ", "") diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..374d0c8aa9b3e34bf66ab9d86a1dc5cf38e6aa1e --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -0,0 +1,53 @@ +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut +from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer +from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer + +PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, JA_JP_Phonemizer)} + + +ESPEAK_LANGS = list(ESpeak.supported_languages().keys()) +GRUUT_LANGS = list(Gruut.supported_languages()) + + +# Dict setting default phonemizers for each language +# Add Gruut languages +_ = [Gruut.name()] * len(GRUUT_LANGS) +DEF_LANG_TO_PHONEMIZER = dict(list(zip(GRUUT_LANGS, _))) + + +# Add ESpeak languages and override any existing ones +_ = [ESpeak.name()] * len(ESPEAK_LANGS) +_new_dict = dict(list(zip(list(ESPEAK_LANGS), _))) +DEF_LANG_TO_PHONEMIZER.update(_new_dict) + +# Force default for some languages +DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"] +DEF_LANG_TO_PHONEMIZER["ja-jp"] = JA_JP_Phonemizer.name() +DEF_LANG_TO_PHONEMIZER["zh-cn"] = ZH_CN_Phonemizer.name() + + +def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: + """Initiate a phonemizer by name + + Args: + name (str): + Name of the phonemizer that should match `phonemizer.name()`. + + kwargs (dict): + Extra keyword arguments that should be passed to the phonemizer. + """ + if name == "espeak": + return ESpeak(**kwargs) + if name == "gruut": + return Gruut(**kwargs) + if name == "zh_cn_phonemizer": + return ZH_CN_Phonemizer(**kwargs) + if name == "ja_jp_phonemizer": + return JA_JP_Phonemizer(**kwargs) + raise ValueError(f"Phonemizer {name} not found") + + +if __name__ == "__main__": + print(DEF_LANG_TO_PHONEMIZER) diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..08fa8e130a1324f9052a53dfb03f5918a24d3ec6 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -0,0 +1,141 @@ +import abc +from typing import List, Tuple + +from TTS.tts.utils.text.punctuation import Punctuation + + +class BasePhonemizer(abc.ABC): + """Base phonemizer class + + Phonemization follows the following steps: + 1. Preprocessing: + - remove empty lines + - remove punctuation + - keep track of punctuation marks + + 2. Phonemization: + - convert text to phonemes + + 3. Postprocessing: + - join phonemes + - restore punctuation marks + + Args: + language (str): + Language used by the phonemizer. + + punctuations (List[str]): + List of punctuation marks to be preserved. + + keep_puncs (bool): + Whether to preserve punctuation marks or not. + """ + + def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False): + + # ensure the backend is installed on the system + if not self.is_available(): + raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover + + # ensure the backend support the requested language + self._language = self._init_language(language) + + # setup punctuation processing + self._keep_puncs = keep_puncs + self._punctuator = Punctuation(punctuations) + + def _init_language(self, language): + """Language initialization + + This method may be overloaded in child classes (see Segments backend) + + """ + if not self.is_supported_language(language): + raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend") + return language + + @property + def language(self): + """The language code configured to be used for phonemization""" + return self._language + + @staticmethod + @abc.abstractmethod + def name(): + """The name of the backend""" + ... + + @classmethod + @abc.abstractmethod + def is_available(cls): + """Returns True if the backend is installed, False otherwise""" + ... + + @classmethod + @abc.abstractmethod + def version(cls): + """Return the backend version as a tuple (major, minor, patch)""" + ... + + @staticmethod + @abc.abstractmethod + def supported_languages(): + """Return a dict of language codes -> name supported by the backend""" + ... + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return language in self.supported_languages() + + @abc.abstractmethod + def _phonemize(self, text, separator): + """The main phonemization method""" + + def _phonemize_preprocess(self, text) -> Tuple[List[str], List]: + """Preprocess the text before phonemization + + 1. remove spaces + 2. remove punctuation + + Override this if you need a different behaviour + """ + text = text.strip() + if self._keep_puncs: + # a tuple (text, punctuation marks) + return self._punctuator.strip_to_restore(text) + return [self._punctuator.strip(text)], [] + + def _phonemize_postprocess(self, phonemized, punctuations) -> str: + """Postprocess the raw phonemized output + + Override this if you need a different behaviour + """ + if self._keep_puncs: + return self._punctuator.restore(phonemized, punctuations)[0] + return phonemized[0] + + def phonemize(self, text: str, separator="|") -> str: + """Returns the `text` phonemized for the given language + + Args: + text (str): + Text to be phonemized. + + separator (str): + string separator used between phonemes. Default to '_'. + + Returns: + (str): Phonemized text + """ + text, punctuations = self._phonemize_preprocess(text) + phonemized = [] + for t in text: + p = self._phonemize(t, separator) + phonemized.append(p) + phonemized = self._phonemize_postprocess(phonemized, punctuations) + return phonemized + + def print_logs(self, level: int = 0): + indent = "\t" * level + print(f"{indent}| > phoneme language: {self.language}") + print(f"{indent}| > phoneme backend: {self.name()}") diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..024f79c6bed5c01e9fc826e3efa11865eeacb6b3 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -0,0 +1,225 @@ +import logging +import subprocess +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + + +def is_tool(name): + from shutil import which + + return which(name) is not None + + +# priority: espeakng > espeak +if is_tool("espeak-ng"): + _DEF_ESPEAK_LIB = "espeak-ng" +elif is_tool("espeak"): + _DEF_ESPEAK_LIB = "espeak" +else: + _DEF_ESPEAK_LIB = None + + +def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: + """Run espeak with the given arguments.""" + cmd = [ + espeak_lib, + "-q", + "-b", + "1", # UTF8 text encoding + ] + cmd.extend(args) + logging.debug("espeakng: executing %s", repr(cmd)) + + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as p: + res = iter(p.stdout.readline, b"") + if not sync: + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + return res + res2 = [] + for line in res: + res2.append(line) + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + p.wait() + return res2 + + +class ESpeak(BasePhonemizer): + """ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P + + Args: + language (str): + Valid language code for the used backend. + + backend (str): + Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically + prefering `espeak-ng` over `espeak`. Defaults to None. + + punctuations (str): + Characters to be treated as punctuation. Defaults to Punctuation.default_puncs(). + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to True. + + Example: + + >>> from TTS.tts.utils.text.phonemizers import ESpeak + >>> phonemizer = ESpeak("tr") + >>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|") + 'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.' + + """ + + _ESPEAK_LIB = _DEF_ESPEAK_LIB + + def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True): + if self._ESPEAK_LIB is None: + raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.") + self.backend = self._ESPEAK_LIB + + # band-aid for backwards compatibility + if language == "en": + language = "en-us" + + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + if backend is not None: + self.backend = backend + + @property + def backend(self): + return self._ESPEAK_LIB + + @backend.setter + def backend(self, backend): + if backend not in ["espeak", "espeak-ng"]: + raise Exception("Unknown backend: %s" % backend) + self._ESPEAK_LIB = backend + + def auto_set_espeak_lib(self) -> None: + if is_tool("espeak-ng"): + self._ESPEAK_LIB = "espeak-ng" + elif is_tool("espeak"): + self._ESPEAK_LIB = "espeak" + else: + raise Exception("Cannot set backend automatically. espeak-ng or espeak not found") + + @staticmethod + def name(): + return "espeak" + + def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str: + """Convert input text to phonemes. + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + # set arguments + args = ["-v", f"{self._language}"] + # espeak and espeak-ng parses `ipa` differently + if tie: + # use '͡' between phonemes + if self.backend == "espeak": + args.append("--ipa=1") + else: + args.append("--ipa=3") + else: + # split with '_' + if self.backend == "espeak": + args.append("--ipa=3") + else: + args.append("--ipa=1") + if tie: + args.append("--tie=%s" % tie) + + args.append('"' + text + '"') + # compute phonemes + phonemes = "" + for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): + logging.debug("line: %s", repr(line)) + ph_decoded = line.decode("utf8").strip() + # espeak need to skip first two characters of the retuned text: + # version 1.48.03: "_ p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # espeak-ng need to skip the first character of the retuned text: + # "_p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + + # dealing with the conditions descrived above + ph_decoded = ph_decoded[:1].replace("_", "") + ph_decoded[1:] + phonemes += ph_decoded.strip() + return phonemes.replace("_", separator) + + def _phonemize(self, text, separator=None): + return self.phonemize_espeak(text, separator, tie=False) + + @staticmethod + def supported_languages() -> Dict: + """Get a dictionary of supported languages. + + Returns: + Dict: Dictionary of language codes. + """ + if _DEF_ESPEAK_LIB is None: + return {} + args = ["--voices"] + langs = {} + count = 0 + for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True): + line = line.decode("utf8").strip() + if count > 0: + cols = line.split() + lang_code = cols[1] + lang_name = cols[3] + langs[lang_code] = lang_name + logging.debug("line: %s", repr(line)) + count += 1 + return langs + + def version(self) -> str: + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + args = ["--version"] + for line in _espeak_exe(self.backend, args, sync=True): + version = line.decode("utf8").strip().split()[2] + logging.debug("line: %s", repr(line)) + return version + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return is_tool("espeak") or is_tool("espeak-ng") + + +if __name__ == "__main__": + e = ESpeak(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = ESpeak(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = ESpeak(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how are you today?") + "`") diff --git a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e9c9abd4c41935ed07ec10ed883d75b42a6bc8 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py @@ -0,0 +1,151 @@ +import importlib +from typing import List + +import gruut +from gruut_ipa import IPA + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + +# Table for str.translate to fix gruut/TTS phoneme mismatch +GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") + + +class Gruut(BasePhonemizer): + """Gruut wrapper for G2P + + Args: + language (str): + Valid language code for the used backend. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`. + + keep_puncs (bool): + If true, keep the punctuations after phonemization. Defaults to True. + + use_espeak_phonemes (bool): + If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False. + + keep_stress (bool): + If true, keep the stress characters after phonemization. Defaults to False. + + Example: + + >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + >>> phonemizer = Gruut('en-us') + >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|") + 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?' + """ + + def __init__( + self, + language: str, + punctuations=Punctuation.default_puncs(), + keep_puncs=True, + use_espeak_phonemes=False, + keep_stress=False, + ): + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + self.use_espeak_phonemes = use_espeak_phonemes + self.keep_stress = keep_stress + + @staticmethod + def name(): + return "gruut" + + def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument + """Convert input text to phonemes. + + Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters + that constitude a single sound. + + It doesn't affect 🐸TTS since it individually converts each character to token IDs. + + Examples:: + "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ` + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + ph_list = [] + for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes): + for word in sentence: + if word.is_break: + # Use actual character for break phoneme (e.g., comma) + if ph_list: + # Join with previous word + ph_list[-1].append(word.text) + else: + # First word is punctuation + ph_list.append([word.text]) + elif word.phonemes: + # Add phonemes for word + word_phonemes = [] + + for word_phoneme in word.phonemes: + if not self.keep_stress: + # Remove primary/secondary stress + word_phoneme = IPA.without_stress(word_phoneme) + + word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE) + + if word_phoneme: + # Flatten phonemes + word_phonemes.extend(word_phoneme) + + if word_phonemes: + ph_list.append(word_phonemes) + + ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list] + ph = f"{separator} ".join(ph_words) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_gruut(text, separator, tie=False) + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return gruut.is_language_supported(language) + + @staticmethod + def supported_languages() -> List: + """Get a dictionary of supported languages. + + Returns: + List: List of language codes. + """ + return list(gruut.get_supported_languages()) + + def version(self): + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + return gruut.__version__ + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return importlib.util.find_spec("gruut") is not None + + +if __name__ == "__main__": + e = Gruut(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = Gruut(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = Gruut(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how, are you today?") + "`") diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..60b965f9d8f16327a5b6da41729601a96debfdc6 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -0,0 +1,72 @@ +from typing import Dict + +from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】" + +_TRANS_TABLE = {"、": ","} + + +def trans(text): + for i, j in _TRANS_TABLE.items(): + text = text.replace(i, j) + return text + + +class JA_JP_Phonemizer(BasePhonemizer): + """🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer` + + TODO: someone with JA knowledge should check this implementation + + Example: + + >>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer + >>> phonemizer = JA_JP_Phonemizer() + >>> phonemizer.phonemize("どちらに行きますか?", separator="|") + 'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?' + + """ + + language = "ja-jp" + + def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "ja_jp_phonemizer" + + def _phonemize(self, text: str, separator: str = "|") -> str: + ph = japanese_text_to_phonemes(text) + if separator is not None or separator != "": + return separator.join(ph) + return ph + + def phonemize(self, text: str, separator="|") -> str: + """Custom phonemize for JP_JA + + Skip pre-post processing steps used by the other phonemizers. + """ + return self._phonemize(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"ja-jp": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "これは、電話をかけるための私の日本語の例のテキストです。" +# e = JA_JP_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e36b0a2a1f98aae72be017a3b0a956d6300afb61 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -0,0 +1,55 @@ +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name + + +class MultiPhonemizer: + """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages + + Args: + custom_lang_to_phonemizer (Dict): + Custom phonemizer mapping if you want to change the defaults. In the format of + `{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`. + + TODO: find a way to pass custom kwargs to the phonemizers + """ + + lang_to_phonemizer_name = DEF_LANG_TO_PHONEMIZER + language = "multi-lingual" + + def __init__(self, custom_lang_to_phonemizer: Dict = {}) -> None: # pylint: disable=dangerous-default-value + self.lang_to_phonemizer_name.update(custom_lang_to_phonemizer) + self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name) + + @staticmethod + def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict: + lang_to_phonemizer = {} + for k, v in lang_to_phonemizer_name.items(): + phonemizer = get_phonemizer_by_name(v, language=k) + lang_to_phonemizer[k] = phonemizer + return lang_to_phonemizer + + @staticmethod + def name(): + return "multi-phonemizer" + + def phonemize(self, text, language, separator="|"): + return self.lang_to_phonemizer[language].phonemize(text, separator) + + def supported_languages(self) -> List: + return list(self.lang_to_phonemizer_name.keys()) + + +# if __name__ == "__main__": +# texts = { +# "tr": "Merhaba, bu Türkçe bit örnek!", +# "en-us": "Hello, this is English example!", +# "de": "Hallo, das ist ein Deutches Beipiel!", +# "zh-cn": "这是中国的例子", +# } +# phonemes = {} +# ph = MultiPhonemizer() +# for lang, text in texts.items(): +# phoneme = ph.phonemize(text, lang) +# phonemes[lang] = phoneme +# print(phonemes) diff --git a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5a4a55911d84eaaa043e4b6724c4de8a5f249ad4 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py @@ -0,0 +1,62 @@ +from typing import Dict + +from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class ZH_CN_Phonemizer(BasePhonemizer): + """🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + + Example :: + + "这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。` + + TODO: someone with Mandarin knowledge should check this implementation + """ + + language = "zh-cn" + + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "zh_cn_phonemizer" + + @staticmethod + def phonemize_zh_cn(text: str, separator: str = "|") -> str: + ph = chinese_text_to_phonemes(text, separator) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_zh_cn(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"zh-cn": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "这是,样本中文。" +# e = ZH_CN_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a058bb07407b4994a0af2eebb6489f0e91ee05 --- /dev/null +++ b/TTS/tts/utils/text/punctuation.py @@ -0,0 +1,172 @@ +import collections +import re +from enum import Enum + +import six + +_DEF_PUNCS = ';:,.!?¡¿—…"«»“”' + +_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"]) + + +class PuncPosition(Enum): + """Enum for the punctuations positions""" + + BEGIN = 0 + END = 1 + MIDDLE = 2 + ALONE = 3 + + +class Punctuation: + """Handle punctuations in text. + + Just strip punctuations from text or strip and restore them later. + + Args: + puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`. + + Example: + >>> punc = Punctuation() + >>> punc.strip("This is. example !") + 'This is example' + + >>> text_striped, punc_map = punc.strip_to_restore("This is. example !") + >>> ' '.join(text_striped) + 'This is example' + + >>> text_restored = punc.restore(text_striped, punc_map) + >>> text_restored[0] + 'This is. example !' + """ + + def __init__(self, puncs: str = _DEF_PUNCS): + self.puncs = puncs + + @staticmethod + def default_puncs(): + """Return default set of punctuations.""" + return _DEF_PUNCS + + @property + def puncs(self): + return self._puncs + + @puncs.setter + def puncs(self, value): + if not isinstance(value, six.string_types): + raise ValueError("[!] Punctuations must be of type str.") + self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder + self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+") + + def strip(self, text): + """Remove all the punctuations by replacing with `space`. + + Args: + text (str): The text to be processed. + + Example:: + + "This is. example !" -> "This is example " + """ + return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip() + + def strip_to_restore(self, text): + """Remove punctuations from text to restore them later. + + Args: + text (str): The text to be processed. + + Examples :: + + "This is. example !" -> [["This is", "example"], [".", "!"]] + + """ + text, puncs = self._strip_to_restore(text) + return text, puncs + + def _strip_to_restore(self, text): + """Auxiliary method for Punctuation.preserve()""" + matches = list(re.finditer(self.puncs_regular_exp, text)) + if not matches: + return [text], [] + # the text is only punctuations + if len(matches) == 1 and matches[0].group() == text: + return [], [_PUNC_IDX(text, PuncPosition.ALONE)] + # build a punctuation map to be used later to restore punctuations + puncs = [] + for match in matches: + position = PuncPosition.MIDDLE + if match == matches[0] and text.startswith(match.group()): + position = PuncPosition.BEGIN + elif match == matches[-1] and text.endswith(match.group()): + position = PuncPosition.END + puncs.append(_PUNC_IDX(match.group(), position)) + # convert str text to a List[str], each item is separated by a punctuation + splitted_text = [] + for idx, punc in enumerate(puncs): + split = text.split(punc.punc) + prefix, suffix = split[0], punc.punc.join(split[1:]) + splitted_text.append(prefix) + # if the text does not end with a punctuation, add it to the last item + if idx == len(puncs) - 1 and len(suffix) > 0: + splitted_text.append(suffix) + text = suffix + return splitted_text, puncs + + @classmethod + def restore(cls, text, puncs): + """Restore punctuation in a text. + + Args: + text (str): The text to be processed. + puncs (List[str]): The list of punctuations map to be used for restoring. + + Examples :: + + ['This is', 'example'], ['.', '!'] -> "This is. example!" + + """ + return cls._restore(text, puncs, 0) + + @classmethod + def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements + """Auxiliary method for Punctuation.restore()""" + if not puncs: + return text + + # nothing have been phonemized, returns the puncs alone + if not text: + return ["".join(m.mark for m in puncs)] + + current = puncs[0] + + if current.position == PuncPosition.BEGIN: + return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num) + + if current.position == PuncPosition.END: + return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1) + + if current.position == PuncPosition.ALONE: + return [current.mark] + cls._restore(text, puncs[1:], num + 1) + + # POSITION == MIDDLE + if len(text) == 1: # pragma: nocover + # a corner case where the final part of an intermediate + # mark (I) has not been phonemized + return cls._restore([text[0] + current.punc], puncs[1:], num) + + return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num) + + +# if __name__ == "__main__": +# punc = Punctuation() +# text = "This is. This is, example!" + +# print(punc.strip(text)) + +# split_text, puncs = punc.strip_to_restore(text) +# print(split_text, " ---- ", puncs) + +# restored_text = punc.restore(split_text, puncs) +# print(restored_text) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1569c634fb583a13a4040901d1e16b53703aa3dd --- /dev/null +++ b/TTS/tts/utils/text/tokenizer.py @@ -0,0 +1,206 @@ +from typing import Callable, Dict, List, Union + +from TTS.tts.utils.text import cleaners +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +from TTS.utils.generic_utils import get_import_path, import_class + + +class TTSTokenizer: + """🐸TTS tokenizer to convert input characters to token IDs and back. + + Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later. + + Args: + use_phonemes (bool): + Whether to use phonemes instead of characters. Defaults to False. + + characters (Characters): + A Characters object to use for character-to-ID and ID-to-character mappings. + + text_cleaner (callable): + A function to pre-process the text before tokenization and phonemization. Defaults to None. + + phonemizer (Phonemizer): + A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None. + + Example: + + >>> from TTS.tts.utils.text.tokenizer import TTSTokenizer + >>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) + >>> text = "Hello world!" + >>> ids = tokenizer.text_to_ids(text) + >>> text_hat = tokenizer.ids_to_text(ids) + >>> assert text == text_hat + """ + + def __init__( + self, + use_phonemes=False, + text_cleaner: Callable = None, + characters: "BaseCharacters" = None, + phonemizer: Union["Phonemizer", Dict] = None, + add_blank: bool = False, + use_eos_bos=False, + ): + self.text_cleaner = text_cleaner + self.use_phonemes = use_phonemes + self.add_blank = add_blank + self.use_eos_bos = use_eos_bos + self.characters = characters + self.not_found_characters = [] + self.phonemizer = phonemizer + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, new_characters): + self._characters = new_characters + self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None + self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None + + def encode(self, text: str) -> List[int]: + """Encodes a string of text as a sequence of IDs.""" + token_ids = [] + for char in text: + try: + idx = self.characters.char_to_id(char) + token_ids.append(idx) + except KeyError: + # discard but store not found characters + if char not in self.not_found_characters: + self.not_found_characters.append(char) + print(text) + print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") + return token_ids + + def decode(self, token_ids: List[int]) -> str: + """Decodes a sequence of IDs to a string of text.""" + text = "" + for token_id in token_ids: + text += self.characters.id_to_char(token_id) + return text + + def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument + """Converts a string of text to a sequence of token IDs. + + Args: + text(str): + The text to convert to token IDs. + + language(str): + The language code of the text. Defaults to None. + + TODO: + - Add support for language-specific processing. + + 1. Text normalizatin + 2. Phonemization (if use_phonemes is True) + 3. Add blank char between characters + 4. Add BOS and EOS characters + 5. Text to token IDs + """ + # TODO: text cleaner should pick the right routine based on the language + if self.text_cleaner is not None: + text = self.text_cleaner(text) + if self.use_phonemes: + text = self.phonemizer.phonemize(text, separator="") + if self.add_blank: + text = self.intersperse_blank_char(text, True) + if self.use_eos_bos: + text = self.pad_with_bos_eos(text) + return self.encode(text) + + def ids_to_text(self, id_sequence: List[int]) -> str: + """Converts a sequence of token IDs to a string of text.""" + return self.decode(id_sequence) + + def pad_with_bos_eos(self, char_sequence: List[str]): + """Pads a sequence with the special BOS and EOS characters.""" + return [self.characters.bos] + list(char_sequence) + [self.characters.eos] + + def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): + """Intersperses the blank character between characters in a sequence. + + Use the ```blank``` character if defined else use the ```pad``` character. + """ + char_to_use = self.characters.blank if use_blank_char else self.characters.pad + result = [char_to_use] * (len(char_sequence) * 2 + 1) + result[1::2] = char_sequence + return result + + def print_logs(self, level: int = 0): + indent = "\t" * level + print(f"{indent}| > add_blank: {self.add_blank}") + print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") + print(f"{indent}| > use_phonemes: {self.use_phonemes}") + if self.use_phonemes: + print(f"{indent}| > phonemizer:") + self.phonemizer.print_logs(level + 1) + if len(self.not_found_characters) > 0: + print(f"{indent}| > {len(self.not_found_characters)} not found characters:") + for char in self.not_found_characters: + print(f"{indent}| > {char}") + + @staticmethod + def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): + """Init Tokenizer object from config + + Args: + config (Coqpit): Coqpit model config. + characters (BaseCharacters): Defines the model character set. If not set, use the default options based on + the config values. Defaults to None. + """ + # init cleaners + text_cleaner = None + if isinstance(config.text_cleaner, (str, list)): + text_cleaner = getattr(cleaners, config.text_cleaner) + + # init characters + if characters is None: + # set characters based on defined characters class + if config.characters and config.characters.characters_class: + CharactersClass = import_class(config.characters.characters_class) + characters, new_config = CharactersClass.init_from_config(config) + # set characters based on config + else: + if config.use_phonemes: + # init phoneme set + characters, new_config = IPAPhonemes().init_from_config(config) + else: + # init character set + characters, new_config = Graphemes().init_from_config(config) + + else: + characters, new_config = characters.init_from_config(config) + + # set characters class + new_config.characters.characters_class = get_import_path(characters) + + # init phonemizer + phonemizer = None + if config.use_phonemes: + phonemizer_kwargs = {"language": config.phoneme_language} + + if "phonemizer" in config and config.phonemizer: + phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) + else: + try: + phonemizer = get_phonemizer_by_name( + DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs + ) + new_config.phonemizer = phonemizer.name() + except KeyError as e: + raise ValueError( + f"""No phonemizer found for language {config.phoneme_language}. + You may need to install a third party library for this language.""" + ) from e + + return ( + TTSTokenizer( + config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars + ), + new_config, + ) diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..78c12981098ed1870ad799a72e7f7b80e4aafc17 --- /dev/null +++ b/TTS/tts/utils/visual.py @@ -0,0 +1,202 @@ +import librosa +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch + +matplotlib.use("Agg") + + +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False): + if isinstance(alignment, torch.Tensor): + alignment_ = alignment.detach().cpu().numpy().squeeze() + else: + alignment_ = alignment + alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_ + fig, ax = plt.subplots(figsize=fig_size) + im = ax.imshow(alignment_.T, aspect="auto", origin="lower", interpolation="none") + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + # plt.yticks(range(len(text)), list(text)) + plt.tight_layout() + if title is not None: + plt.title(title) + if not output_fig: + plt.close() + return fig + + +def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + fig = plt.figure(figsize=fig_size) + plt.imshow(spectrogram_, aspect="auto", origin="lower") + plt.colorbar() + plt.tight_layout() + if not output_fig: + plt.close() + return fig + + +def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the spectrogram. + + Args: + pitch (np.array): Pitch values. + spectrogram (np.array): Spectrogram values. + + Shapes: + pitch: :math:`(T,)` + spec: :math:`(C, T)` + """ + + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + ax.imshow(spectrogram_, aspect="auto", origin="lower") + ax.set_xlabel("time") + ax.set_ylabel("spec_freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the input characters. + + Args: + pitch (np.array): Pitch values. + chars (str): Characters to place to the x-axis. + + Shapes: + pitch: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = chars + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def visualize( + alignment, + postnet_output, + text, + hop_length, + CONFIG, + tokenizer, + stop_tokens=None, + decoder_output=None, + output_path=None, + figsize=(8, 24), + output_fig=False, +): + """Intended to be used in Notebooks.""" + + if decoder_output is not None: + num_plot = 4 + else: + num_plot = 3 + + label_fontsize = 16 + fig = plt.figure(figsize=figsize) + + plt.subplot(num_plot, 1, 1) + plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) + plt.xlabel("Decoder timestamp", fontsize=label_fontsize) + plt.ylabel("Encoder timestamp", fontsize=label_fontsize) + # compute phoneme representation and back + if CONFIG.use_phonemes: + seq = tokenizer.text_to_ids(text) + text = tokenizer.ids_to_text(seq) + print(text) + plt.yticks(range(len(text)), list(text)) + plt.colorbar() + + if stop_tokens is not None: + # plot stopnet predictions + plt.subplot(num_plot, 1, 2) + plt.plot(range(len(stop_tokens)), list(stop_tokens)) + + # plot postnet spectrogram + plt.subplot(num_plot, 1, 3) + librosa.display.specshow( + postnet_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) + + plt.xlabel("Time", fontsize=label_fontsize) + plt.ylabel("Hz", fontsize=label_fontsize) + plt.tight_layout() + plt.colorbar() + + if decoder_output is not None: + plt.subplot(num_plot, 1, 4) + librosa.display.specshow( + decoder_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) + plt.xlabel("Time", fontsize=label_fontsize) + plt.ylabel("Hz", fontsize=label_fontsize) + plt.tight_layout() + plt.colorbar() + + if output_path: + print(output_path) + fig.savefig(output_path) + plt.close() + + if not output_fig: + plt.close() diff --git a/TTS/utils/__init__.py b/TTS/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9d194201a9e3d8de3bdc4e3103e4a0254a6ea4 --- /dev/null +++ b/TTS/utils/audio.py @@ -0,0 +1,927 @@ +from typing import Dict, Tuple + +import librosa +import numpy as np +import pyworld as pw +import scipy.io.wavfile +import scipy.signal +import soundfile as sf +import torch +from torch import nn + +from TTS.tts.utils.helpers import StandardScaler + + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + TODO: Merge this with audio.py + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=False, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + + if self.power is not None: + S = S**self.power + + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + self.sample_rate, + self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain + + +# pylint: disable=too-many-public-methods +class AudioProcessor(object): + """Audio Processor for TTS used by all the data pipelines. + + TODO: Make this a dataclass to replace `BaseAudioConfig`. + + Note: + All the class arguments are set to default values to enable a flexible initialization + of the class with the model config. They are not meaningful for all the arguments. + + Args: + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + resample (bool, optional): + enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. + + num_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + log_func (int, optional): + log exponent used for converting spectrogram aplitude to DB. + + min_level_db (int, optional): + minimum db threshold for the computed melspectrograms. Defaults to None. + + frame_shift_ms (int, optional): + milliseconds of frames between STFT columns. Defaults to None. + + frame_length_ms (int, optional): + milliseconds of STFT window length. Defaults to None. + + hop_length (int, optional): + number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. + + win_length (int, optional): + STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. + + ref_level_db (int, optional): + reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. + + fft_size (int, optional): + FFT window size for STFT. Defaults to 1024. + + power (int, optional): + Exponent value applied to the spectrogram before GriffinLim. Defaults to None. + + preemphasis (float, optional): + Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. + + signal_norm (bool, optional): + enable/disable signal normalization. Defaults to None. + + symmetric_norm (bool, optional): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. + + max_norm (float, optional): + ```k``` defining the normalization range. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + pitch_fmin (int, optional): + minimum filter frequency for computing pitch. Defaults to None. + + pitch_fmax (int, optional): + maximum filter frequency for computing pitch. Defaults to None. + + spec_gain (int, optional): + gain applied when converting amplitude to DB. Defaults to 20. + + stft_pad_mode (str, optional): + Padding mode for STFT. Defaults to 'reflect'. + + clip_norm (bool, optional): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + griffin_lim_iters (int, optional): + Number of GriffinLim iterations. Defaults to None. + + do_trim_silence (bool, optional): + enable/disable silence trimming when loading the audio signal. Defaults to False. + + trim_db (int, optional): + DB threshold used for silence trimming. Defaults to 60. + + do_sound_norm (bool, optional): + enable/disable signal normalization. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + stats_path (str, optional): + Path to the computed stats file. Defaults to None. + + verbose (bool, optional): + enable/disable logging. Defaults to True. + + """ + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + pitch_fmax=None, + pitch_fmin=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + do_amp_to_db_linear=True, + do_amp_to_db_mel=True, + do_rms_norm=False, + db_level=None, + stats_path=None, + verbose=True, + **_, + ): + + # setup class attributed + self.sample_rate = sample_rate + self.resample = resample + self.num_mels = num_mels + self.log_func = log_func + self.min_level_db = min_level_db or 0 + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.ref_level_db = ref_level_db + self.fft_size = fft_size + self.power = power + self.preemphasis = preemphasis + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = mel_fmin or 0 + self.mel_fmax = mel_fmax + self.pitch_fmin = pitch_fmin + self.pitch_fmax = pitch_fmax + self.spec_gain = float(spec_gain) + self.stft_pad_mode = stft_pad_mode + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + self.trim_db = trim_db + self.do_sound_norm = do_sound_norm + self.do_amp_to_db_linear = do_amp_to_db_linear + self.do_amp_to_db_mel = do_amp_to_db_mel + self.do_rms_norm = do_rms_norm + self.db_level = db_level + self.stats_path = stats_path + # setup exp_func for db to amp conversion + if log_func == "np.log": + self.base = np.e + elif log_func == "np.log10": + self.base = 10 + else: + raise ValueError(" [!] unknown `log_func` value.") + # setup stft parameters + if hop_length is None: + # compute stft parameters from given time values + self.hop_length, self.win_length = self._stft_parameters() + else: + # use stft parameters from config file + self.hop_length = hop_length + self.win_length = win_length + assert min_level_db != 0.0, " [!] min_level_db is 0" + assert ( + self.win_length <= self.fft_size + ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" + members = vars(self) + if verbose: + print(" > Setting up Audio Processor...") + for key, value in members.items(): + print(" | > {}:{}".format(key, value)) + # create spectrogram utils + self.mel_basis = self._build_mel_basis() + self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + # setup scaler + if stats_path and signal_norm: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None + + @staticmethod + def init_from_config(config: "Coqpit", verbose=True): + if "audio" in config: + return AudioProcessor(verbose=verbose, **config.audio) + return AudioProcessor(verbose=verbose, **config) + + ### setting up the parameters ### + def _build_mel_basis( + self, + ) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if self.mel_fmax is not None: + assert self.mel_fmax <= self.sample_rate // 2 + return librosa.filters.mel( + self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) + + def _stft_parameters( + self, + ) -> Tuple[int, int]: + """Compute the real STFT parameters from the time values. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = self.frame_length_ms / self.frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) + win_length = int(hop_length * factor) + return hop_length, win_length + + ### normalization ### + def normalize(self, S: np.ndarray) -> np.ndarray: + """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` + + Args: + S (np.ndarray): Spectrogram to normalize. + + Raises: + RuntimeError: Mean and variance is computed from incompatible parameters. + + Returns: + np.ndarray: Normalized spectrogram. + """ + # pylint: disable=no-else-return + S = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.fft_size / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + # range normalization + S -= self.ref_level_db # discard certain range of DB assuming it is air noise + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def denormalize(self, S: np.ndarray) -> np.ndarray: + """Denormalize spectrogram values. + + Args: + S (np.ndarray): Spectrogram to denormalize. + + Raises: + RuntimeError: Mean and variance are incompatible. + + Returns: + np.ndarray: Denormalized spectrogram. + """ + # pylint: disable=no-else-return + S_denorm = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.fft_size / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + self.ref_level_db + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db + return S_denorm + self.ref_level_db + else: + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: + """Loading mean and variance statistics from a `npy` file. + + Args: + stats_path (str): Path to the `npy` file containing + + Returns: + Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to + compute them. + """ + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] + # check all audio parameters used for computing stats + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] + for key in stats_config.keys(): + if key in skip_parameters: + continue + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler( + self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray + ) -> None: + """Initialize scaler objects used in mean-std normalization. + + Args: + mel_mean (np.ndarray): Mean for melspectrograms. + mel_std (np.ndarray): STD for melspectrograms. + linear_mean (np.ndarray): Mean for full scale spectrograms. + linear_std (np.ndarray): STD for full scale spectrograms. + """ + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) + + ### DB and AMP conversion ### + # pylint: disable=no-self-use + def _amp_to_db(self, x: np.ndarray) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + + Returns: + np.ndarray: Decibels spectrogram. + """ + return self.spec_gain * _log(np.maximum(1e-5, x), self.base) + + # pylint: disable=no-self-use + def _db_to_amp(self, x: np.ndarray) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / self.spec_gain, self.base) + + ### Preemphasis ### + def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -self.preemphasis], [1], x) + + def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Reverse pre-emphasis.""" + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -self.preemphasis], x) + + ### SPECTROGRAMs ### + def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray: + """Project a full scale spectrogram to a melspectrogram. + + Args: + spectrogram (np.ndarray): Full scale spectrogram. + + Returns: + np.ndarray: Melspectrogram + """ + return np.dot(self.mel_basis, spectrogram) + + def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec)) + + def spectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_linear: + S = self._amp_to_db(np.abs(D)) + else: + S = np.abs(D) + return self.normalize(S).astype(np.float32) + + def melspectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_mel: + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) + else: + S = self._linear_to_mel(np.abs(D)) + return self.normalize(S).astype(np.float32) + + def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = self.denormalize(spectrogram) + S = self._db_to_amp(S) + # Reconstruct phase + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) + return self._griffin_lim(S**self.power) + + def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + D = self.denormalize(mel_spectrogram) + S = self._db_to_amp(D) + S = self._mel_to_linear(S) # Convert back to linear + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) + return self._griffin_lim(S**self.power) + + def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + S = self.denormalize(linear_spec) + S = self._db_to_amp(S) + S = self._linear_to_mel(np.abs(S)) + S = self._amp_to_db(S) + mel = self.normalize(S) + return mel + + ### STFT and ISTFT ### + def _stft(self, y: np.ndarray) -> np.ndarray: + """Librosa STFT wrapper. + + Args: + y (np.ndarray): Audio signal. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + window="hann", + center=True, + ) + + def _istft(self, y: np.ndarray) -> np.ndarray: + """Librosa iSTFT wrapper.""" + return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) + + def _griffin_lim(self, S): + angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) + S_complex = np.abs(S).astype(np.complex) + y = self._istft(S_complex * angles) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(self.griffin_lim_iters): + angles = np.exp(1j * np.angle(self._stft(y))) + y = self._istft(S_complex * angles) + return y + + def compute_stft_paddings(self, x, pad_sides=1): + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + assert pad_sides in (1, 2) + pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] + if pad_sides == 1: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(pitch_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." + # align F0 length to the spectrogram length + if len(x) % self.hop_length == 0: + x = np.pad(x, (0, self.hop_length // 2), mode="reflect") + + f0, t = pw.dio( + x.astype(np.double), + fs=self.sample_rate, + f0_ceil=self.pitch_fmax, + frame_period=1000 * self.hop_length / self.sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) + return f0 + + ### Audio Processing ### + def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + window_length = int(self.sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = self._db_to_amp(-self.trim_db) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + def trim_silence(self, wav): + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(self.sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ + 0 + ] + + @staticmethod + def sound_norm(x: np.ndarray) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * 0.95 + + @staticmethod + def _rms_norm(wav, db_level=-27): + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) + return wav * a + + def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: RMS normalized waveform. + """ + if db_level is None: + db_level = self.db_level + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = self._rms_norm(x, db_level) + return wav + + ### save and load ### + def load_wav(self, filename: str, sr: int = None) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + + Returns: + np.ndarray: Loaded waveform. + """ + if self.resample: + # loading with resampling. It is significantly slower. + x, sr = librosa.load(filename, sr=self.sample_rate) + elif sr is None: + # SF is faster than librosa for loading files + x, sr = sf.read(filename) + assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) + else: + x, sr = librosa.load(filename, sr=sr) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + print(f" [!] File cannot be trimmed for silence - {filename}") + if self.do_sound_norm: + x = self.sound_norm(x) + if self.do_rms_norm: + x = self.rms_volume_norm(x, self.db_level) + return x + + def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None: + """Save a waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + """ + if self.do_rms_norm: + wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767 + else: + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16)) + + def get_duration(self, filename: str) -> float: + """Get the duration of a wav file using Librosa. + + Args: + filename (str): Path to the wav file. + """ + return librosa.get_duration(filename) + + @staticmethod + def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: + mu = 2**qc - 1 + # wav_abs = np.minimum(np.abs(wav), 1.0) + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + # Quantize signal to the specified number of levels. + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + @staticmethod + def mulaw_decode(wav, qc): + """Recovers waveform from quantized values.""" + mu = 2**qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + @staticmethod + def encode_16bits(x): + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + @staticmethod + def quantize(x: np.ndarray, bits: int) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2**bits - 1) / 2 + + @staticmethod + def dequantize(x, bits): + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2**bits - 1) - 1 + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) diff --git a/TTS/utils/callbacks.py b/TTS/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..511d215c656f1ce3ed31484963db64fae4dc77d4 --- /dev/null +++ b/TTS/utils/callbacks.py @@ -0,0 +1,105 @@ +class TrainerCallback: + @staticmethod + def on_init_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_start"): + trainer.model.module.on_init_start(trainer) + else: + if hasattr(trainer.model, "on_init_start"): + trainer.model.on_init_start(trainer) + + if hasattr(trainer.criterion, "on_init_start"): + trainer.criterion.on_init_start(trainer) + + if hasattr(trainer.optimizer, "on_init_start"): + trainer.optimizer.on_init_start(trainer) + + @staticmethod + def on_init_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_end"): + trainer.model.module.on_init_end(trainer) + else: + if hasattr(trainer.model, "on_init_end"): + trainer.model.on_init_end(trainer) + + if hasattr(trainer.criterion, "on_init_end"): + trainer.criterion.on_init_end(trainer) + + if hasattr(trainer.optimizer, "on_init_end"): + trainer.optimizer.on_init_end(trainer) + + @staticmethod + def on_epoch_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_start"): + trainer.model.module.on_epoch_start(trainer) + else: + if hasattr(trainer.model, "on_epoch_start"): + trainer.model.on_epoch_start(trainer) + + if hasattr(trainer.criterion, "on_epoch_start"): + trainer.criterion.on_epoch_start(trainer) + + if hasattr(trainer.optimizer, "on_epoch_start"): + trainer.optimizer.on_epoch_start(trainer) + + @staticmethod + def on_epoch_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_end"): + trainer.model.module.on_epoch_end(trainer) + else: + if hasattr(trainer.model, "on_epoch_end"): + trainer.model.on_epoch_end(trainer) + + if hasattr(trainer.criterion, "on_epoch_end"): + trainer.criterion.on_epoch_end(trainer) + + if hasattr(trainer.optimizer, "on_epoch_end"): + trainer.optimizer.on_epoch_end(trainer) + + @staticmethod + def on_train_step_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_start"): + trainer.model.module.on_train_step_start(trainer) + else: + if hasattr(trainer.model, "on_train_step_start"): + trainer.model.on_train_step_start(trainer) + + if hasattr(trainer.criterion, "on_train_step_start"): + trainer.criterion.on_train_step_start(trainer) + + if hasattr(trainer.optimizer, "on_train_step_start"): + trainer.optimizer.on_train_step_start(trainer) + + @staticmethod + def on_train_step_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_end"): + trainer.model.module.on_train_step_end(trainer) + else: + if hasattr(trainer.model, "on_train_step_end"): + trainer.model.on_train_step_end(trainer) + + if hasattr(trainer.criterion, "on_train_step_end"): + trainer.criterion.on_train_step_end(trainer) + + if hasattr(trainer.optimizer, "on_train_step_end"): + trainer.optimizer.on_train_step_end(trainer) + + @staticmethod + def on_keyboard_interrupt(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_keyboard_interrupt"): + trainer.model.module.on_keyboard_interrupt(trainer) + else: + if hasattr(trainer.model, "on_keyboard_interrupt"): + trainer.model.on_keyboard_interrupt(trainer) + + if hasattr(trainer.criterion, "on_keyboard_interrupt"): + trainer.criterion.on_keyboard_interrupt(trainer) + + if hasattr(trainer.optimizer, "on_keyboard_interrupt"): + trainer.optimizer.on_keyboard_interrupt(trainer) diff --git a/TTS/utils/capacitron_optimizer.py b/TTS/utils/capacitron_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f075afac86d425d0355a6d678a9c1ca3f0062e --- /dev/null +++ b/TTS/utils/capacitron_optimizer.py @@ -0,0 +1,65 @@ +from typing import Generator + +from trainer.trainer_utils import get_optimizer + + +class CapacitronOptimizer: + """Double optimizer class for the Capacitron model.""" + + def __init__(self, config: dict, model_params: Generator) -> None: + self.primary_params, self.secondary_params = self.split_model_parameters(model_params) + + optimizer_names = list(config.optimizer_params.keys()) + optimizer_parameters = list(config.optimizer_params.values()) + + self.primary_optimizer = get_optimizer( + optimizer_names[0], + optimizer_parameters[0], + config.lr, + parameters=self.primary_params, + ) + + self.secondary_optimizer = get_optimizer( + optimizer_names[1], + self.extract_optimizer_parameters(optimizer_parameters[1]), + optimizer_parameters[1]["lr"], + parameters=self.secondary_params, + ) + + self.param_groups = self.primary_optimizer.param_groups + + def first_step(self): + self.secondary_optimizer.step() + self.secondary_optimizer.zero_grad() + self.primary_optimizer.zero_grad() + + def step(self): + self.primary_optimizer.step() + + def zero_grad(self): + self.primary_optimizer.zero_grad() + self.secondary_optimizer.zero_grad() + + def load_state_dict(self, state_dict): + self.primary_optimizer.load_state_dict(state_dict[0]) + self.secondary_optimizer.load_state_dict(state_dict[1]) + + def state_dict(self): + return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()] + + @staticmethod + def split_model_parameters(model_params: Generator) -> list: + primary_params = [] + secondary_params = [] + for name, param in model_params: + if param.requires_grad: + if name == "capacitron_vae_layer.beta": + secondary_params.append(param) + else: + primary_params.append(param) + return [iter(primary_params), iter(secondary_params)] + + @staticmethod + def extract_optimizer_parameters(params: dict) -> dict: + """Extract parameters that are not the learning rate""" + return {k: v for k, v in params.items() if k != "lr"} diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py new file mode 100644 index 0000000000000000000000000000000000000000..a51ef7661ece97c87c165ad1aba4c9d9700379dc --- /dev/null +++ b/TTS/utils/distribute.py @@ -0,0 +1,20 @@ +# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py +import torch +import torch.distributed as dist + + +def reduce_tensor(tensor, num_gpus): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.reduce_op.SUM) + rt /= num_gpus + return rt + + +def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): + assert torch.cuda.is_available(), "Distributed mode requires CUDA." + + # Set cuda device so everything is done on the right GPU. + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Initialize distributed communication + dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) diff --git a/TTS/utils/download.py b/TTS/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..de9b31a7a87071a964cd171b2075b03a7a433a76 --- /dev/null +++ b/TTS/utils/download.py @@ -0,0 +1,207 @@ +# Adapted from https://github.com/pytorch/audio/ + +import hashlib +import logging +import os +import tarfile +import urllib +import urllib.request +import zipfile +from os.path import expanduser +from typing import Any, Iterable, List, Optional + +from torch.utils.model_zoo import tqdm + + +def stream_url( + url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True +) -> Iterable: + """Stream url by chunk + + Args: + url (str): Url. + start_byte (int or None, optional): Start streaming at that point (Default: ``None``). + block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + """ + + # If we already have the whole file, there is no need to download it again + req = urllib.request.Request(url, method="HEAD") + with urllib.request.urlopen(req) as response: + url_size = int(response.info().get("Content-Length", -1)) + if url_size == start_byte: + return + + req = urllib.request.Request(url) + if start_byte: + req.headers["Range"] = "bytes={}-".format(start_byte) + + with urllib.request.urlopen(req) as upointer, tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=url_size, + disable=not progress_bar, + ) as pbar: + + num_bytes = 0 + while True: + chunk = upointer.read(block_size) + if not chunk: + break + yield chunk + num_bytes += len(chunk) + pbar.update(len(chunk)) + + +def download_url( + url: str, + download_folder: str, + filename: Optional[str] = None, + hash_value: Optional[str] = None, + hash_type: str = "sha256", + progress_bar: bool = True, + resume: bool = False, +) -> None: + """Download file to disk. + + Args: + url (str): Url. + download_folder (str): Folder to download file. + filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url + (Default: ``None``). + hash_value (str or None, optional): Hash for url (Default: ``None``). + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + resume (bool, optional): Enable resuming download (Default: ``False``). + """ + + req = urllib.request.Request(url, method="HEAD") + req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with + + # Detect filename + filename = filename or req_info.get_filename() or os.path.basename(url) + filepath = os.path.join(download_folder, filename) + if resume and os.path.exists(filepath): + mode = "ab" + local_size: Optional[int] = os.path.getsize(filepath) + + elif not resume and os.path.exists(filepath): + raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath)) + else: + mode = "wb" + local_size = None + + if hash_value and local_size == int(req_info.get("Content-Length", -1)): + with open(filepath, "rb") as file_obj: + if validate_file(file_obj, hash_value, hash_type): + return + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + with open(filepath, mode) as fpointer: + for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): + fpointer.write(chunk) + + with open(filepath, "rb") as file_obj: + if hash_value and not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + +def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + + Returns: + bool: return True if its a valid file, else False. + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = file_obj.read(1024**2) + if not chunk: + break + hash_func.update(chunk) + + return hash_func.hexdigest() == hash_value + + +def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: + """Extract archive. + Args: + from_path (str): the path of the archive. + to_path (str or None, optional): the root path of the extraced files (directory of from_path) + (Default: ``None``) + overwrite (bool, optional): overwrite existing files (Default: ``False``) + + Returns: + list: List of paths to extracted files even if not overwritten. + """ + + if to_path is None: + to_path = os.path.dirname(from_path) + + try: + with tarfile.open(from_path, "r") as tar: + logging.info("Opened tar file %s.", from_path) + files = [] + for file_ in tar: # type: Any + file_path = os.path.join(to_path, file_.name) + if file_.isfile(): + files.append(file_path) + if os.path.exists(file_path): + logging.info("%s already extracted.", file_path) + if not overwrite: + continue + tar.extract(file_, to_path) + return files + except tarfile.ReadError: + pass + + try: + with zipfile.ZipFile(from_path, "r") as zfile: + logging.info("Opened zip file %s.", from_path) + files = zfile.namelist() + for file_ in files: + file_path = os.path.join(to_path, file_) + if os.path.exists(file_path): + logging.info("%s already extracted.", file_path) + if not overwrite: + continue + zfile.extract(file_, to_path) + return files + except zipfile.BadZipFile: + pass + + raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.") + + +def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str): + """Download dataset from kaggle. + Args: + dataset_path (str): + This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning' + dataset_name (str): Name of the folder the dataset will be saved in. + output_path (str): Path of the location you want the dataset folder to be saved to. + """ + data_path = os.path.join(output_path, dataset_name) + try: + import kaggle # pylint: disable=import-outside-toplevel + + kaggle.api.authenticate() + print(f"""\nDownloading {dataset_name}...""") + kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) + except OSError: + print( + f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}""" + ) diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..104dc7b94e17b1d7f828103d2396d6c5115b628a --- /dev/null +++ b/TTS/utils/downloaders.py @@ -0,0 +1,126 @@ +import os +from typing import Optional + +from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive + + +def download_ljspeech(path: str): + """Download and extract LJSpeech dataset + + Args: + path (str): path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_vctk(path: str, use_kaggle: Optional[bool] = False): + """Download and extract VCTK dataset. + + Args: + path (str): path to the directory where the dataset will be stored. + + use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False. + """ + if use_kaggle: + download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path) + else: + os.makedirs(path, exist_ok=True) + url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_tweb(path: str): + """Download and extract Tweb dataset + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path) + + +def download_libri_tts(path: str, subset: Optional[str] = "all"): + """Download and extract libri tts dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + subset (str, optional): Name of the subset to download. If you only want to download a certain + portion specify it here. Defaults to 'all'. + """ + + subset_dict = { + "libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz", + "libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz", + "libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz", + "libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz", + "libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz", + "libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz", + "libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz", + } + + os.makedirs(path, exist_ok=True) + if subset == "all": + for sub, val in subset_dict.items(): + print(f" > Downloading {sub}...") + download_url(val, path) + basename = os.path.basename(val) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + print(" > All subsets downloaded") + else: + url = subset_dict[subset] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_thorsten_de(path: str): + """Download and extract Thorsten german male voice dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_mailabs(path: str, language: str = "english"): + """Download and extract Mailabs dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + language (str): Language subset to download. Defaults to english. + """ + language_dict = { + "english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz", + "german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz", + "french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz", + "italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz", + "spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz", + } + os.makedirs(path, exist_ok=True) + url = language_dict[language] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b685210c1179b8adfc1ed57c9a5089aff07f52ae --- /dev/null +++ b/TTS/utils/generic_utils.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +import datetime +import importlib +import os +import re +import subprocess +import sys +from pathlib import Path +from typing import Dict + +import fsspec +import torch + + +def to_cuda(x: torch.Tensor) -> torch.Tensor: + if x is None: + return None + if torch.is_tensor(x): + x = x.contiguous() + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) + return x + + +def get_cuda(): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return use_cuda, device + + +def get_git_branch(): + try: + out = subprocess.check_output(["git", "branch"]).decode("utf8") + current = next(line for line in out.split("\n") if line.startswith("*")) + current.replace("* ", "") + except subprocess.CalledProcessError: + current = "inside_docker" + except FileNotFoundError: + current = "unknown" + return current + + +def get_commit_hash(): + """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" + # try: + # subprocess.check_output(['git', 'diff-index', '--quiet', + # 'HEAD']) # Verify client is clean + # except: + # raise RuntimeError( + # " !! Commit before training to get the commit hash.") + try: + commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() + # Not copying .git folder into docker container + except (subprocess.CalledProcessError, FileNotFoundError): + commit = "0000000" + return commit + + +def get_experiment_folder_path(root_path, model_name): + """Get an experiment folder path with the current date and time""" + date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") + commit_hash = get_commit_hash() + output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) + return output_folder + + +def remove_experiment_folder(experiment_path): + """Check folder if there is a checkpoint, otherwise remove the folder""" + fs = fsspec.get_mapper(experiment_path).fs + checkpoint_files = fs.glob(experiment_path + "/*.pth") + if not checkpoint_files: + if fs.exists(experiment_path): + fs.rm(experiment_path, recursive=True) + print(" ! Run is removed from {}".format(experiment_path)) + else: + print(" ! Run is kept in {}".format(experiment_path)) + + +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def to_camel(text): + text = text.capitalize() + text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + text = text.replace("Tts", "TTS") + return text + + +def find_module(module_path: str, module_name: str) -> object: + module_name = module_name.lower() + module = importlib.import_module(module_path + "." + module_name) + class_name = to_camel(module_name) + return getattr(module, class_name) + + +def import_class(module_path: str) -> object: + """Import a class from a module path. + + Args: + module_path (str): The module path of the class. + + Returns: + object: The imported class. + """ + class_name = module_path.split(".")[-1] + module_path = ".".join(module_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_import_path(obj: object) -> str: + """Get the import path of a class. + + Args: + obj (object): The class object. + + Returns: + str: The import path of the class. + """ + return ".".join([type(obj).__module__, type(obj).__name__]) + + +def get_user_data_dir(appname): + if sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + return ans.joinpath(appname) + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + print(" | > Layer missing in the model definition: {}".format(k)) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + return model_dict + + +def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: + """Format kwargs to hande auxilary inputs to models. + + Args: + def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. + kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. + + Returns: + Dict: arguments with formatted auxilary inputs. + """ + for name in def_args: + if name not in kwargs: + kwargs[def_args[name]] = None + return kwargs + + +class KeepAverage: + def __init__(self): + self.avg_values = {} + self.iters = {} + + def __getitem__(self, key): + return self.avg_values[key] + + def items(self): + return self.avg_values.items() + + def add_value(self, name, init_val=0, init_iter=0): + self.avg_values[name] = init_val + self.iters[name] = init_iter + + def update_value(self, name, value, weighted_avg=False): + if name not in self.avg_values: + # add value if not exist before + self.add_value(name, init_val=value) + else: + # else update existing value + if weighted_avg: + self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value + self.iters[name] += 1 + else: + self.avg_values[name] = self.avg_values[name] * self.iters[name] + value + self.iters[name] += 1 + self.avg_values[name] /= self.iters[name] + + def add_values(self, name_dict): + for key, value in name_dict.items(): + self.add_value(key, init_val=value) + + def update_values(self, value_dict): + for key, value in value_dict.items(): + self.update_value(key, value) diff --git a/TTS/utils/io.py b/TTS/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..0b32f77ab281073c399cc0aabe86670ff8f90969 --- /dev/null +++ b/TTS/utils/io.py @@ -0,0 +1,201 @@ +import datetime +import json +import os +import pickle as pickle_tts +import shutil +from typing import Any, Callable, Dict, Union + +import fsspec +import torch +from coqpit import Coqpit + + +class RenamingUnpickler(pickle_tts.Unpickler): + """Overload default pickler to solve module renaming problem""" + + def find_class(self, module, name): + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) + + +class AttrDict(dict): + """A custom dict which converts dict keys + to class attributes""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def copy_model_files(config: Coqpit, out_path, new_fields=None): + """Copy config.json and other model files to training folder and add + new fields. + + Args: + config (Coqpit): Coqpit config defining the training run. + out_path (str): output path to copy the file. + new_fields (dict): new fileds to be added or edited + in the config file. + """ + copy_config_path = os.path.join(out_path, "config.json") + # add extra information fields + if new_fields: + config.update(new_fields, allow_new=True) + # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. + with fsspec.open(copy_config_path, "w", encoding="utf8") as f: + json.dump(config.to_dict(), f, indent=4) + + # copy model stats file if available + if config.audio.stats_path is not None: + copy_stats_path = os.path.join(out_path, "scale_stats.npy") + filesystem = fsspec.get_mapper(copy_stats_path).fs + if not filesystem.exists(copy_stats_path): + with fsspec.open(config.audio.stats_path, "rb") as source_file: + with fsspec.open(copy_stats_path, "wb") as target_file: + shutil.copyfileobj(source_file, target_file) + + +def load_fsspec( + path: str, + map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + **kwargs, +) -> Any: + """Like torch.load but can load from other locations (e.g. s3:// , gs://). + + Args: + path: Any path or url supported by fsspec. + map_location: torch.device or str. + **kwargs: Keyword arguments forwarded to torch.load. + + Returns: + Object stored in path. + """ + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location=map_location, **kwargs) + + +def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin + try: + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) + if use_cuda: + model.cuda() + if eval: + model.eval() + return model, state + + +def save_fsspec(state: Any, path: str, **kwargs): + """Like torch.save but can save to other locations (e.g. s3:// , gs://). + + Args: + state: State object to save + path: Any path or url supported by fsspec. + **kwargs: Keyword arguments forwarded to torch.save. + """ + with fsspec.open(path, "wb") as f: + torch.save(state, f, **kwargs) + + +def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): + if hasattr(model, "module"): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() + if isinstance(optimizer, list): + optimizer_state = [optim.state_dict() for optim in optimizer] + elif optimizer.__class__.__name__ == "CapacitronOptimizer": + optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()] + else: + optimizer_state = optimizer.state_dict() if optimizer is not None else None + + if isinstance(scaler, list): + scaler_state = [s.state_dict() for s in scaler] + else: + scaler_state = scaler.state_dict() if scaler is not None else None + + if isinstance(config, Coqpit): + config = config.to_dict() + + state = { + "config": config, + "model": model_state, + "optimizer": optimizer_state, + "scaler": scaler_state, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + state.update(kwargs) + save_fsspec(state, output_path) + + +def save_checkpoint( + config, + model, + optimizer, + scaler, + current_step, + epoch, + output_folder, + **kwargs, +): + file_name = "checkpoint_{}.pth".format(current_step) + checkpoint_path = os.path.join(output_folder, file_name) + print("\n > CHECKPOINT : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + **kwargs, + ) + + +def save_best_model( + current_loss, + best_loss, + config, + model, + optimizer, + scaler, + current_step, + epoch, + out_path, + keep_all_best=False, + keep_after=10000, + **kwargs, +): + if current_loss < best_loss: + best_model_name = f"best_model_{current_step}.pth" + checkpoint_path = os.path.join(out_path, best_model_name) + print(" > BEST MODEL : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + model_loss=current_loss, + **kwargs, + ) + fs = fsspec.get_mapper(out_path).fs + # only delete previous if current is saved successfully + if not keep_all_best or (current_step < keep_after): + model_names = fs.glob(os.path.join(out_path, "best_model*.pth")) + for model_name in model_names: + if os.path.basename(model_name) != best_model_name: + fs.rm(model_name) + # create a shortcut which always points to the currently best model + shortcut_name = "best_model.pth" + shortcut_path = os.path.join(out_path, shortcut_name) + fs.copy(checkpoint_path, shortcut_path) + best_loss = current_loss + return best_loss diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py new file mode 100644 index 0000000000000000000000000000000000000000..281e5af02a65380d6882ca315e2bf1b72b1845a6 --- /dev/null +++ b/TTS/utils/manage.py @@ -0,0 +1,363 @@ +import io +import json +import os +import zipfile +from pathlib import Path +from shutil import copyfile, rmtree +from typing import Dict, Tuple + +import requests + +from TTS.config import load_config +from TTS.utils.generic_utils import get_user_data_dir + +LICENSE_URLS = { + "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", + "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/", + "mit": "https://choosealicense.com/licenses/mit/", + "apache 2.0": "https://choosealicense.com/licenses/apache-2.0/", + "apache2": "https://choosealicense.com/licenses/apache-2.0/", + "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/", +} + + +class ModelManager(object): + """Manage TTS models defined in .models.json. + It provides an interface to list and download + models defines in '.model.json' + + Models are downloaded under '.TTS' folder in the user's + home path. + + Args: + models_file (str): path to .model.json + """ + + def __init__(self, models_file=None, output_prefix=None): + super().__init__() + if output_prefix is None: + self.output_prefix = get_user_data_dir("tts") + else: + self.output_prefix = os.path.join(output_prefix, "tts") + self.models_dict = None + if models_file is not None: + self.read_models_file(models_file) + else: + # try the default location + path = Path(__file__).parent / "../.models.json" + self.read_models_file(path) + + def read_models_file(self, file_path): + """Read .models.json as a dict + + Args: + file_path (str): path to .models.json. + """ + with open(file_path, "r", encoding="utf-8") as json_file: + self.models_dict = json.load(json_file) + + def _list_models(self, model_type, model_count=0): + model_list = [] + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + output_path = os.path.join(self.output_prefix, model_full_name) + if os.path.exists(output_path): + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") + else: + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + model_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + return model_list + + def _list_for_model_type(self, model_type): + print(" Name format: language/dataset/model") + models_name_list = [] + model_count = 1 + model_type = "tts_models" + models_name_list.extend(self._list_models(model_type, model_count)) + return [name.replace(model_type + "/", "") for name in models_name_list] + + def list_models(self): + print(" Name format: type/language/dataset/model") + models_name_list = [] + model_count = 1 + for model_type in self.models_dict: + model_list = self._list_models(model_type, model_count) + models_name_list.extend(model_list) + return models_name_list + + def model_info_by_idx(self, model_query): + """Print the description of the model from .models.json file using model_idx + + Args: + model_query (str): / + """ + model_name_list = [] + model_type, model_query_idx = model_query.split("/") + try: + model_query_idx = int(model_query_idx) + if model_query_idx <= 0: + print("> model_query_idx should be a positive integer!") + return + except: + print("> model_query_idx should be an integer!") + return + model_count = 0 + if model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + else: + print(f"> model_type {model_type} does not exist in the list.") + return + if model_query_idx > model_count: + print(f"model query idx exceeds the number of available models [{model_count}] ") + else: + model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") + print(f"> model type : {model_type}") + print(f"> language supported : {lang}") + print(f"> dataset used : {dataset}") + print(f"> model name : {model}") + if "description" in self.models_dict[model_type][lang][dataset][model]: + print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}") + else: + print("> description : coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}") + + def model_info_by_full_name(self, model_query_name): + """Print the description of the model from .models.json file using model_full_name + + Args: + model_query_name (str): Format is /// + """ + model_type, lang, dataset, model = model_query_name.split("/") + if model_type in self.models_dict: + if lang in self.models_dict[model_type]: + if dataset in self.models_dict[model_type][lang]: + if model in self.models_dict[model_type][lang][dataset]: + print(f"> model type : {model_type}") + print(f"> language supported : {lang}") + print(f"> dataset used : {dataset}") + print(f"> model name : {model}") + if "description" in self.models_dict[model_type][lang][dataset][model]: + print( + f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}" + ) + else: + print("> description : coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + print( + f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}" + ) + else: + print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.") + else: + print(f"> dataset {dataset} does not exist for {model_type}/{lang}.") + else: + print(f"> lang {lang} does not exist for {model_type}.") + else: + print(f"> model_type {model_type} does not exist in the list.") + + def list_tts_models(self): + """Print all `TTS` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("tts_models") + + def list_vocoder_models(self): + """Print all the `vocoder` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("vocoder_models") + + def list_langs(self): + """Print all the available languages""" + print(" Name format: type/language") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + print(f" >: {model_type}/{lang} ") + + def list_datasets(self): + """Print all the datasets""" + print(" Name format: type/language/dataset") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + print(f" >: {model_type}/{lang}/{dataset}") + + @staticmethod + def print_model_license(model_item: Dict): + """Print the license of a model + + Args: + model_item (dict): model item in the models.json + """ + if "license" in model_item and model_item["license"].strip() != "": + print(f" > Model's license - {model_item['license']}") + if model_item["license"].lower() in LICENSE_URLS: + print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") + else: + print(" > Check https://opensource.org/licenses for more info.") + else: + print(" > Model's license - No license information available") + + def download_model(self, model_name): + """Download model files given the full model name. + Model name is in the format + 'type/language/dataset/model' + e.g. 'tts_model/en/ljspeech/tacotron' + + Every model must have the following files: + - *.pth : pytorch model checkpoint file. + - config.json : model config file. + - scale_stats.npy (if exist): scale values for preprocessing. + + Args: + model_name (str): model name as explained above. + """ + # fetch model info from the dict + model_type, lang, dataset, model = model_name.split("/") + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + model_item = self.models_dict[model_type][lang][dataset][model] + # set the model specific output path + output_path = os.path.join(self.output_prefix, model_full_name) + if os.path.exists(output_path): + print(f" > {model_name} is already downloaded.") + else: + os.makedirs(output_path, exist_ok=True) + print(f" > Downloading model to {output_path}") + # download from github release + self._download_zip_file(model_item["github_rls_url"], output_path) + self.print_model_license(model_item=model_item) + # find downloaded files + output_model_path, output_config_path = self._find_files(output_path) + # update paths in the config.json + self._update_paths(output_path, output_config_path) + return output_model_path, output_config_path, model_item + + @staticmethod + def _find_files(output_path: str) -> Tuple[str, str]: + """Find the model and config files in the output path + + Args: + output_path (str): path to the model files + + Returns: + Tuple[str, str]: path to the model file and config file + """ + model_file = None + config_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + model_file = os.path.join(output_path, file_name) + elif file_name == "config.json": + config_file = os.path.join(output_path, file_name) + if model_file is None: + raise ValueError(" [!] Model file not found in the output path") + if config_file is None: + raise ValueError(" [!] Config file not found in the output path") + return model_file, config_file + + @staticmethod + def _find_speaker_encoder(output_path: str) -> str: + """Find the speaker encoder file in the output path + + Args: + output_path (str): path to the model files + + Returns: + str: path to the speaker encoder file + """ + speaker_encoder_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_se.pth", "model_se.pth.tar"]: + speaker_encoder_file = os.path.join(output_path, file_name) + return speaker_encoder_file + + def _update_paths(self, output_path: str, config_path: str) -> None: + """Update paths for certain files in config.json after download. + + Args: + output_path (str): local path the model is downloaded to. + config_path (str): local config.json path. + """ + output_stats_path = os.path.join(output_path, "scale_stats.npy") + output_d_vector_file_path = os.path.join(output_path, "speakers.json") + output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + speaker_encoder_config_path = os.path.join(output_path, "config_se.json") + speaker_encoder_model_path = self._find_speaker_encoder(output_path) + + # update the scale_path.npy file path in the model config.json + self._update_path("audio.stats_path", output_stats_path, config_path) + + # update the speakers.json file path in the model config.json to the current path + self._update_path("d_vector_file", output_d_vector_file_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path) + + # update the speaker_ids.json file path in the model config.json to the current path + self._update_path("speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + + # update the speaker_encoder file path in the model config.json to the current path + self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) + self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) + + @staticmethod + def _update_path(field_name, new_path, config_path): + """Update the path in the model config.json for the current environment after download""" + if new_path and os.path.exists(new_path): + config = load_config(config_path) + field_names = field_name.split(".") + if len(field_names) > 1: + # field name points to a sub-level field + sub_conf = config + for fd in field_names[:-1]: + if fd in sub_conf: + sub_conf = sub_conf[fd] + else: + return + sub_conf[field_names[-1]] = new_path + else: + # field name points to a top-level field + config[field_name] = new_path + config.save_json(config_path) + + @staticmethod + def _download_zip_file(file_url, output_folder): + """Download the github releases""" + # download the file + r = requests.get(file_url) + # extract the file + try: + with zipfile.ZipFile(io.BytesIO(r.content)) as z: + z.extractall(output_folder) + except zipfile.BadZipFile: + print(f" > Error: Bad zip file - {file_url}") + raise zipfile.BadZipFile # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in z.namelist()[1:]: + src_path = os.path.join(output_folder, file_path) + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + copyfile(src_path, dst_path) + # remove the extracted folder + rmtree(os.path.join(output_folder, z.namelist()[0])) + + @staticmethod + def _check_dict_key(my_dict, key): + if key in my_dict.keys() and my_dict[key] is not None: + if not isinstance(key, str): + return True + if isinstance(key, str) and len(my_dict[key]) > 0: + return True + return False diff --git a/TTS/utils/radam.py b/TTS/utils/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..73426e6433bc03dfa4d0a2e2eca43d5ed4e919e7 --- /dev/null +++ b/TTS/utils/radam.py @@ -0,0 +1,107 @@ +# modified from https://github.com/LiyuanLucasLiu/RAdam + +import math + +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)] + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # pylint: disable=useless-super-delegation + super().__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + + return loss diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7c065eaf12dedbe02b2e8f7ead9e55d020ca2d75 --- /dev/null +++ b/TTS/utils/synthesizer.py @@ -0,0 +1,427 @@ +import time +from typing import List + +import numpy as np +import pysbd +import torch + +from TTS.config import load_config +from TTS.encoder.models.resnet import ResNetSpeakerEncoder +from TTS.tts.configs.shared_configs import BaseAudioConfig +from TTS.tts.models import setup_model as setup_tts_model + +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import +from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.models import setup_model as setup_vocoder_model +from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input + + +class Synthesizer(object): + def __init__( + self, + tts_checkpoint: str, + tts_config_path: str, + tts_speakers_file: str = "", + tts_languages_file: str = "", + vocoder_checkpoint: str = "", + vocoder_config: str = "", + encoder_checkpoint: str = "", + encoder_config: str = "", + use_cuda: bool = False, + ) -> None: + """General 🐸 TTS interface for inference. It takes a tts and a vocoder + model and synthesize speech from the provided text. + + The text is divided into a list of sentences using `pysbd` and synthesize + speech on each sentence separately. + + If you have certain special characters in your text, you need to handle + them before providing the text to Synthesizer. + + TODO: set the segmenter based on the source language + + Args: + tts_checkpoint (str): path to the tts model file. + tts_config_path (str): path to the tts config file. + vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. + vocoder_config (str, optional): path to the vocoder config file. Defaults to None. + encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`, + encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`, + use_cuda (bool, optional): enable/disable cuda. Defaults to False. + """ + self.tts_checkpoint = tts_checkpoint + self.tts_config_path = tts_config_path + self.tts_speakers_file = tts_speakers_file + self.tts_languages_file = tts_languages_file + self.vocoder_checkpoint = vocoder_checkpoint + self.vocoder_config = vocoder_config + self.encoder_checkpoint = encoder_checkpoint + self.encoder_config = encoder_config + self.use_cuda = use_cuda + + self.tts_model = None + self.vocoder_model = None + self.speaker_manager = None + self.num_speakers = 0 + self.tts_speakers = {} + self.language_manager = None + self.num_languages = 0 + self.tts_languages = {} + self.d_vector_dim = 0 + self.seg = self._get_segmenter("en") + self.use_cuda = use_cuda + + if self.use_cuda: + assert torch.cuda.is_available(), "CUDA is not availabe on this machine." + self._load_tts(tts_checkpoint, tts_config_path, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + if vocoder_checkpoint: + self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] + + @staticmethod + def _get_segmenter(lang: str): + """get the sentence segmenter for the given language. + + Args: + lang (str): target language code. + + Returns: + [type]: [description] + """ + return pysbd.Segmenter(language=lang, clean=True) + + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: + """Load the TTS model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + 5. Init the speaker manager in the model. + + Args: + tts_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.tts_config = load_config(tts_config_path) + if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: + raise ValueError("Phonemizer is not defined in the TTS config.") + + self.tts_model = setup_tts_model(config=self.tts_config) + + if not self.encoder_checkpoint: + self._set_speaker_encoder_paths_from_tts_config() + + self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) + if use_cuda: + self.tts_model.cuda() + + self.use_zero_shot_speaker_encoder = False + if self.encoder_checkpoint and self.encoder_config and hasattr(self.tts_model, "speaker_manager"): + self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda) + elif self.encoder_checkpoint and self.encoder_config is None: + self.use_zero_shot_speaker_encoder = True + del self.tts_model.emb_g + state_dict = torch.load(self.encoder_checkpoint)['state_dict'] + state_dict = {k.split('.', 1)[1]:v for k,v in state_dict.items() if k.startswith('speaker_encoder')} + self.zero_shot_speaker_encoder = ResNetSpeakerEncoder( + input_dim=self.tts_config['model_args']['out_channels'], + proj_dim=self.tts_config['model_args']['hidden_channels'], + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=BaseAudioConfig( + **self.tts_config['audio'] + ), + ) + self.zero_shot_speaker_encoder.load_state_dict(state_dict) + if use_cuda: + self.zero_shot_speaker_encoder.cuda() + print("| Loaded zero-shot speaker encoder.") + + def _set_speaker_encoder_paths_from_tts_config(self): + """Set the encoder paths from the tts model config for models with speaker encoders.""" + if hasattr(self.tts_config, "model_args") and hasattr( + self.tts_config.model_args, "speaker_encoder_config_path" + ): + self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path + self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path + + def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: + """Load the vocoder model. + + 1. Load the vocoder config. + 2. Init the AudioProcessor for the vocoder. + 3. Init the vocoder model from the config. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + model_file (str): path to the model checkpoint. + model_config (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + self.vocoder_config = load_config(model_config) + self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) + self.vocoder_model = setup_vocoder_model(self.vocoder_config) + self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) + if use_cuda: + self.vocoder_model.cuda() + + def split_into_sentences(self, text) -> List[str]: + """Split give text into sentences. + + Args: + text (str): input text in string format. + + Returns: + List[str]: list of sentences. + """ + return self.seg.segment(text) + + def save_wav(self, wav: List[int], path: str) -> None: + """Save the waveform as a file. + + Args: + wav (List[int]): waveform as a list of values. + path (str): output path to save the waveform. + """ + wav = np.array(wav) + self.tts_model.ap.save_wav(wav, path, self.output_sample_rate) + + def tts( + self, + text: str = "", + speaker_name: str = "", + language_name: str = "", + speaker_wav=None, + style_wav=None, + style_text=None, + reference_wav=None, + reference_speaker_name=None, + ) -> List[int]: + """🐸 TTS magic. Run all the models and generate speech. + + Args: + text (str): input text. + speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "". + language_name (str, optional): language id for multi-language models. Defaults to "". + speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None. + style_wav ([type], optional): style waveform for GST. Defaults to None. + style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None. + reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None. + reference_speaker_name ([type], optional): spekaer id of reference waveform. Defaults to None. + Returns: + List[int]: [description] + """ + start_time = time.time() + wavs = [] + + if not text and not reference_wav: + raise ValueError( + "You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API." + ) + + if text: + sens = self.split_into_sentences(text) + print(" > Text splitted to sentences.") + print(sens) + + # handle multi-speaker + speaker_embedding = None + speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"): + if speaker_name and isinstance(speaker_name, str): + if self.tts_config.use_d_vector_file: + # get the average speaker embedding from the saved d_vectors. + speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding( + speaker_name, num_samples=None, randomize=False + ) + speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + speaker_id = self.tts_model.speaker_manager.ids[speaker_name] + + elif not speaker_name and not speaker_wav: + raise ValueError( + " [!] Look like you use a multi-speaker model. " + "You need to define either a `speaker_name` or a `speaker_wav` to use a multi-speaker model." + ) + else: + speaker_embedding = None + else: + if speaker_name: + raise ValueError( + f" [!] Missing speakers.json file path for selecting speaker {speaker_name}." + "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " + ) + + # handle multi-lingaul + language_id = None + if self.tts_languages_file or ( + hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None + ): + if language_name and isinstance(language_name, str): + language_id = self.tts_model.language_manager.ids[language_name] + + elif not language_name: + raise ValueError( + " [!] Look like you use a multi-lingual model. " + "You need to define either a `language_name` or a `style_wav` to use a multi-lingual model." + ) + + else: + raise ValueError( + f" [!] Missing language_ids.json file path for selecting language {language_name}." + "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " + ) + + # compute a new d_vector from the given clip. + if speaker_wav is not None: + if self.use_zero_shot_speaker_encoder: + wav = self.tts_model.ap.load_wav(speaker_wav, sr=22050) + mel = self.tts_model.ap.melspectrogram(wav).astype("float32") + mel = torch.FloatTensor(mel).contiguous().unsqueeze(0) + with torch.no_grad(): + speaker_embedding = self.zero_shot_speaker_encoder(mel)[0] + else: + speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + + use_gl = self.vocoder_model is None + + if not reference_wav: + for sen in sens: + # synthesize voice + outputs = synthesis( + model=self.tts_model, + text=sen, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + speaker_id=speaker_id, + style_wav=style_wav, + style_text=style_text, + use_griffin_lim=use_gl, + d_vector=speaker_embedding, + language_id=language_id, + ) + waveform = outputs["wav"] + mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() + if not use_gl: + # denormalize tts output based on tts audio config + # ### + # import matplotlib.pyplot as plt + # import seaborn as sns + # img=sns.heatmap(mel_postnet_spec.T) + # fig = img.get_figure() + # fig.savefig('output/fp_1.png') + # fig.clf() + # ### + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + # ### + # import matplotlib.pyplot as plt + # import seaborn as sns + # img=sns.heatmap(mel_postnet_spec.T) + # fig = img.get_figure() + # fig.savefig('output/fp_2.png') + # fig.clf() + # ### + device_type = "cuda" if self.use_cuda else "cpu" + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + print(" > interpolating tts model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) + if self.use_cuda and not use_gl: + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + waveform = waveform.squeeze() + + # trim silence + if self.tts_config.audio["do_trim_silence"] is True: + waveform = trim_silence(waveform, self.tts_model.ap) + + wavs += list(waveform) + wavs += [0] * 10000 + else: + # get the speaker embedding or speaker id for the reference wav file + reference_speaker_embedding = None + reference_speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "ids"): + if reference_speaker_name and isinstance(reference_speaker_name, str): + if self.tts_config.use_d_vector_file: + # get the speaker embedding from the saved d_vectors. + reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name( + reference_speaker_name + )[0] + reference_speaker_embedding = np.array(reference_speaker_embedding)[ + None, : + ] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + reference_speaker_id = self.tts_model.speaker_manager.ids[reference_speaker_name] + else: + reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip( + reference_wav + ) + + outputs = transfer_voice( + model=self.tts_model, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + reference_wav=reference_wav, + speaker_id=speaker_id, + d_vector=speaker_embedding, + use_griffin_lim=use_gl, + reference_speaker_id=reference_speaker_id, + reference_d_vector=reference_speaker_embedding, + ) + waveform = outputs + if not use_gl: + mel_postnet_spec = outputs[0].detach().cpu().numpy() + # denormalize tts output based on tts audio config + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + device_type = "cuda" if self.use_cuda else "cpu" + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + print(" > interpolating tts model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) + if self.use_cuda: + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + wavs = waveform.squeeze() + + # compute stats + process_time = time.time() - start_time + audio_time = len(wavs) / self.tts_config.audio["sample_rate"] + print(f" > Processing time: {process_time}") + print(f" > Real-time factor: {process_time / audio_time}") + return wavs diff --git a/TTS/utils/training.py b/TTS/utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..b51f55e92b56bece69ae61f99f68b48c88938261 --- /dev/null +++ b/TTS/utils/training.py @@ -0,0 +1,44 @@ +import numpy as np +import torch + + +def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): + r"""Check model gradient against unexpected jumps and failures""" + skip_flag = False + if ignore_stopnet: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_( + [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + else: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + + # compatibility with different torch versions + if isinstance(grad_norm, float): + if np.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + else: + if torch.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + return grad_norm, skip_flag + + +def gradual_training_scheduler(global_step, config): + """Setup the gradual training schedule wrt number + of active GPUs""" + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + num_gpus = 1 + new_values = None + # we set the scheduling wrt num_gpus + for values in config.gradual_training: + if global_step * num_gpus >= values[0]: + new_values = values + return new_values[1], new_values[2] diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py new file mode 100644 index 0000000000000000000000000000000000000000..033b911a7c188cb90ed342e579e0d428e648e9b8 --- /dev/null +++ b/TTS/utils/vad.py @@ -0,0 +1,81 @@ +import torch +import torchaudio + + +def read_audio(path): + wav, sr = torchaudio.load(path) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + return wav.squeeze(0), sr + + +def resample_wav(wav, sr, new_sr): + wav = wav.unsqueeze(0) + transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr) + wav = transform(wav) + return wav.squeeze(0) + + +def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False): + factor = new_sr / vad_sr + new_timestamps = [] + if just_begging_end and timestamps: + # get just the start and end timestamps + new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)} + new_timestamps.append(new_dict) + else: + for ts in timestamps: + # map to the new SR + new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)} + new_timestamps.append(new_dict) + + return new_timestamps + + +def get_vad_model_and_utils(use_cuda=False): + model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=False) + if use_cuda: + model = model.cuda() + + get_speech_timestamps, save_audio, _, _, collect_chunks = utils + return model, get_speech_timestamps, save_audio, collect_chunks + + +def remove_silence( + model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False +): + + # get the VAD model and utils functions + model, get_speech_timestamps, save_audio, collect_chunks = model_and_utils + + # read ground truth wav and resample the audio for the VAD + wav, gt_sample_rate = read_audio(audio_path) + + # if needed, resample the audio for the VAD model + if gt_sample_rate != vad_sample_rate: + wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate) + else: + wav_vad = wav + + if use_cuda: + wav_vad = wav_vad.cuda() + + # get speech timestamps from full audio file + speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768) + + # map the current speech_timestamps to the sample rate of the ground truth audio + new_speech_timestamps = map_timestamps_to_new_sr( + vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end + ) + + # if have speech timestamps else save the wav + if new_speech_timestamps: + wav = collect_chunks(new_speech_timestamps, wav) + else: + print(f"> The file {audio_path} probably does not have speech please check it !!") + + # save audio + save_audio(out_path, wav, sampling_rate=gt_sample_rate) + return out_path diff --git a/TTS/vocoder/README.md b/TTS/vocoder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b9fb17c8f09fa6e8c217087e31fb8c52d96da536 --- /dev/null +++ b/TTS/vocoder/README.md @@ -0,0 +1,39 @@ +# Mozilla TTS Vocoders (Experimental) + +Here there are vocoder model implementations which can be combined with the other TTS models. + +Currently, following models are implemented: + +- Melgan +- MultiBand-Melgan +- ParallelWaveGAN +- GAN-TTS (Discriminator Only) + +It is also very easy to adapt different vocoder models as we provide a flexible and modular (but not too modular) framework. + +## Training a model + +You can see here an example (Soon)[Colab Notebook]() training MelGAN with LJSpeech dataset. + +In order to train a new model, you need to gather all wav files into a folder and give this folder to `data_path` in '''config.json''' + +You need to define other relevant parameters in your ```config.json``` and then start traning with the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --config_path path/to/config.json``` + +Example config files can be found under `tts/vocoder/configs/` folder. + +You can continue a previous training run by the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --continue_path path/to/your/model/folder``` + +You can fine-tune a pre-trained model by the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --restore_path path/to/your/model.pth``` + +Restoring a model starts a new training in a different folder. It only restores model weights with the given checkpoint file. However, continuing a training starts from the same directory where the previous training run left off. + +You can also follow your training runs on Tensorboard as you do with our TTS models. + +## Acknowledgement +Thanks to @kan-bayashi for his [repository](https://github.com/kan-bayashi/ParallelWaveGAN) being the start point of our work. diff --git a/TTS/vocoder/__init__.py b/TTS/vocoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vocoder/configs/__init__.py b/TTS/vocoder/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e11b990c6d7294e7cb00c3e024bbb5f94a8105 --- /dev/null +++ b/TTS/vocoder/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +configs_dir = os.path.dirname(__file__) +for file in os.listdir(configs_dir): + path = os.path.join(configs_dir, file) + if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): + config_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("TTS.vocoder.configs." + config_name) + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + + if isclass(attribute): + # Add the class to this package's variables + globals()[attribute_name] = attribute diff --git a/TTS/vocoder/configs/fullband_melgan_config.py b/TTS/vocoder/configs/fullband_melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab83aace678e328a8f99a5f0dc63e54ed99d4c4 --- /dev/null +++ b/TTS/vocoder/configs/fullband_melgan_config.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, field + +from .shared_configs import BaseGANVocoderConfig + + +@dataclass +class FullbandMelganConfig(BaseGANVocoderConfig): + """Defines parameters for FullBand MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import FullbandMelganConfig + >>> config = FullbandMelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fullband_melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "fullband_melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]} + ) + generator_model: str = "melgan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 4} + ) + + # Training - overrides + batch_size: int = 16 + seq_len: int = 8192 + pad_short: int = 2000 + use_noise_augment: bool = True + use_cache: bool = True + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0.0 diff --git a/TTS/vocoder/configs/hifigan_config.py b/TTS/vocoder/configs/hifigan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f76bb14c094808e49fda279b6e185ed1d63241d3 --- /dev/null +++ b/TTS/vocoder/configs/hifigan_config.py @@ -0,0 +1,138 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class HifiganConfig(BaseGANVocoderConfig): + """Defines parameters for FullBand MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import HifiganConfig + >>> config = HifiganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `hifigan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'hifigan_discriminator`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `hifigan_generator`. + generator_model_params (dict): Parameters of the generator model. Defaults to + ` + { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ` + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): + STFT loss parameters. Default to + `{ + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240] + }` + l1_spec_loss_params (dict): + L1 spectrogram loss parameters. Default to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "hifigan" + # model specific params + discriminator_model: str = "hifigan_discriminator" + generator_model: str = "hifigan_generator" + generator_model_params: dict = field( + default_factory=lambda: { + "upsample_factors": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_type": "1", + } + ) + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = False + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = True + + # loss weights - overrides + stft_loss_weight: float = 0 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 45 + l1_spec_loss_params: dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + # optimizer parameters + lr: float = 1e-4 + wd: float = 1e-6 diff --git a/TTS/vocoder/configs/melgan_config.py b/TTS/vocoder/configs/melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dc35b6f8b70891d4904baefad802d9c62fe67925 --- /dev/null +++ b/TTS/vocoder/configs/melgan_config.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class MelganConfig(BaseGANVocoderConfig): + """Defines parameters for MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import MelganConfig + >>> config = MelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]} + ) + generator_model: str = "melgan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 3} + ) + + # Training - overrides + batch_size: int = 16 + seq_len: int = 8192 + pad_short: int = 2000 + use_noise_augment: bool = True + use_cache: bool = True + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0 diff --git a/TTS/vocoder/configs/multiband_melgan_config.py b/TTS/vocoder/configs/multiband_melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..763113537f36a8615b2b77369bf5bde01527fe53 --- /dev/null +++ b/TTS/vocoder/configs/multiband_melgan_config.py @@ -0,0 +1,144 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class MultibandMelganConfig(BaseGANVocoderConfig): + """Defines parameters for MultiBandMelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import MultibandMelganConfig + >>> config = MultibandMelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `multiband_melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{ + "base_channels": 16, + "max_channels": 512, + "downsample_factors": [4, 4, 4] + }` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + generator_model_param (dict): + The generator model parameters. Defaults to `{"upsample_factors": [8, 4, 2], "num_res_blocks": 4}`. + use_pqmf (bool): + enable / disable PQMF modulation for multi-band training. Defaults to True. + lr_gen (float): + Initial learning rate for the generator model. Defaults to 0.0001. + lr_disc (float): + Initial learning rate for the discriminator model. Defaults to 0.0001. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `MultiStepLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to + `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `MultiStepLR`. + lr_scheduler_dict_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to + `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + use_stft_loss (bool):` + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "multiband_melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]} + ) + generator_model: str = "multiband_melgan_generator" + generator_model_params: dict = field(default_factory=lambda: {"upsample_factors": [8, 4, 2], "num_res_blocks": 4}) + use_pqmf: bool = True + + # optimizer - overrides + lr_gen: float = 0.0001 # Initial learning rate. + lr_disc: float = 0.0001 # Initial learning rate. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + lr_scheduler_gen: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) + lr_scheduler_disc: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) + + # Training - overrides + batch_size: int = 64 + seq_len: int = 16384 + pad_short: int = 2000 + use_noise_augment: bool = False + use_cache: bool = True + steps_to_start_discriminator: bool = 200000 + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = True + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + subband_stft_loss_params: dict = field( + default_factory=lambda: {"n_ffts": [384, 683, 171], "hop_lengths": [30, 60, 10], "win_lengths": [150, 300, 60]} + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0 diff --git a/TTS/vocoder/configs/parallel_wavegan_config.py b/TTS/vocoder/configs/parallel_wavegan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7845dd6bf835ebab4cc5d8b65962b7347b7711cf --- /dev/null +++ b/TTS/vocoder/configs/parallel_wavegan_config.py @@ -0,0 +1,133 @@ +from dataclasses import dataclass, field + +from .shared_configs import BaseGANVocoderConfig + + +@dataclass +class ParallelWaveganConfig(BaseGANVocoderConfig): + """Defines parameters for ParallelWavegan vocoder. + + Args: + model (str): + Model name used for selecting the right configuration at initialization. Defaults to `gan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'parallel_wavegan_discriminator`. + discriminator_model_params (dict): The discriminator model kwargs. Defaults to + '{"num_layers": 10}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `parallel_wavegan_generator`. + generator_model_param (dict): + The generator model kwargs. Defaults to `{"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + use_stft_loss (bool):` + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 0. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + lr_gen (float): + Generator model initial learning rate. Defaults to 0.0002. + lr_disc (float): + Discriminator model initial learning rate. Defaults to 0.0002. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `ExponentialLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. + lr_scheduler_dict_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. + """ + + model: str = "parallel_wavegan" + + # Model specific params + discriminator_model: str = "parallel_wavegan_discriminator" + discriminator_model_params: dict = field(default_factory=lambda: {"num_layers": 10}) + generator_model: str = "parallel_wavegan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30} + ) + + # Training - overrides + batch_size: int = 6 + seq_len: int = 25600 + pad_short: int = 2000 + use_noise_augment: bool = False + use_cache: bool = True + steps_to_start_discriminator: int = 200000 + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 0 + l1_spec_loss_weight: float = 0 + + # optimizer overrides + lr_gen: float = 0.0002 # Initial learning rate. + lr_disc: float = 0.0002 # Initial learning rate. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) + lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1} + ) + scheduler_after_epoch: bool = False diff --git a/TTS/vocoder/configs/shared_configs.py b/TTS/vocoder/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..a558cfcabbc2abc26be60065d3ac75cebd829f28 --- /dev/null +++ b/TTS/vocoder/configs/shared_configs.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field + +from TTS.config import BaseAudioConfig, BaseTrainingConfig + + +@dataclass +class BaseVocoderConfig(BaseTrainingConfig): + """Shared parameters among all the vocoder models. + Args: + audio (BaseAudioConfig): + Audio processor config instance. Defaultsto `BaseAudioConfig()`. + use_noise_augment (bool): + Augment the input audio with random noise. Defaults to False/ + eval_split_size (int): + Number of instances used for evaluation. Defaults to 10. + data_path (str): + Root path of the training data. All the audio files found recursively from this root path are used for + training. Defaults to `""`. + feature_path (str): + Root path to the precomputed feature files. Defaults to None. + seq_len (int): + Length of the waveform segments used for training. Defaults to 1000. + pad_short (int): + Extra padding for the waveforms shorter than `seq_len`. Defaults to 0. + conv_path (int): + Extra padding for the feature frames against convolution of the edge frames. Defaults to MISSING. + Defaults to 0. + use_cache (bool): + enable / disable in memory caching of the computed features. If the RAM is not enough, if may cause OOM. + Defaults to False. + epochs (int): + Number of training epochs to. Defaults to 10000. + wd (float): + Weight decay. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # dataloading + use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms. + eval_split_size: int = 10 # number of samples used for evaluation. + # dataset + data_path: str = "" # root data path. It finds all wav files recursively from there. + feature_path: str = None # if you use precomputed features + seq_len: int = 1000 # signal length used in training. + pad_short: int = 0 # additional padding for short wavs + conv_pad: int = 0 # additional padding against convolutions applied to spectrograms + use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM. + # OPTIMIZER + epochs: int = 10000 # total number of epochs to train. + wd: float = 0.0 # Weight decay weight. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + + +@dataclass +class BaseGANVocoderConfig(BaseVocoderConfig): + """Base config class used among all the GAN based vocoders. + Args: + use_stft_loss (bool): + enable / disable the use of STFT loss. Defaults to True. + use_subband_stft_loss (bool): + enable / disable the use of Subband STFT loss. Defaults to True. + use_mse_gan_loss (bool): + enable / disable the use of Mean Squared Error based GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable the use of Hinge GAN loss. Defaults to True. + use_feat_match_loss (bool): + enable / disable feature matching loss. Defaults to True. + use_l1_spec_loss (bool): + enable / disable L1 spectrogram loss. Defaults to True. + stft_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + subband_stft_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + mse_G_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 1. + hinge_G_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + feat_match_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 100. + l1_spec_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 45. + stft_loss_params (dict): + Parameters for the STFT loss. Defaults to `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`. + l1_spec_loss_params (dict): + Parameters for the L1 spectrogram loss. Defaults to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + target_loss (str): + Target loss name that defines the quality of the model. Defaults to `G_avg_loss`. + grad_clip (list): + A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping. + Defaults to [5, 5]. + lr_gen (float): + Generator model initial learning rate. Defaults to 0.0002. + lr_disc (float): + Discriminator model initial learning rate. Defaults to 0.0002. + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `ExponentialLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. + lr_scheduler_disc_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + scheduler_after_epoch (bool): + Whether to update the learning rate schedulers after each epoch. Defaults to True. + use_pqmf (bool): + enable / disable PQMF for subband approximation at training. Defaults to False. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + diff_samples_for_G_and_D (bool): + enable / disable use of different training samples for the generator and the discriminator iterations. + Enabling it results in slower iterations but faster convergance in some cases. Defaults to False. + """ + + model: str = "gan" + + # LOSS PARAMETERS + use_stft_loss: bool = True + use_subband_stft_loss: bool = True + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = True + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = True + + # loss weights + stft_loss_weight: float = 0 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 100 + l1_spec_loss_weight: float = 45 + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + l1_spec_loss_params: dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch + + # optimizer + grad_clip: float = field(default_factory=lambda: [5, 5]) + lr_gen: float = 0.0002 # Initial learning rate. + lr_disc: float = 0.0002 # Initial learning rate. + lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + scheduler_after_epoch: bool = True + + use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN) + steps_to_start_discriminator = 0 # start training the discriminator after this number of steps. + diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps. diff --git a/TTS/vocoder/configs/univnet_config.py b/TTS/vocoder/configs/univnet_config.py new file mode 100644 index 0000000000000000000000000000000000000000..67f324cfce5f701f0d7453beab81590bef6be114 --- /dev/null +++ b/TTS/vocoder/configs/univnet_config.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass, field +from typing import Dict + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class UnivnetConfig(BaseGANVocoderConfig): + """Defines parameters for UnivNet vocoder. + + Example: + + >>> from TTS.vocoder.configs import UnivNetConfig + >>> config = UnivNetConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `UnivNet`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'UnivNet_discriminator`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `UnivNet_generator`. + generator_model_params (dict): Parameters of the generator model. Defaults to + ` + { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ` + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 32. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False. + stft_loss_params (dict): + STFT loss parameters. Default to + `{ + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240] + }` + l1_spec_loss_params (dict): + L1 spectrogram loss parameters. Default to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "univnet" + batch_size: int = 32 + # model specific params + discriminator_model: str = "univnet_discriminator" + generator_model: str = "univnet_generator" + generator_model_params: Dict = field( + default_factory=lambda: { + "in_channels": 64, + "out_channels": 1, + "hidden_channels": 32, + "cond_channels": 80, + "upsample_factors": [8, 8, 4], + "lvc_layers_each_block": 4, + "lvc_kernel_size": 3, + "kpnet_hidden_channels": 64, + "kpnet_conv_size": 3, + "dropout": 0.0, + } + ) + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet) + use_l1_spec_loss: bool = False + + # loss weights - overrides + stft_loss_weight: float = 2.5 + stft_loss_params: Dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 0 + l1_spec_loss_weight: float = 0 + l1_spec_loss_params: Dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + # optimizer parameters + lr_gen: float = 1e-4 # Initial learning rate. + lr_disc: float = 1e-4 # Initial learning rate. + lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + # lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + # lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0}) + steps_to_start_discriminator: int = 200000 + + def __post_init__(self): + super().__post_init__() + self.generator_model_params["cond_channels"] = self.audio.num_mels diff --git a/TTS/vocoder/configs/wavegrad_config.py b/TTS/vocoder/configs/wavegrad_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c39813ae68c3d8c77614c9a5188ac5f2a59d991d --- /dev/null +++ b/TTS/vocoder/configs/wavegrad_config.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseVocoderConfig +from TTS.vocoder.models.wavegrad import WavegradArgs + + +@dataclass +class WavegradConfig(BaseVocoderConfig): + """Defines parameters for WaveGrad vocoder. + Example: + + >>> from TTS.vocoder.configs import WavegradConfig + >>> config = WavegradConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `wavegrad`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `wavegrad`. + model_params (WavegradArgs): Model parameters. Check `WavegradArgs` for default values. + target_loss (str): + Target loss name that defines the quality of the model. Defaults to `avg_wavegrad_loss`. + epochs (int): + Number of epochs to traing the model. Defaults to 10000. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 96. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 6144. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + mixed_precision (bool): + enable / disable mixed precision training. Default is True. + eval_split_size (int): + Number of samples used for evalutaion. Defaults to 50. + train_noise_schedule (dict): + Training noise schedule. Defaults to + `{"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}` + test_noise_schedule (dict): + Inference noise schedule. For a better performance, you may need to use `bin/tune_wavegrad.py` to find a + better schedule. Defaults to + ` + { + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50, + } + ` + grad_clip (float): + Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 1.0 + lr (float): + Initila leraning rate. Defaults to 1e-4. + lr_scheduler (str): + One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`. + lr_scheduler_params (dict): + kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}` + """ + + model: str = "wavegrad" + # Model specific params + generator_model: str = "wavegrad" + model_params: WavegradArgs = field(default_factory=WavegradArgs) + target_loss: str = "loss" # loss value to pick the best model to save after each epoch + + # Training - overrides + epochs: int = 10000 + batch_size: int = 96 + seq_len: int = 6144 + use_cache: bool = True + mixed_precision: bool = True + eval_split_size: int = 50 + + # NOISE SCHEDULE PARAMS + train_noise_schedule: dict = field(default_factory=lambda: {"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}) + + test_noise_schedule: dict = field( + default_factory=lambda: { # inference noise schedule. Try TTS/bin/tune_wavegrad.py to find the optimal values. + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50, + } + ) + + # optimizer overrides + grad_clip: float = 1.0 + lr: float = 1e-4 # Initial learning rate. + lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) diff --git a/TTS/vocoder/configs/wavernn_config.py b/TTS/vocoder/configs/wavernn_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f39400e5e50b56d4ff79c8c148fd518b3ec3b390 --- /dev/null +++ b/TTS/vocoder/configs/wavernn_config.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseVocoderConfig +from TTS.vocoder.models.wavernn import WavernnArgs + + +@dataclass +class WavernnConfig(BaseVocoderConfig): + """Defines parameters for Wavernn vocoder. + Example: + + >>> from TTS.vocoder.configs import WavernnConfig + >>> config = WavernnConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `wavernn`. + mode (str): + Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single + Gaussian Distribution and `bits` for quantized bits as the model's output. + mulaw (bool): + enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults + to `True`. + generator_model (str): + One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `WaveRNN`. + wavernn_model_params (dict): + kwargs for the WaveRNN model. Defaults to + `{ + "rnn_dims": 512, + "fc_dims": 512, + "compute_dims": 128, + "res_out_dims": 128, + "num_res_blocks": 10, + "use_aux_net": True, + "use_upsample_net": True, + "upsample_factors": [4, 8, 8] + }` + batched (bool): + enable / disable the batched inference. It speeds up the inference by splitting the input into segments and + processing the segments in a batch. Then it merges the outputs with a certain overlap and smoothing. If + you set it False, without CUDA, it is too slow to be practical. Defaults to True. + target_samples (int): + Size of the segments in batched mode. Defaults to 11000. + overlap_sampels (int): + Size of the overlap between consecutive segments. Defaults to 550. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 256. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 1280. + + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + mixed_precision (bool): + enable / disable mixed precision training. Default is True. + eval_split_size (int): + Number of samples used for evalutaion. Defaults to 50. + num_epochs_before_test (int): + Number of epochs waited to run the next evalution. Since inference takes some time, it is better to + wait some number of epochs not ot waste training time. Defaults to 10. + grad_clip (float): + Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 4.0 + lr (float): + Initila leraning rate. Defaults to 1e-4. + lr_scheduler (str): + One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`. + lr_scheduler_params (dict): + kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [200000, 400000, 600000]}` + """ + + model: str = "wavernn" + + # Model specific params + model_args: WavernnArgs = field(default_factory=WavernnArgs) + target_loss: str = "loss" + + # Inference + batched: bool = True + target_samples: int = 11000 + overlap_samples: int = 550 + + # Training - overrides + epochs: int = 10000 + batch_size: int = 256 + seq_len: int = 1280 + use_noise_augment: bool = False + use_cache: bool = True + mixed_precision: bool = True + eval_split_size: int = 50 + num_epochs_before_test: int = ( + 10 # number of epochs to wait until the next test run (synthesizing a full audio clip). + ) + + # optimizer overrides + grad_clip: float = 4.0 + lr: float = 1e-4 # Initial learning rate. + lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.5, "milestones": [200000, 400000, 600000]}) diff --git a/TTS/vocoder/datasets/__init__.py b/TTS/vocoder/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..871eb0d20276ffc691fd6da796bf65df6c23ea0d --- /dev/null +++ b/TTS/vocoder/datasets/__init__.py @@ -0,0 +1,58 @@ +from typing import List + +from coqpit import Coqpit +from torch.utils.data import Dataset + +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset + + +def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: + if config.model.lower() in "gan": + dataset = GANDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + verbose=verbose, + ) + dataset.shuffle_mapping() + elif config.model.lower() == "wavegrad": + dataset = WaveGradDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + verbose=verbose, + ) + elif config.model.lower() == "wavernn": + dataset = WaveRNNDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_params.pad, + mode=config.model_params.mode, + mulaw=config.model_params.mulaw, + is_training=not is_eval, + verbose=verbose, + ) + else: + raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") + return dataset diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a782067e1badef3522ac5b7d1b6407e3f291502a --- /dev/null +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -0,0 +1,153 @@ +import glob +import os +import random +from multiprocessing import Manager + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class GANDataset(Dataset): + """ + GAN Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly and returns + random segments of (audio, feature) couples. + """ + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + return_pairs=False, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): + super().__init__() + self.ap = ap + self.item_list = items + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.seq_len = seq_len + self.hop_len = hop_len + self.pad_short = pad_short + self.conv_pad = conv_pad + self.return_pairs = return_pairs + self.is_training = is_training + self.return_segments = return_segments + self.use_cache = use_cache + self.use_noise_augment = use_noise_augment + self.verbose = verbose + + assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." + self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) + + # map G and D instances + self.G_to_D_mappings = list(range(len(self.item_list))) + self.shuffle_mapping() + + # cache acoustic features + if use_cache: + self.create_feature_cache() + + def create_feature_cache(self): + self.manager = Manager() + self.cache = self.manager.list() + self.cache += [None for _ in range(len(self.item_list))] + + @staticmethod + def find_wav_files(path): + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, idx): + """Return different items for Generator and Discriminator and + cache acoustic features""" + + # set the seed differently for each worker + if torch.utils.data.get_worker_info(): + random.seed(torch.utils.data.get_worker_info().seed) + + if self.return_segments: + item1 = self.load_item(idx) + if self.return_pairs: + idx2 = self.G_to_D_mappings[idx] + item2 = self.load_item(idx2) + return item1, item2 + return item1 + item1 = self.load_item(idx) + return item1 + + def _pad_short_samples(self, audio, mel=None): + """Pad samples shorter than the output sequence length""" + if len(audio) < self.seq_len: + audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0) + + if mel is not None and mel.shape[1] < self.feat_frame_len: + pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0] + mel = np.pad( + mel, + ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), + mode="constant", + constant_values=pad_value.mean(), + ) + return audio, mel + + def shuffle_mapping(self): + random.shuffle(self.G_to_D_mappings) + + def load_item(self, idx): + """load (audio, feat) couple""" + if self.compute_feat: + # compute features from wav + wavpath = self.item_list[idx] + # print(wavpath) + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = self.ap.melspectrogram(audio) + audio, mel = self._pad_short_samples(audio, mel) + else: + + # load precomputed features + wavpath, feat_path = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = np.load(feat_path) + audio, mel = self._pad_short_samples(audio, mel) + + # correct the audio length wrt padding applied in stft + audio = np.pad(audio, (0, self.hop_len), mode="edge") + audio = audio[: mel.shape[-1] * self.hop_len] + assert ( + mel.shape[-1] * self.hop_len == audio.shape[-1] + ), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}" + + audio = torch.from_numpy(audio).float().unsqueeze(0) + mel = torch.from_numpy(mel).float().squeeze(0) + + if self.return_segments: + max_mel_start = mel.shape[1] - self.feat_frame_len + mel_start = random.randint(0, max_mel_start) + mel_end = mel_start + self.feat_frame_len + mel = mel[:, mel_start:mel_end] + + audio_start = mel_start * self.hop_len + audio = audio[:, audio_start : audio_start + self.seq_len] + + if self.use_noise_augment and self.is_training and self.return_segments: + audio = audio + (1 / 32768) * torch.randn_like(audio) + return (mel, audio) diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..0f69b812fa58949eadc78b450114f03b19e5c80c --- /dev/null +++ b/TTS/vocoder/datasets/preprocess.py @@ -0,0 +1,70 @@ +import glob +import os +from pathlib import Path + +import numpy as np +from coqpit import Coqpit +from tqdm import tqdm + +from TTS.utils.audio import AudioProcessor + + +def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): + """Process wav and compute mel and quantized wave signal. + It is mainly used by WaveRNN dataloader. + + Args: + out_path (str): Parent folder path to save the files. + config (Coqpit): Model config. + ap (AudioProcessor): Audio processor. + """ + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + wav_files = find_wav_files(config.data_path) + for path in tqdm(wav_files): + wav_name = Path(path).stem + quant_path = os.path.join(out_path, "quant", wav_name + ".npy") + mel_path = os.path.join(out_path, "mel", wav_name + ".npy") + y = ap.load_wav(path) + mel = ap.melspectrogram(y) + np.save(mel_path, mel) + if isinstance(config.mode, int): + quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode) + np.save(quant_path, quant) + + +def find_wav_files(data_path, file_ext="wav"): + wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True) + return wav_paths + + +def find_feat_files(data_path): + feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True) + return feat_paths + + +def load_wav_data(data_path, eval_split_size, file_ext="wav"): + wav_paths = find_wav_files(data_path, file_ext=file_ext) + assert len(wav_paths) > 0, f" [!] {data_path} is empty." + np.random.seed(0) + np.random.shuffle(wav_paths) + return wav_paths[:eval_split_size], wav_paths[eval_split_size:] + + +def load_wav_feat_data(data_path, feat_path, eval_split_size): + wav_paths = find_wav_files(data_path) + feat_paths = find_feat_files(feat_path) + + wav_paths.sort(key=lambda x: Path(x).stem) + feat_paths.sort(key=lambda x: Path(x).stem) + + assert len(wav_paths) == len(feat_paths), f" [!] {len(wav_paths)} vs {feat_paths}" + for wav, feat in zip(wav_paths, feat_paths): + wav_name = Path(wav).stem + feat_name = Path(feat).stem + assert wav_name == feat_name + + items = list(zip(wav_paths, feat_paths)) + np.random.seed(0) + np.random.shuffle(items) + return items[:eval_split_size], items[eval_split_size:] diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..05e0fae8873d8606ddf4ab2743b3cf1f47db85f9 --- /dev/null +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -0,0 +1,152 @@ +import glob +import os +import random +from multiprocessing import Manager +from typing import List, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class WaveGradDataset(Dataset): + """ + WaveGrad Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly and returns + random segments of (audio, feature) couples. + """ + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): + + super().__init__() + self.ap = ap + self.item_list = items + self.seq_len = seq_len if return_segments else None + self.hop_len = hop_len + self.pad_short = pad_short + self.conv_pad = conv_pad + self.is_training = is_training + self.return_segments = return_segments + self.use_cache = use_cache + self.use_noise_augment = use_noise_augment + self.verbose = verbose + + if return_segments: + assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." + self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) + + # cache acoustic features + if use_cache: + self.create_feature_cache() + + def create_feature_cache(self): + self.manager = Manager() + self.cache = self.manager.list() + self.cache += [None for _ in range(len(self.item_list))] + + @staticmethod + def find_wav_files(path): + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, idx): + item = self.load_item(idx) + return item + + def load_test_samples(self, num_samples: int) -> List[Tuple]: + """Return test samples. + + Args: + num_samples (int): Number of samples to return. + + Returns: + List[Tuple]: melspectorgram and audio. + + Shapes: + - melspectrogram (Tensor): :math:`[C, T]` + - audio (Tensor): :math:`[T_audio]` + """ + samples = [] + return_segments = self.return_segments + self.return_segments = False + for idx in range(num_samples): + mel, audio = self.load_item(idx) + samples.append([mel, audio]) + self.return_segments = return_segments + return samples + + def load_item(self, idx): + """load (audio, feat) couple""" + # compute features from wav + wavpath = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + + if self.return_segments: + # correct audio length wrt segment length + if audio.shape[-1] < self.seq_len + self.pad_short: + audio = np.pad( + audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0 + ) + assert ( + audio.shape[-1] >= self.seq_len + self.pad_short + ), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" + + # correct the audio length wrt hop length + p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1] + audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0) + + if self.use_cache: + self.cache[idx] = audio + + if self.return_segments: + max_start = len(audio) - self.seq_len + start = random.randint(0, max_start) + end = start + self.seq_len + audio = audio[start:end] + + if self.use_noise_augment and self.is_training and self.return_segments: + audio = audio + (1 / 32768) * torch.randn_like(audio) + + mel = self.ap.melspectrogram(audio) + mel = mel[..., :-1] # ignore the padding + + audio = torch.from_numpy(audio).float() + mel = torch.from_numpy(mel).float().squeeze(0) + return (mel, audio) + + @staticmethod + def collate_full_clips(batch): + """This is used in tune_wavegrad.py. + It pads sequences to the max length.""" + max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1] + max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0] + + mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length]) + audios = torch.zeros([len(batch), max_audio_length]) + + for idx, b in enumerate(batch): + mel = b[0] + audio = b[1] + mels[idx, :, : mel.shape[1]] = mel + audios[idx, : audio.shape[0]] = audio + + return audios, mels diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2c771cf0ed5bb228eb8f4aaa6c850665c4997170 --- /dev/null +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -0,0 +1,117 @@ +import numpy as np +import torch +from torch.utils.data import Dataset + + +class WaveRNNDataset(Dataset): + """ + WaveRNN Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly. + """ + + def __init__( + self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True + ): + + super().__init__() + self.ap = ap + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.item_list = items + self.seq_len = seq_len + self.hop_len = hop_len + self.mel_len = seq_len // hop_len + self.pad = pad + self.mode = mode + self.mulaw = mulaw + self.is_training = is_training + self.verbose = verbose + self.return_segments = return_segments + + assert self.seq_len % self.hop_len == 0 + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, index): + item = self.load_item(index) + return item + + def load_test_samples(self, num_samples): + samples = [] + return_segments = self.return_segments + self.return_segments = False + for idx in range(num_samples): + mel, audio, _ = self.load_item(idx) + samples.append([mel, audio]) + self.return_segments = return_segments + return samples + + def load_item(self, index): + """ + load (audio, feat) couple if feature_path is set + else compute it on the fly + """ + if self.compute_feat: + + wavpath = self.item_list[index] + audio = self.ap.load_wav(wavpath) + if self.return_segments: + min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) + else: + min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) + if audio.shape[0] < min_audio_len: + print(" [!] Instance is too short! : {}".format(wavpath)) + audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) + mel = self.ap.melspectrogram(audio) + + if self.mode in ["gauss", "mold"]: + x_input = audio + elif isinstance(self.mode, int): + x_input = ( + self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode) + ) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + else: + + wavpath, feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + + if mel.shape[-1] < self.mel_len + 2 * self.pad: + print(" [!] Instance is too short! : {}".format(wavpath)) + self.item_list[index] = self.item_list[index + 1] + feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + if self.mode in ["gauss", "mold"]: + x_input = self.ap.load_wav(wavpath) + elif isinstance(self.mode, int): + x_input = np.load(feat_path.replace("/mel/", "/quant/")) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + return mel, x_input, wavpath + + def collate(self, batch): + mel_win = self.seq_len // self.hop_len + 2 * self.pad + max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] + + mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] + sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] + + mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)] + + coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)] + + mels = np.stack(mels).astype(np.float32) + if self.mode in ["gauss", "mold"]: + coarse = np.stack(coarse).astype(np.float32) + coarse = torch.FloatTensor(coarse) + x_input = coarse[:, : self.seq_len] + elif isinstance(self.mode, int): + coarse = np.stack(coarse).astype(np.int64) + coarse = torch.LongTensor(coarse) + x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0 + y_coarse = coarse[:, 1:] + mels = torch.FloatTensor(mels) + return x_input, mels, y_coarse diff --git a/TTS/vocoder/layers/__init__.py b/TTS/vocoder/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vocoder/layers/hifigan.py b/TTS/vocoder/layers/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..f51200724887b04746a125b7d7c368e0315ce7da --- /dev/null +++ b/TTS/vocoder/layers/hifigan.py @@ -0,0 +1,53 @@ +from torch import nn + + +# pylint: disable=dangerous-default-value +class ResStack(nn.Module): + def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]): + super().__init__() + resstack = [] + for dilation in dilations: + resstack += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(padding), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), + ] + self.resstack = nn.Sequential(*resstack) + + self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) + + def forward(self, x): + x1 = self.shortcut(x) + x2 = self.resstack(x) + return x1 + x2 + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.shortcut) + nn.utils.remove_weight_norm(self.resstack[2]) + nn.utils.remove_weight_norm(self.resstack[5]) + nn.utils.remove_weight_norm(self.resstack[8]) + nn.utils.remove_weight_norm(self.resstack[11]) + nn.utils.remove_weight_norm(self.resstack[14]) + nn.utils.remove_weight_norm(self.resstack[17]) + + +class MRF(nn.Module): + def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value + super().__init__() + self.resblock1 = ResStack(kernels[0], channel, 0, dilations) + self.resblock2 = ResStack(kernels[1], channel, 6, dilations) + self.resblock3 = ResStack(kernels[2], channel, 12, dilations) + + def forward(self, x): + x1 = self.resblock1(x) + x2 = self.resblock2(x) + x3 = self.resblock3(x) + return x1 + x2 + x3 + + def remove_weight_norm(self): + self.resblock1.remove_weight_norm() + self.resblock2.remove_weight_norm() + self.resblock3.remove_weight_norm() diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..848e292b8390f054366f1ea9a4f858a0e55cf50c --- /dev/null +++ b/TTS/vocoder/layers/losses.py @@ -0,0 +1,368 @@ +from typing import Dict, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.utils.audio import TorchSTFT +from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss + +################################# +# GENERATOR LOSSES +################################# + + +class STFTLoss(nn.Module): + """STFT loss. Input generate and real waveforms are converted + to spectrograms compared with L1 and Spectral convergence losses. + It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + + def __init__(self, n_fft, hop_length, win_length): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.stft = TorchSTFT(n_fft, hop_length, win_length) + + def forward(self, y_hat, y): + y_hat_M = self.stft(y_hat) + y_M = self.stft(y) + # magnitude loss + loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) + # spectral convergence loss + loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") + return loss_mag, loss_sc + + +class MultiScaleSTFTLoss(torch.nn.Module): + """Multi-scale STFT loss. Input generate and real waveforms are converted + to spectrograms compared with L1 and Spectral convergence losses. + It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + + def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)): + super().__init__() + self.loss_funcs = torch.nn.ModuleList() + for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): + self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length)) + + def forward(self, y_hat, y): + N = len(self.loss_funcs) + loss_sc = 0 + loss_mag = 0 + for f in self.loss_funcs: + lm, lsc = f(y_hat, y) + loss_mag += lm + loss_sc += lsc + loss_sc /= N + loss_mag /= N + return loss_mag, loss_sc + + +class L1SpecLoss(nn.Module): + """L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf""" + + def __init__( + self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True + ): + super().__init__() + self.use_mel = use_mel + self.stft = TorchSTFT( + n_fft, + hop_length, + win_length, + sample_rate=sample_rate, + mel_fmin=mel_fmin, + mel_fmax=mel_fmax, + n_mels=n_mels, + use_mel=use_mel, + ) + + def forward(self, y_hat, y): + y_hat_M = self.stft(y_hat) + y_M = self.stft(y) + # magnitude loss + loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) + return loss_mag + + +class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): + """Multiscale STFT loss for multi band model outputs. + From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106""" + + # pylint: disable=no-self-use + def forward(self, y_hat, y): + y_hat = y_hat.view(-1, 1, y_hat.shape[2]) + y = y.view(-1, 1, y.shape[2]) + return super().forward(y_hat.squeeze(1), y.squeeze(1)) + + +class MSEGLoss(nn.Module): + """Mean Squared Generator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_real): + loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) + return loss_fake + + +class HingeGLoss(nn.Module): + """Hinge Discriminator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_real): + # TODO: this might be wrong + loss_fake = torch.mean(F.relu(1.0 - score_real)) + return loss_fake + + +################################## +# DISCRIMINATOR LOSSES +################################## + + +class MSEDLoss(nn.Module): + """Mean Squared Discriminator Loss""" + + def __init__( + self, + ): + super().__init__() + self.loss_func = nn.MSELoss() + + # pylint: disable=no-self-use + def forward(self, score_fake, score_real): + loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape)) + loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape)) + loss_d = loss_real + loss_fake + return loss_d, loss_real, loss_fake + + +class HingeDLoss(nn.Module): + """Hinge Discriminator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_fake, score_real): + loss_real = torch.mean(F.relu(1.0 - score_real)) + loss_fake = torch.mean(F.relu(1.0 + score_fake)) + loss_d = loss_real + loss_fake + return loss_d, loss_real, loss_fake + + +class MelganFeatureLoss(nn.Module): + def __init__( + self, + ): + super().__init__() + self.loss_func = nn.L1Loss() + + # pylint: disable=no-self-use + def forward(self, fake_feats, real_feats): + loss_feats = 0 + num_feats = 0 + for idx, _ in enumerate(fake_feats): + for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]): + loss_feats += self.loss_func(fake_feat, real_feat) + num_feats += 1 + loss_feats = loss_feats / num_feats + return loss_feats + + +##################################### +# LOSS WRAPPERS +##################################### + + +def _apply_G_adv_loss(scores_fake, loss_func): + """Compute G adversarial loss function + and normalize values""" + adv_loss = 0 + if isinstance(scores_fake, list): + for score_fake in scores_fake: + fake_loss = loss_func(score_fake) + adv_loss += fake_loss + adv_loss /= len(scores_fake) + else: + fake_loss = loss_func(scores_fake) + adv_loss = fake_loss + return adv_loss + + +def _apply_D_loss(scores_fake, scores_real, loss_func): + """Compute D loss func and normalize loss values""" + loss = 0 + real_loss = 0 + fake_loss = 0 + if isinstance(scores_fake, list): + # multi-scale loss + for score_fake, score_real in zip(scores_fake, scores_real): + total_loss, real_loss, fake_loss = loss_func(score_fake=score_fake, score_real=score_real) + loss += total_loss + real_loss += real_loss + fake_loss += fake_loss + # normalize loss values with number of scales (discriminators) + loss /= len(scores_fake) + real_loss /= len(scores_real) + fake_loss /= len(scores_fake) + else: + # single scale loss + total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real) + loss = total_loss + return loss, real_loss, fake_loss + + +################################## +# MODEL LOSSES +################################## + + +class GeneratorLoss(nn.Module): + """Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes + losses. It allows to experiment with different combinations of loss functions with different models by just + changing configurations. + + Args: + C (AttrDict): model configuration. + """ + + def __init__(self, C): + super().__init__() + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." + + self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False + self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False + self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False + self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False + self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False + self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False + + self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0 + self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0 + self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0 + self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0 + self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0 + self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0 + + if C.use_stft_loss: + self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) + if C.use_subband_stft_loss: + self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params) + if C.use_mse_gan_loss: + self.mse_loss = MSEGLoss() + if C.use_hinge_gan_loss: + self.hinge_loss = HingeGLoss() + if C.use_feat_match_loss: + self.feat_match_loss = MelganFeatureLoss() + if C.use_l1_spec_loss: + assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"] + self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params) + + def forward( + self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None + ): + gen_loss = 0 + adv_loss = 0 + return_dict = {} + + # STFT Loss + if self.use_stft_loss: + stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1)) + return_dict["G_stft_loss_mg"] = stft_loss_mg + return_dict["G_stft_loss_sc"] = stft_loss_sc + gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) + + # L1 Spec loss + if self.use_l1_spec_loss: + l1_spec_loss = self.l1_spec_loss(y_hat, y) + return_dict["G_l1_spec_loss"] = l1_spec_loss + gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss + + # subband STFT Loss + if self.use_subband_stft_loss: + subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) + return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg + return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc + gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) + + # multiscale MSE adversarial loss + if self.use_mse_gan_loss and scores_fake is not None: + mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) + return_dict["G_mse_fake_loss"] = mse_fake_loss + adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss + + # multiscale Hinge adversarial loss + if self.use_hinge_gan_loss and not scores_fake is not None: + hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) + return_dict["G_hinge_fake_loss"] = hinge_fake_loss + adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss + + # Feature Matching Loss + if self.use_feat_match_loss and not feats_fake is None: + feat_match_loss = self.feat_match_loss(feats_fake, feats_real) + return_dict["G_feat_match_loss"] = feat_match_loss + adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss + return_dict["loss"] = gen_loss + adv_loss + return_dict["G_gen_loss"] = gen_loss + return_dict["G_adv_loss"] = adv_loss + return return_dict + + +class DiscriminatorLoss(nn.Module): + """Like ```GeneratorLoss```""" + + def __init__(self, C): + super().__init__() + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." + + self.use_mse_gan_loss = C.use_mse_gan_loss + self.use_hinge_gan_loss = C.use_hinge_gan_loss + + if C.use_mse_gan_loss: + self.mse_loss = MSEDLoss() + if C.use_hinge_gan_loss: + self.hinge_loss = HingeDLoss() + + def forward(self, scores_fake, scores_real): + loss = 0 + return_dict = {} + + if self.use_mse_gan_loss: + mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss + ) + return_dict["D_mse_gan_loss"] = mse_D_loss + return_dict["D_mse_gan_real_loss"] = mse_D_real_loss + return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss + loss += mse_D_loss + + if self.use_hinge_gan_loss: + hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss + ) + return_dict["D_hinge_gan_loss"] = hinge_D_loss + return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss + return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss + loss += hinge_D_loss + + return_dict["loss"] = loss + return return_dict + + +class WaveRNNLoss(nn.Module): + def __init__(self, wave_rnn_mode: Union[str, int]): + super().__init__() + if wave_rnn_mode == "mold": + self.loss_func = discretized_mix_logistic_loss + elif wave_rnn_mode == "gauss": + self.loss_func = gaussian_loss + elif isinstance(wave_rnn_mode, int): + self.loss_func = torch.nn.CrossEntropyLoss() + else: + raise ValueError(" [!] Unknown mode for Wavernn.") + + def forward(self, y_hat, y) -> Dict: + loss = self.loss_func(y_hat, y) + return {"loss": loss} diff --git a/TTS/vocoder/layers/lvc_block.py b/TTS/vocoder/layers/lvc_block.py new file mode 100644 index 0000000000000000000000000000000000000000..8913a1132ec769fd304077412289c01c0d1cb17b --- /dev/null +++ b/TTS/vocoder/layers/lvc_block.py @@ -0,0 +1,198 @@ +import torch +import torch.nn.functional as F + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( # pylint: disable=dangerous-default-value + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): + kpnet_ + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers + l_b = conv_out_channels * conv_layers + + padding = (kpnet_conv_size - 1) // 2 + self.input_conv = torch.nn.Sequential( + torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_conv = torch.nn.Sequential( + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True) + self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + Returns: + """ + batch, _, cond_length = c.shape + + c = self.input_conv(c) + c = c + self.residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + + kernels = k.contiguous().view( + batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length + ) + bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length) + return kernels, bias + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + upsample_ratio, + conv_layers=4, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = conv_layers + self.conv_kernel_size = conv_kernel_size + self.convs = torch.nn.ModuleList() + + self.upsample = torch.nn.ConvTranspose1d( + in_channels, + in_channels, + kernel_size=upsample_ratio * 2, + stride=upsample_ratio, + padding=upsample_ratio // 2 + upsample_ratio % 2, + output_padding=upsample_ratio % 2, + ) + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=conv_layers, + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + ) + + for i in range(conv_layers): + padding = (3**i) * int((conv_kernel_size - 1) / 2) + conv = torch.nn.Conv1d( + in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i + ) + + self.convs.append(conv) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + in_channels = x.shape[1] + kernels, bias = self.kernel_predictor(c) + + x = F.leaky_relu(x, 0.2) + x = self.upsample(x) + + for i in range(self.conv_layers): + y = F.leaky_relu(x, 0.2) + y = self.convs[i](y) + y = F.leaky_relu(y, 0.2) + + k = kernels[:, i, :, :, :, :] + b = bias[:, i, :, :] + y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length) + x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :]) + return x + + @staticmethod + def location_variable_convolution(x, kernel, bias, dilation, hop_size): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + + assert in_length == ( + kernel_length * hop_size + ), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o + bias.unsqueeze(-1).unsqueeze(-1) + o = o.contiguous().view(batch, out_channels, -1) + return o diff --git a/TTS/vocoder/layers/melgan.py b/TTS/vocoder/layers/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb328e98354dc0683b3c5b4f4160dd54d92fabd --- /dev/null +++ b/TTS/vocoder/layers/melgan.py @@ -0,0 +1,42 @@ +from torch import nn +from torch.nn.utils import weight_norm + + +class ResidualStack(nn.Module): + def __init__(self, channels, num_res_blocks, kernel_size): + super().__init__() + + assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." + base_padding = (kernel_size - 1) // 2 + + self.blocks = nn.ModuleList() + for idx in range(num_res_blocks): + layer_kernel_size = kernel_size + layer_dilation = layer_kernel_size**idx + layer_padding = base_padding * layer_dilation + self.blocks += [ + nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(layer_padding), + weight_norm( + nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True) + ), + nn.LeakyReLU(0.2), + weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)), + ) + ] + + self.shortcuts = nn.ModuleList( + [weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)] + ) + + def forward(self, x): + for block, shortcut in zip(self.blocks, self.shortcuts): + x = shortcut(x) + block(x) + return x + + def remove_weight_norm(self): + for block, shortcut in zip(self.blocks, self.shortcuts): + nn.utils.remove_weight_norm(block[2]) + nn.utils.remove_weight_norm(block[4]) + nn.utils.remove_weight_norm(shortcut) diff --git a/TTS/vocoder/layers/parallel_wavegan.py b/TTS/vocoder/layers/parallel_wavegan.py new file mode 100644 index 0000000000000000000000000000000000000000..51142e5eceb20564585635a9040a24bc8eb3b6e3 --- /dev/null +++ b/TTS/vocoder/layers/parallel_wavegan.py @@ -0,0 +1,77 @@ +import torch +from torch.nn import functional as F + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False, + ): + super().__init__() + self.dropout = dropout + # no future time stamps available + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + # dilation conv + self.conv = torch.nn.Conv1d( + res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) + else: + self.conv1x1_aux = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) + self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) + + def forward(self, x, c): + """ + x: B x D_res x T + c: B x D_aux x T + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv(x) + + # remove future time steps if use_causal_conv conv + x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + assert self.conv1x1_aux is not None + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # for skip connection + s = self.conv1x1_skip(x) + + # for residual connection + x = (self.conv1x1_out(x) + residual) * (0.5**2) + + return x, s diff --git a/TTS/vocoder/layers/pqmf.py b/TTS/vocoder/layers/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..6253efbbefc32222464a97bee99707d46bcdcf8b --- /dev/null +++ b/TTS/vocoder/layers/pqmf.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +import torch.nn.functional as F +from scipy import signal as sig + + +# adapted from +# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan +class PQMF(torch.nn.Module): + def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): + super().__init__() + + self.N = N + self.taps = taps + self.cutoff = cutoff + self.beta = beta + + QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) + H = np.zeros((N, len(QMF))) + G = np.zeros((N, len(QMF))) + for k in range(N): + constant_factor = ( + (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + ) # TODO: (taps - 1) -> taps + phase = (-1) ** k * np.pi / 4 + H[k] = 2 * QMF * np.cos(constant_factor + phase) + + G[k] = 2 * QMF * np.cos(constant_factor - phase) + + H = torch.from_numpy(H[:, None, :]).float() + G = torch.from_numpy(G[None, :, :]).float() + + self.register_buffer("H", H) + self.register_buffer("G", G) + + updown_filter = torch.zeros((N, N, N)).float() + for k in range(N): + updown_filter[k, k, 0] = 1.0 + self.register_buffer("updown_filter", updown_filter) + self.N = N + + self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) + + def forward(self, x): + return self.analysis(x) + + def analysis(self, x): + return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N) + + def synthesis(self, x): + x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N) + x = F.conv1d(x, self.G, padding=self.taps // 2) + return x diff --git a/TTS/vocoder/layers/qmf.dat b/TTS/vocoder/layers/qmf.dat new file mode 100644 index 0000000000000000000000000000000000000000..17eab1379de991c36897c2ce701802ef76849c0d --- /dev/null +++ b/TTS/vocoder/layers/qmf.dat @@ -0,0 +1,640 @@ + 0.0000000e+000 + -5.5252865e-004 + -5.6176926e-004 + -4.9475181e-004 + -4.8752280e-004 + -4.8937912e-004 + -5.0407143e-004 + -5.2265643e-004 + -5.4665656e-004 + -5.6778026e-004 + -5.8709305e-004 + -6.1327474e-004 + -6.3124935e-004 + -6.5403334e-004 + -6.7776908e-004 + -6.9416146e-004 + -7.1577365e-004 + -7.2550431e-004 + -7.4409419e-004 + -7.4905981e-004 + -7.6813719e-004 + -7.7248486e-004 + -7.8343323e-004 + -7.7798695e-004 + -7.8036647e-004 + -7.8014496e-004 + -7.7579773e-004 + -7.6307936e-004 + -7.5300014e-004 + -7.3193572e-004 + -7.2153920e-004 + -6.9179375e-004 + -6.6504151e-004 + -6.3415949e-004 + -5.9461189e-004 + -5.5645764e-004 + -5.1455722e-004 + -4.6063255e-004 + -4.0951215e-004 + -3.5011759e-004 + -2.8969812e-004 + -2.0983373e-004 + -1.4463809e-004 + -6.1733441e-005 + 1.3494974e-005 + 1.0943831e-004 + 2.0430171e-004 + 2.9495311e-004 + 4.0265402e-004 + 5.1073885e-004 + 6.2393761e-004 + 7.4580259e-004 + 8.6084433e-004 + 9.8859883e-004 + 1.1250155e-003 + 1.2577885e-003 + 1.3902495e-003 + 1.5443220e-003 + 1.6868083e-003 + 1.8348265e-003 + 1.9841141e-003 + 2.1461584e-003 + 2.3017255e-003 + 2.4625617e-003 + 2.6201759e-003 + 2.7870464e-003 + 2.9469448e-003 + 3.1125421e-003 + 3.2739613e-003 + 3.4418874e-003 + 3.6008268e-003 + 3.7603923e-003 + 3.9207432e-003 + 4.0819753e-003 + 4.2264269e-003 + 4.3730720e-003 + 4.5209853e-003 + 4.6606461e-003 + 4.7932561e-003 + 4.9137604e-003 + 5.0393023e-003 + 5.1407354e-003 + 5.2461166e-003 + 5.3471681e-003 + 5.4196776e-003 + 5.4876040e-003 + 5.5475715e-003 + 5.5938023e-003 + 5.6220643e-003 + 5.6455197e-003 + 5.6389200e-003 + 5.6266114e-003 + 5.5917129e-003 + 5.5404364e-003 + 5.4753783e-003 + 5.3838976e-003 + 5.2715759e-003 + 5.1382275e-003 + 4.9839688e-003 + 4.8109469e-003 + 4.6039530e-003 + 4.3801862e-003 + 4.1251642e-003 + 3.8456408e-003 + 3.5401247e-003 + 3.2091886e-003 + 2.8446758e-003 + 2.4508540e-003 + 2.0274176e-003 + 1.5784683e-003 + 1.0902329e-003 + 5.8322642e-004 + 2.7604519e-005 + -5.4642809e-004 + -1.1568136e-003 + -1.8039473e-003 + -2.4826724e-003 + -3.1933778e-003 + -3.9401124e-003 + -4.7222596e-003 + -5.5337211e-003 + -6.3792293e-003 + -7.2615817e-003 + -8.1798233e-003 + -9.1325330e-003 + -1.0115022e-002 + -1.1131555e-002 + -1.2185000e-002 + -1.3271822e-002 + -1.4390467e-002 + -1.5540555e-002 + -1.6732471e-002 + -1.7943338e-002 + -1.9187243e-002 + -2.0453179e-002 + -2.1746755e-002 + -2.3068017e-002 + -2.4416099e-002 + -2.5787585e-002 + -2.7185943e-002 + -2.8607217e-002 + -3.0050266e-002 + -3.1501761e-002 + -3.2975408e-002 + -3.4462095e-002 + -3.5969756e-002 + -3.7481285e-002 + -3.9005368e-002 + -4.0534917e-002 + -4.2064909e-002 + -4.3609754e-002 + -4.5148841e-002 + -4.6684303e-002 + -4.8216572e-002 + -4.9738576e-002 + -5.1255616e-002 + -5.2763075e-002 + -5.4245277e-002 + -5.5717365e-002 + -5.7161645e-002 + -5.8591568e-002 + -5.9983748e-002 + -6.1345517e-002 + -6.2685781e-002 + -6.3971590e-002 + -6.5224711e-002 + -6.6436751e-002 + -6.7607599e-002 + -6.8704383e-002 + -6.9763024e-002 + -7.0762871e-002 + -7.1700267e-002 + -7.2568258e-002 + -7.3362026e-002 + -7.4100364e-002 + -7.4745256e-002 + -7.5313734e-002 + -7.5800836e-002 + -7.6199248e-002 + -7.6499217e-002 + -7.6709349e-002 + -7.6817398e-002 + -7.6823001e-002 + -7.6720492e-002 + -7.6505072e-002 + -7.6174832e-002 + -7.5730576e-002 + -7.5157626e-002 + -7.4466439e-002 + -7.3640601e-002 + -7.2677464e-002 + -7.1582636e-002 + -7.0353307e-002 + -6.8966401e-002 + -6.7452502e-002 + -6.5769067e-002 + -6.3944481e-002 + -6.1960278e-002 + -5.9816657e-002 + -5.7515269e-002 + -5.5046003e-002 + -5.2409382e-002 + -4.9597868e-002 + -4.6630331e-002 + -4.3476878e-002 + -4.0145828e-002 + -3.6641812e-002 + -3.2958393e-002 + -2.9082401e-002 + -2.5030756e-002 + -2.0799707e-002 + -1.6370126e-002 + -1.1762383e-002 + -6.9636862e-003 + -1.9765601e-003 + 3.2086897e-003 + 8.5711749e-003 + 1.4128883e-002 + 1.9883413e-002 + 2.5822729e-002 + 3.1953127e-002 + 3.8277657e-002 + 4.4780682e-002 + 5.1480418e-002 + 5.8370533e-002 + 6.5440985e-002 + 7.2694330e-002 + 8.0137293e-002 + 8.7754754e-002 + 9.5553335e-002 + 1.0353295e-001 + 1.1168269e-001 + 1.2000780e-001 + 1.2850029e-001 + 1.3715518e-001 + 1.4597665e-001 + 1.5496071e-001 + 1.6409589e-001 + 1.7338082e-001 + 1.8281725e-001 + 1.9239667e-001 + 2.0212502e-001 + 2.1197359e-001 + 2.2196527e-001 + 2.3206909e-001 + 2.4230169e-001 + 2.5264803e-001 + 2.6310533e-001 + 2.7366340e-001 + 2.8432142e-001 + 2.9507167e-001 + 3.0590986e-001 + 3.1682789e-001 + 3.2781137e-001 + 3.3887227e-001 + 3.4999141e-001 + 3.6115899e-001 + 3.7237955e-001 + 3.8363500e-001 + 3.9492118e-001 + 4.0623177e-001 + 4.1756969e-001 + 4.2891199e-001 + 4.4025538e-001 + 4.5159965e-001 + 4.6293081e-001 + 4.7424532e-001 + 4.8552531e-001 + 4.9677083e-001 + 5.0798175e-001 + 5.1912350e-001 + 5.3022409e-001 + 5.4125534e-001 + 5.5220513e-001 + 5.6307891e-001 + 5.7385241e-001 + 5.8454032e-001 + 5.9511231e-001 + 6.0557835e-001 + 6.1591099e-001 + 6.2612427e-001 + 6.3619801e-001 + 6.4612697e-001 + 6.5590163e-001 + 6.6551399e-001 + 6.7496632e-001 + 6.8423533e-001 + 6.9332824e-001 + 7.0223887e-001 + 7.1094104e-001 + 7.1944626e-001 + 7.2774489e-001 + 7.3582118e-001 + 7.4368279e-001 + 7.5131375e-001 + 7.5870808e-001 + 7.6586749e-001 + 7.7277809e-001 + 7.7942875e-001 + 7.8583531e-001 + 7.9197358e-001 + 7.9784664e-001 + 8.0344858e-001 + 8.0876950e-001 + 8.1381913e-001 + 8.1857760e-001 + 8.2304199e-001 + 8.2722753e-001 + 8.3110385e-001 + 8.3469374e-001 + 8.3797173e-001 + 8.4095414e-001 + 8.4362383e-001 + 8.4598185e-001 + 8.4803158e-001 + 8.4978052e-001 + 8.5119715e-001 + 8.5230470e-001 + 8.5310209e-001 + 8.5357206e-001 + 8.5373856e-001 + 8.5357206e-001 + 8.5310209e-001 + 8.5230470e-001 + 8.5119715e-001 + 8.4978052e-001 + 8.4803158e-001 + 8.4598185e-001 + 8.4362383e-001 + 8.4095414e-001 + 8.3797173e-001 + 8.3469374e-001 + 8.3110385e-001 + 8.2722753e-001 + 8.2304199e-001 + 8.1857760e-001 + 8.1381913e-001 + 8.0876950e-001 + 8.0344858e-001 + 7.9784664e-001 + 7.9197358e-001 + 7.8583531e-001 + 7.7942875e-001 + 7.7277809e-001 + 7.6586749e-001 + 7.5870808e-001 + 7.5131375e-001 + 7.4368279e-001 + 7.3582118e-001 + 7.2774489e-001 + 7.1944626e-001 + 7.1094104e-001 + 7.0223887e-001 + 6.9332824e-001 + 6.8423533e-001 + 6.7496632e-001 + 6.6551399e-001 + 6.5590163e-001 + 6.4612697e-001 + 6.3619801e-001 + 6.2612427e-001 + 6.1591099e-001 + 6.0557835e-001 + 5.9511231e-001 + 5.8454032e-001 + 5.7385241e-001 + 5.6307891e-001 + 5.5220513e-001 + 5.4125534e-001 + 5.3022409e-001 + 5.1912350e-001 + 5.0798175e-001 + 4.9677083e-001 + 4.8552531e-001 + 4.7424532e-001 + 4.6293081e-001 + 4.5159965e-001 + 4.4025538e-001 + 4.2891199e-001 + 4.1756969e-001 + 4.0623177e-001 + 3.9492118e-001 + 3.8363500e-001 + 3.7237955e-001 + 3.6115899e-001 + 3.4999141e-001 + 3.3887227e-001 + 3.2781137e-001 + 3.1682789e-001 + 3.0590986e-001 + 2.9507167e-001 + 2.8432142e-001 + 2.7366340e-001 + 2.6310533e-001 + 2.5264803e-001 + 2.4230169e-001 + 2.3206909e-001 + 2.2196527e-001 + 2.1197359e-001 + 2.0212502e-001 + 1.9239667e-001 + 1.8281725e-001 + 1.7338082e-001 + 1.6409589e-001 + 1.5496071e-001 + 1.4597665e-001 + 1.3715518e-001 + 1.2850029e-001 + 1.2000780e-001 + 1.1168269e-001 + 1.0353295e-001 + 9.5553335e-002 + 8.7754754e-002 + 8.0137293e-002 + 7.2694330e-002 + 6.5440985e-002 + 5.8370533e-002 + 5.1480418e-002 + 4.4780682e-002 + 3.8277657e-002 + 3.1953127e-002 + 2.5822729e-002 + 1.9883413e-002 + 1.4128883e-002 + 8.5711749e-003 + 3.2086897e-003 + -1.9765601e-003 + -6.9636862e-003 + -1.1762383e-002 + -1.6370126e-002 + -2.0799707e-002 + -2.5030756e-002 + -2.9082401e-002 + -3.2958393e-002 + -3.6641812e-002 + -4.0145828e-002 + -4.3476878e-002 + -4.6630331e-002 + -4.9597868e-002 + -5.2409382e-002 + -5.5046003e-002 + -5.7515269e-002 + -5.9816657e-002 + -6.1960278e-002 + -6.3944481e-002 + -6.5769067e-002 + -6.7452502e-002 + -6.8966401e-002 + -7.0353307e-002 + -7.1582636e-002 + -7.2677464e-002 + -7.3640601e-002 + -7.4466439e-002 + -7.5157626e-002 + -7.5730576e-002 + -7.6174832e-002 + -7.6505072e-002 + -7.6720492e-002 + -7.6823001e-002 + -7.6817398e-002 + -7.6709349e-002 + -7.6499217e-002 + -7.6199248e-002 + -7.5800836e-002 + -7.5313734e-002 + -7.4745256e-002 + -7.4100364e-002 + -7.3362026e-002 + -7.2568258e-002 + -7.1700267e-002 + -7.0762871e-002 + -6.9763024e-002 + -6.8704383e-002 + -6.7607599e-002 + -6.6436751e-002 + -6.5224711e-002 + -6.3971590e-002 + -6.2685781e-002 + -6.1345517e-002 + -5.9983748e-002 + -5.8591568e-002 + -5.7161645e-002 + -5.5717365e-002 + -5.4245277e-002 + -5.2763075e-002 + -5.1255616e-002 + -4.9738576e-002 + -4.8216572e-002 + -4.6684303e-002 + -4.5148841e-002 + -4.3609754e-002 + -4.2064909e-002 + -4.0534917e-002 + -3.9005368e-002 + -3.7481285e-002 + -3.5969756e-002 + -3.4462095e-002 + -3.2975408e-002 + -3.1501761e-002 + -3.0050266e-002 + -2.8607217e-002 + -2.7185943e-002 + -2.5787585e-002 + -2.4416099e-002 + -2.3068017e-002 + -2.1746755e-002 + -2.0453179e-002 + -1.9187243e-002 + -1.7943338e-002 + -1.6732471e-002 + -1.5540555e-002 + -1.4390467e-002 + -1.3271822e-002 + -1.2185000e-002 + -1.1131555e-002 + -1.0115022e-002 + -9.1325330e-003 + -8.1798233e-003 + -7.2615817e-003 + -6.3792293e-003 + -5.5337211e-003 + -4.7222596e-003 + -3.9401124e-003 + -3.1933778e-003 + -2.4826724e-003 + -1.8039473e-003 + -1.1568136e-003 + -5.4642809e-004 + 2.7604519e-005 + 5.8322642e-004 + 1.0902329e-003 + 1.5784683e-003 + 2.0274176e-003 + 2.4508540e-003 + 2.8446758e-003 + 3.2091886e-003 + 3.5401247e-003 + 3.8456408e-003 + 4.1251642e-003 + 4.3801862e-003 + 4.6039530e-003 + 4.8109469e-003 + 4.9839688e-003 + 5.1382275e-003 + 5.2715759e-003 + 5.3838976e-003 + 5.4753783e-003 + 5.5404364e-003 + 5.5917129e-003 + 5.6266114e-003 + 5.6389200e-003 + 5.6455197e-003 + 5.6220643e-003 + 5.5938023e-003 + 5.5475715e-003 + 5.4876040e-003 + 5.4196776e-003 + 5.3471681e-003 + 5.2461166e-003 + 5.1407354e-003 + 5.0393023e-003 + 4.9137604e-003 + 4.7932561e-003 + 4.6606461e-003 + 4.5209853e-003 + 4.3730720e-003 + 4.2264269e-003 + 4.0819753e-003 + 3.9207432e-003 + 3.7603923e-003 + 3.6008268e-003 + 3.4418874e-003 + 3.2739613e-003 + 3.1125421e-003 + 2.9469448e-003 + 2.7870464e-003 + 2.6201759e-003 + 2.4625617e-003 + 2.3017255e-003 + 2.1461584e-003 + 1.9841141e-003 + 1.8348265e-003 + 1.6868083e-003 + 1.5443220e-003 + 1.3902495e-003 + 1.2577885e-003 + 1.1250155e-003 + 9.8859883e-004 + 8.6084433e-004 + 7.4580259e-004 + 6.2393761e-004 + 5.1073885e-004 + 4.0265402e-004 + 2.9495311e-004 + 2.0430171e-004 + 1.0943831e-004 + 1.3494974e-005 + -6.1733441e-005 + -1.4463809e-004 + -2.0983373e-004 + -2.8969812e-004 + -3.5011759e-004 + -4.0951215e-004 + -4.6063255e-004 + -5.1455722e-004 + -5.5645764e-004 + -5.9461189e-004 + -6.3415949e-004 + -6.6504151e-004 + -6.9179375e-004 + -7.2153920e-004 + -7.3193572e-004 + -7.5300014e-004 + -7.6307936e-004 + -7.7579773e-004 + -7.8014496e-004 + -7.8036647e-004 + -7.7798695e-004 + -7.8343323e-004 + -7.7248486e-004 + -7.6813719e-004 + -7.4905981e-004 + -7.4409419e-004 + -7.2550431e-004 + -7.1577365e-004 + -6.9416146e-004 + -6.7776908e-004 + -6.5403334e-004 + -6.3124935e-004 + -6.1327474e-004 + -5.8709305e-004 + -5.6778026e-004 + -5.4665656e-004 + -5.2265643e-004 + -5.0407143e-004 + -4.8937912e-004 + -4.8752280e-004 + -4.9475181e-004 + -5.6176926e-004 + -5.5252865e-004 diff --git a/TTS/vocoder/layers/upsample.py b/TTS/vocoder/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..e169db00b2749493e1cec07ee51c93178dada118 --- /dev/null +++ b/TTS/vocoder/layers/upsample.py @@ -0,0 +1,102 @@ +import torch +from torch.nn import functional as F + + +class Stretch2d(torch.nn.Module): + def __init__(self, x_scale, y_scale, mode="nearest"): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + """ + x (Tensor): Input tensor (B, C, F, T). + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + """ + return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + + +class UpsampleNetwork(torch.nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + super().__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = torch.nn.ModuleList() + for scale in upsample_factors: + # interpolation layer + stretch = Stretch2d(scale, 1, interpolate_mode) + self.up_layers += [stretch] + + # conv layer + assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.up_layers += [conv] + + # nonlinear + if nonlinear_activation is not None: + nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + self.up_layers += [nonlinear] + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsample) + """ + c = c.unsqueeze(1) # (B, 1, C, T) + for f in self.up_layers: + c = f(c) + return c.squeeze(1) # (B, C, T') + + +class ConvUpsample(torch.nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False, + ): + super().__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + # To capture wide-context information in conditional features + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + # NOTE(kan-bayashi): Here do not use padding because the input is already padded + self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) + self.upsample = UpsampleNetwork( + upsample_factors=upsample_factors, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + ) + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsampled), + """ + c_ = self.conv_in(c) + c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..24b905f994b69075fc5e46249ce0c7719fe4b174 --- /dev/null +++ b/TTS/vocoder/layers/wavegrad.py @@ -0,0 +1,165 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils import weight_norm + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.orthogonal_(self.weight) + nn.init.zeros_(self.bias) + + +class PositionalEncoding(nn.Module): + """Positional encoding with noise level conditioning""" + + def __init__(self, n_channels, max_len=10000): + super().__init__() + self.n_channels = n_channels + self.max_len = max_len + self.C = 5000 + self.pe = torch.zeros(0, 0) + + def forward(self, x, noise_level): + if x.shape[2] > self.pe.shape[1]: + self.init_pe_matrix(x.shape[1], x.shape[2], x) + return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C + + def init_pe_matrix(self, n_channels, max_len, x): + pe = torch.zeros(max_len, n_channels) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels) + + pe[:, 0::2] = torch.sin(position / div_term) + pe[:, 1::2] = torch.cos(position / div_term) + self.pe = pe.transpose(0, 1).to(x) + + +class FiLM(nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.encoding = PositionalEncoding(input_size) + self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) + self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) + + nn.init.xavier_uniform_(self.input_conv.weight) + nn.init.xavier_uniform_(self.output_conv.weight) + nn.init.zeros_(self.input_conv.bias) + nn.init.zeros_(self.output_conv.bias) + + def forward(self, x, noise_scale): + o = self.input_conv(x) + o = F.leaky_relu(o, 0.2) + o = self.encoding(o, noise_scale) + shift, scale = torch.chunk(self.output_conv(o), 2, dim=1) + return shift, scale + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv) + nn.utils.remove_weight_norm(self.output_conv) + + def apply_weight_norm(self): + self.input_conv = weight_norm(self.input_conv) + self.output_conv = weight_norm(self.output_conv) + + +@torch.jit.script +def shif_and_scale(x, scale, shift): + o = shift + scale * x + return o + + +class UBlock(nn.Module): + def __init__(self, input_size, hidden_size, factor, dilation): + super().__init__() + assert isinstance(dilation, (list, tuple)) + assert len(dilation) == 4 + + self.factor = factor + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]), + ] + ) + self.out_block = nn.ModuleList( + [ + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]), + ] + ) + + def forward(self, x, shift, scale): + x_inter = F.interpolate(x, size=x.shape[-1] * self.factor) + res = self.res_block(x_inter) + o = F.leaky_relu(x_inter, 0.2) + o = F.interpolate(o, size=x.shape[-1] * self.factor) + o = self.main_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.main_block[1](o) + res2 = res + o + o = shif_and_scale(res2, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[1](o) + o = o + res2 + return o + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.res_block) + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + for _, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) + for idx, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + self.out_block[idx] = weight_norm(layer) + + +class DBlock(nn.Module): + def __init__(self, input_size, hidden_size, factor): + super().__init__() + self.factor = factor + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), + Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), + Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), + ] + ) + + def forward(self, x): + size = x.shape[-1] // self.factor + res = self.res_block(x) + res = F.interpolate(res, size=size) + o = F.interpolate(x, size=size) + for layer in self.main_block: + o = F.leaky_relu(o, 0.2) + o = layer(o) + return o + res + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.res_block) + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65901617b69d3ae708e09226c5e4ad903f89a929 --- /dev/null +++ b/TTS/vocoder/models/__init__.py @@ -0,0 +1,154 @@ +import importlib +import re + +from coqpit import Coqpit + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_model(config: Coqpit): + """Load models directly from configuration.""" + if "discriminator_model" in config and "generator_model" in config: + MyModel = importlib.import_module("TTS.vocoder.models.gan") + MyModel = getattr(MyModel, "GAN") + else: + MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) + if config.model.lower() == "wavernn": + MyModel = getattr(MyModel, "Wavernn") + elif config.model.lower() == "gan": + MyModel = getattr(MyModel, "GAN") + elif config.model.lower() == "wavegrad": + MyModel = getattr(MyModel, "Wavegrad") + else: + try: + MyModel = getattr(MyModel, to_camel(config.model)) + except ModuleNotFoundError as e: + raise ValueError(f"Model {config.model} not exist!") from e + print(" > Vocoder Model: {}".format(config.model)) + return MyModel.init_from_config(config) + + +def setup_generator(c): + """TODO: use config object as arguments""" + print(" > Generator Model: {}".format(c.generator_model)) + MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.generator_model)) + # this is to preserve the Wavernn class name (instead of Wavernn) + if c.generator_model.lower() in "hifigan_generator": + model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) + elif c.generator_model.lower() in "melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model in "melgan_fb_generator": + raise ValueError("melgan_fb_generator is now fullband_melgan_generator") + elif c.generator_model.lower() in "multiband_melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "fullband_melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "parallel_wavegan_generator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + stacks=c.generator_model_params["stacks"], + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=c.audio["num_mels"], + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=c.generator_model_params["upsample_factors"], + ) + elif c.generator_model.lower() in "univnet_generator": + model = MyModel(**c.generator_model_params) + else: + raise NotImplementedError(f"Model {c.generator_model} not implemented!") + return model + + +def setup_discriminator(c): + """TODO: use config objekt as arguments""" + print(" > Discriminator Model: {}".format(c.discriminator_model)) + if "parallel_wavegan" in c.discriminator_model: + MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") + else: + MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) + if c.discriminator_model in "hifigan_discriminator": + model = MyModel() + if c.discriminator_model in "random_window_discriminator": + model = MyModel( + cond_channels=c.audio["num_mels"], + hop_length=c.audio["hop_length"], + uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], + cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], + cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], + window_sizes=c.discriminator_model_params["window_sizes"], + ) + if c.discriminator_model in "melgan_multiscale_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=c.discriminator_model_params["base_channels"], + max_channels=c.discriminator_model_params["max_channels"], + downsample_factors=c.discriminator_model_params["downsample_factors"], + ) + if c.discriminator_model == "residual_parallel_wavegan_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params["num_layers"], + stacks=c.discriminator_model_params["stacks"], + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ) + if c.discriminator_model == "parallel_wavegan_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params["num_layers"], + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ) + if c.discriminator_model == "univnet_discriminator": + model = MyModel() + return model diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..01a7ff68771c72f89f9d0fb6708706f6f92ba96a --- /dev/null +++ b/TTS/vocoder/models/base_vocoder.py @@ -0,0 +1,53 @@ +from coqpit import Coqpit + +from TTS.model import BaseTrainerModel + +# pylint: skip-file + + +class BaseVocoder(BaseTrainerModel): + """Base `vocoder` class. Every new `vocoder` model must inherit this. + + It defines `vocoder` specific functions on top of `Model`. + + Notes on input/output tensor shapes: + Any input or output tensor of the model must be shaped as + + - 3D tensors `batch x time x channels` + - 2D tensors `batch x channels` + - 1D tensors `batch x 1` + """ + + def __init__(self, config): + super().__init__() + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + if "characters" in config: + _, self.config, num_chars = self.get_characters(config) + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + self.config = config + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/TTS/vocoder/models/fullband_melgan_generator.py b/TTS/vocoder/models/fullband_melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..ee25559af0d468aac535841bdfdd33b366250f43 --- /dev/null +++ b/TTS/vocoder/models/fullband_melgan_generator.py @@ -0,0 +1,33 @@ +import torch + +from TTS.vocoder.models.melgan_generator import MelganGenerator + + +class FullbandMelganGenerator(MelganGenerator): + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=4, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) + + @torch.no_grad() + def inference(self, cond_features): + cond_features = cond_features.to(self.layers[1].weight.device) + cond_features = torch.nn.functional.pad( + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.layers(cond_features) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py new file mode 100644 index 0000000000000000000000000000000000000000..a3803f7714aa5c537c3de334bf9ba81496169502 --- /dev/null +++ b/TTS/vocoder/models/gan.py @@ -0,0 +1,373 @@ +from inspect import signature +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss +from TTS.vocoder.models import setup_discriminator, setup_generator +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.generic_utils import plot_results + + +class GAN(BaseVocoder): + def __init__(self, config: Coqpit, ap: AudioProcessor = None): + """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. + It also helps mixing and matching different generator and disciminator networks easily. + + To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest + is handled by the `GAN` class. + + Args: + config (Coqpit): Model configuration. + ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None. + + Examples: + Initializing the GAN model with HifiGAN generator and discriminator. + >>> from TTS.vocoder.configs import HifiganConfig + >>> config = HifiganConfig() + >>> model = GAN(config) + """ + super().__init__(config) + self.config = config + self.model_g = setup_generator(config) + self.model_d = setup_discriminator(config) + self.train_disc = False # if False, train only the generator. + self.y_hat_g = None # the last generator prediction to be passed onto the discriminator + self.ap = ap + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the generator's forward pass. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: output of the GAN generator network. + """ + return self.model_g.forward(x) + + def inference(self, x: torch.Tensor) -> torch.Tensor: + """Run the generator's inference pass. + + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: output of the GAN generator network. + """ + return self.model_g.inference(x) + + def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for + network on the current pass. + + Args: + batch (Dict): Batch of samples returned by the dataloader. + criterion (Dict): Criterion used to compute the losses. + optimizer_idx (int): ID of the optimizer in use on the current pass. + + Raises: + ValueError: `optimizer_idx` is an unexpected value. + + Returns: + Tuple[Dict, Dict]: model outputs and the computed loss values. + """ + outputs = {} + loss_dict = {} + + x = batch["input"] + y = batch["waveform"] + + if optimizer_idx not in [0, 1]: + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + if optimizer_idx == 0: + # DISCRIMINATOR optimization + + # generator pass + y_hat = self.model_g(x)[:, :, : y.size(2)] + + # cache for generator loss + # pylint: disable=W0201 + self.y_hat_g = y_hat + self.y_hat_sub = None + self.y_sub_g = None + + # PQMF formatting + if y_hat.shape[1] > 1: + self.y_hat_sub = y_hat + y_hat = self.model_g.pqmf_synthesis(y_hat) + self.y_hat_g = y_hat # save for generator loss + self.y_sub_g = self.model_g.pqmf_analysis(y) + + scores_fake, feats_fake, feats_real = None, None, None + + if self.train_disc: + # use different samples for G and D trainings + if self.config.diff_samples_for_G_and_D: + x_d = batch["input_disc"] + y_d = batch["waveform_disc"] + # use a different sample than generator + with torch.no_grad(): + y_hat = self.model_g(x_d) + + # PQMF formatting + if y_hat.shape[1] > 1: + y_hat = self.model_g.pqmf_synthesis(y_hat) + else: + # use the same samples as generator + x_d = x.clone() + y_d = y.clone() + y_hat = self.y_hat_g + + # run D with or without cond. features + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(y_hat.detach().clone(), x_d) + D_out_real = self.model_d(y_d, x_d) + else: + D_out_fake = self.model_d(y_hat.detach()) + D_out_real = self.model_d(y_d) + + # format D outputs + if isinstance(D_out_fake, tuple): + # self.model_d returns scores and features + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + scores_real, feats_real = None, None + else: + scores_real, feats_real = D_out_real + else: + # model D returns only scores + scores_fake = D_out_fake + scores_real = D_out_real + + # compute losses + loss_dict = criterion[optimizer_idx](scores_fake, scores_real) + outputs = {"model_outputs": y_hat} + + if optimizer_idx == 1: + # GENERATOR loss + scores_fake, feats_fake, feats_real = None, None, None + if self.train_disc: + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(self.y_hat_g, x) + else: + D_out_fake = self.model_d(self.y_hat_g) + D_out_real = None + + if self.config.use_feat_match_loss: + with torch.no_grad(): + D_out_real = self.model_d(y) + + # format D outputs + if isinstance(D_out_fake, tuple): + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + feats_real = None + else: + _, feats_real = D_out_real + else: + scores_fake = D_out_fake + feats_fake, feats_real = None, None + + # compute losses + loss_dict = criterion[optimizer_idx]( + self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g + ) + outputs = {"model_outputs": self.y_hat_g} + return outputs, loss_dict + + def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: + """Logging shared by the training and evaluation. + + Args: + name (str): Name of the run. `train` or `eval`, + ap (AudioProcessor): Audio processor used in training. + batch (Dict): Batch used in the last train/eval step. + outputs (Dict): Model outputs from the last train/eval step. + + Returns: + Tuple[Dict, Dict]: log figures and audio samples. + """ + y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"] + y = batch["waveform"] + figures = plot_results(y_hat, y, ap, name) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name}/audio": sample_voice} + return figures, audios + + def train_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + """Call `_log()` for training.""" + figures, audios = self._log("eval", self.ap, batch, outputs) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Call `train_step()` with `no_grad()`""" + self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + """Call `_log()` for evaluation.""" + figures, audios = self._log("eval", self.ap, batch, outputs) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, # pylint: disable=unused-argument, redefined-builtin + ) -> None: + """Load a GAN checkpoint and initialize model parameters. + + Args: + config (Coqpit): Model config. + checkpoint_path (str): Checkpoint file path. + eval (bool, optional): If true, load the model for inference. If falseDefaults to False. + """ + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + # band-aid for older than v0.0.15 GAN models + if "model_disc" in state: + self.model_g.load_checkpoint(config, checkpoint_path, eval) + else: + self.load_state_dict(state["model"]) + if eval: + self.model_d = None + if hasattr(self.model_g, "remove_weight_norm"): + self.model_g.remove_weight_norm() + + def on_train_step_start(self, trainer) -> None: + """Enable the discriminator training based on `steps_to_start_discriminator` + + Args: + trainer (Trainer): Trainer object. + """ + self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + + Returns: + List: optimizers. + """ + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g + ) + optimizer2 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d + ) + return [optimizer2, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler2, scheduler1] + + @staticmethod + def format_batch(batch: List) -> Dict: + """Format the batch for training. + + Args: + batch (List): Batch out of the dataloader. + + Returns: + Dict: formatted model inputs. + """ + if isinstance(batch[0], list): + x_G, y_G = batch[0] + x_D, y_D = batch[1] + return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D} + x, y = batch + return {"input": x, "waveform": y} + + def get_data_loader( # pylint: disable=no-self-use, unused-argument + self, + config: Coqpit, + assets: Dict, + is_eval: True, + samples: List, + verbose: bool, + num_gpus: int, + rank: int = None, # pylint: disable=unused-argument + ): + """Initiate and return the GAN dataloader. + + Args: + config (Coqpit): Model config. + ap (AudioProcessor): Audio processor. + is_eval (True): Set the dataloader for evaluation if true. + samples (List): Data samples. + verbose (bool): Log information if true. + num_gpus (int): Number of GPUs in use. + rank (int): Rank of the current GPU. Defaults to None. + + Returns: + DataLoader: Torch dataloader. + """ + dataset = GANDataset( + ap=self.ap, + items=samples, + seq_len=config.seq_len, + hop_len=self.ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + verbose=verbose, + ) + dataset.shuffle_mapping() + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + drop_last=False, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_criterion(self): + """Return criterions for the optimizers""" + return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] + + @staticmethod + def init_from_config(config: Coqpit, verbose=True) -> "GAN": + ap = AudioProcessor.init_from_config(config, verbose=verbose) + return GAN(config, ap=ap) diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5eaf408c95372ea26f4e83db6f470b4dd92dfb --- /dev/null +++ b/TTS/vocoder/models/hifigan_discriminator.py @@ -0,0 +1,217 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import functional as F + +LRELU_SLOPE = 0.1 + + +class DiscriminatorP(torch.nn.Module): + """HiFiGAN Periodic Discriminator + + Takes every Pth value from the input waveform and applied a stack of convoluations. + + Note: + if `period` is 2 + `waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat` + + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + get_padding = lambda k, d: int((k * d - d) / 2) + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + feat = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + + return x, feat + + +class MultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Period Discriminator (MPD) + Wrapper for the `PeriodDiscriminator` to apply it in different periods. + Periods are suggested to be prime numbers to reduce the overlap between each discriminator. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2, use_spectral_norm=use_spectral_norm), + DiscriminatorP(3, use_spectral_norm=use_spectral_norm), + DiscriminatorP(5, use_spectral_norm=use_spectral_norm), + DiscriminatorP(7, use_spectral_norm=use_spectral_norm), + DiscriminatorP(11, use_spectral_norm=use_spectral_norm), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [List[Tensor]]: list of scores from each discriminator. + [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + scores = [] + feats = [] + for _, d in enumerate(self.discriminators): + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. + It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), + norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class MultiScaleDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Scale Discriminator. + It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper. + """ + + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores = [] + feats = [] + for i, d in enumerate(self.discriminators): + if i != 0: + x = self.meanpools[i - 1](x) + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class HifiganDiscriminator(nn.Module): + """HiFiGAN discriminator wrapping MPD and MSD.""" + + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.mpd(x) + scores_, feats_ = self.msd(x) + return scores + scores_, feats + feats_ diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..fc15f3af1033470990001cc5106dfe08c2930749 --- /dev/null +++ b/TTS/vocoder/models/hifigan_generator.py @@ -0,0 +1,300 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, weight_norm + +from TTS.utils.io import load_fsspec + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutiona block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutiona block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_weight_norm(self.conv_pre) + + if not conv_post_weight_norm: + remove_weight_norm(self.conv_post) + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + c = c.to(self.conv_pre.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() diff --git a/TTS/vocoder/models/melgan_discriminator.py b/TTS/vocoder/models/melgan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..14f00c5927cb28449c4fb0dc0727cde014370c2b --- /dev/null +++ b/TTS/vocoder/models/melgan_discriminator.py @@ -0,0 +1,84 @@ +import numpy as np +from torch import nn +from torch.nn.utils import weight_norm + + +class MelganDiscriminator(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4, 4), + groups_denominator=4, + ): + super().__init__() + self.layers = nn.ModuleList() + + layer_kernel_size = np.prod(kernel_sizes) + layer_padding = (layer_kernel_size - 1) // 2 + + # initial layer + self.layers += [ + nn.Sequential( + nn.ReflectionPad1d(layer_padding), + weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)), + nn.LeakyReLU(0.2, inplace=True), + ) + ] + + # downsampling layers + layer_in_channels = base_channels + for downsample_factor in downsample_factors: + layer_out_channels = min(layer_in_channels * downsample_factor, max_channels) + layer_kernel_size = downsample_factor * 10 + 1 + layer_padding = (layer_kernel_size - 1) // 2 + layer_groups = layer_in_channels // groups_denominator + self.layers += [ + nn.Sequential( + weight_norm( + nn.Conv1d( + layer_in_channels, + layer_out_channels, + kernel_size=layer_kernel_size, + stride=downsample_factor, + padding=layer_padding, + groups=layer_groups, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ) + ] + layer_in_channels = layer_out_channels + + # last 2 layers + layer_padding1 = (kernel_sizes[0] - 1) // 2 + layer_padding2 = (kernel_sizes[1] - 1) // 2 + self.layers += [ + nn.Sequential( + weight_norm( + nn.Conv1d( + layer_out_channels, + layer_out_channels, + kernel_size=kernel_sizes[0], + stride=1, + padding=layer_padding1, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ), + weight_norm( + nn.Conv1d( + layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2 + ) + ), + ] + + def forward(self, x): + feats = [] + for layer in self.layers: + x = layer(x) + feats.append(x) + return x, feats diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..80b478704ebdbcde2a1871a2481bc1b7f1f22fa9 --- /dev/null +++ b/TTS/vocoder/models/melgan_generator.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +from torch.nn.utils import weight_norm + +from TTS.utils.io import load_fsspec +from TTS.vocoder.layers.melgan import ResidualStack + + +class MelganGenerator(nn.Module): + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__() + + # assert model parameters + assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." + + # setup additional model parameters + base_padding = (proj_kernel - 1) // 2 + act_slope = 0.2 + self.inference_padding = 2 + + # initial layer + layers = [] + layers += [ + nn.ReflectionPad1d(base_padding), + weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), + ] + + # upsampling layers and residual stacks + for idx, upsample_factor in enumerate(upsample_factors): + layer_in_channels = base_channels // (2**idx) + layer_out_channels = base_channels // (2 ** (idx + 1)) + layer_filter_size = upsample_factor * 2 + layer_stride = upsample_factor + layer_output_padding = upsample_factor % 2 + layer_padding = upsample_factor // 2 + layer_output_padding + layers += [ + nn.LeakyReLU(act_slope), + weight_norm( + nn.ConvTranspose1d( + layer_in_channels, + layer_out_channels, + layer_filter_size, + stride=layer_stride, + padding=layer_padding, + output_padding=layer_output_padding, + bias=True, + ) + ), + ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), + ] + + layers += [nn.LeakyReLU(act_slope)] + + # final layer + layers += [ + nn.ReflectionPad1d(base_padding), + weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), + nn.Tanh(), + ] + self.layers = nn.Sequential(*layers) + + def forward(self, c): + return self.layers(c) + + def inference(self, c): + c = c.to(self.layers[1].weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.layers(c) + + def remove_weight_norm(self): + for _, layer in enumerate(self.layers): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..b4909f37c0c91c6fee8bb0baab98a8662039dea1 --- /dev/null +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -0,0 +1,50 @@ +from torch import nn + +from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator + + +class MelganMultiscaleDiscriminator(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + num_scales=3, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4), + pooling_kernel_size=4, + pooling_stride=2, + pooling_padding=2, + groups_denominator=4, + ): + super().__init__() + + self.discriminators = nn.ModuleList( + [ + MelganDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + base_channels=base_channels, + max_channels=max_channels, + downsample_factors=downsample_factors, + groups_denominator=groups_denominator, + ) + for _ in range(num_scales) + ] + ) + + self.pooling = nn.AvgPool1d( + kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False + ) + + def forward(self, x): + scores = [] + feats = [] + for disc in self.discriminators: + score, feat = disc(x) + scores.append(score) + feats.append(feat) + x = self.pooling(x) + return scores, feats diff --git a/TTS/vocoder/models/multiband_melgan_generator.py b/TTS/vocoder/models/multiband_melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..25d6590659cf5863176eb6609c7609b0e1b28d12 --- /dev/null +++ b/TTS/vocoder/models/multiband_melgan_generator.py @@ -0,0 +1,41 @@ +import torch + +from TTS.vocoder.layers.pqmf import PQMF +from TTS.vocoder.models.melgan_generator import MelganGenerator + + +class MultibandMelganGenerator(MelganGenerator): + def __init__( + self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) + self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) + + def pqmf_analysis(self, x): + return self.pqmf_layer.analysis(x) + + def pqmf_synthesis(self, x): + return self.pqmf_layer.synthesis(x) + + @torch.no_grad() + def inference(self, cond_features): + cond_features = cond_features.to(self.layers[1].weight.device) + cond_features = torch.nn.functional.pad( + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.pqmf_synthesis(self.layers(cond_features)) diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..adf1bdaea040e99dd66829b9b8ed184146e155cb --- /dev/null +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -0,0 +1,186 @@ +import math + +import torch +from torch import nn + +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock + + +class ParallelWaveganDiscriminator(nn.Module): + """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. + It classifies each audio window real/fake and returns a sequence + of predictions. + It is a stack of convolutional blocks with dilation. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ): + super().__init__() + assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." + assert dilation_factor > 0, " [!] dilation factor must be > 0." + self.conv_layers = nn.ModuleList() + conv_in_channels = in_channels + for i in range(num_layers - 1): + if i == 0: + dilation = 1 + else: + dilation = i if dilation_factor == 1 else dilation_factor**i + conv_in_channels = conv_channels + padding = (kernel_size - 1) // 2 * dilation + conv_layer = [ + nn.Conv1d( + conv_in_channels, + conv_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + ] + self.conv_layers += conv_layer + padding = (kernel_size - 1) // 2 + last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + self.conv_layers += [last_conv_layer] + self.apply_weight_norm() + + def forward(self, x): + """ + x : (B, 1, T). + Returns: + Tensor: (B, 1, T) + """ + for f in self.conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + +class ResidualParallelWaveganDiscriminator(nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): + super().__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.stacks = stacks + self.kernel_size = kernel_size + self.res_factor = math.sqrt(1.0 / num_layers) + + # check the number of num_layers and stacks + assert num_layers % stacks == 0 + layers_per_stack = num_layers // stacks + + # define first convolution + self.first_conv = nn.Sequential( + nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + ) + + # define residual blocks + self.conv_layers = nn.ModuleList() + for layer in range(num_layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=-1, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=False, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = nn.ModuleList( + [ + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True), + ] + ) + + # apply weight norm + self.apply_weight_norm() + + def forward(self, x): + """ + x: (B, 1, T). + """ + x = self.first_conv(x) + + skips = 0 + for f in self.conv_layers: + x, h = f(x, None) + skips += h + skips *= self.res_factor + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + print(f"Weight norm is removed from {m}.") + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9d8ad5c2b14902763ca39654e09ad4dceff060 --- /dev/null +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -0,0 +1,164 @@ +import math + +import numpy as np +import torch + +from TTS.utils.io import load_fsspec +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock +from TTS.vocoder.layers.upsample import ConvUpsample + + +class ParallelWaveganGenerator(torch.nn.Module): + """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. + It is similar to WaveNet with no causal convolution. + It is conditioned on an aux feature (spectrogram) to generate + an output waveform from an input noise. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=[4, 4, 4, 4], + inference_padding=2, + ): + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.num_res_blocks = num_res_blocks + self.stacks = stacks + self.kernel_size = kernel_size + self.upsample_factors = upsample_factors + self.upsample_scale = np.prod(upsample_factors) + self.inference_padding = inference_padding + self.use_weight_norm = use_weight_norm + + # check the number of layers and stacks + assert num_res_blocks % stacks == 0 + layers_per_stack = num_res_blocks // stacks + + # define first convolution + self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True) + + # define conv + upsampling network + self.upsample_net = ConvUpsample(upsample_factors=upsample_factors) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(num_res_blocks): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True), + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True), + ] + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, c): + """ + c: (B, C ,T'). + o: Output tensor (B, out_channels, T) + """ + # random noise + x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale]) + x = x.to(self.first_conv.bias.device) + + # perform upsampling + if c is not None and self.upsample_net is not None: + c = self.upsample_net(c) + assert ( + c.shape[-1] == x.shape[-1] + ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" + + # encode to hidden representation + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, c) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + + return x + + @torch.no_grad() + def inference(self, c): + c = c.to(self.first_conv.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.weight_norm(m) + # print(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + if self.use_weight_norm: + self.remove_weight_norm() diff --git a/TTS/vocoder/models/random_window_discriminator.py b/TTS/vocoder/models/random_window_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..ea95668a5fb6408488f0243c2e4e7f95ee4c6a6f --- /dev/null +++ b/TTS/vocoder/models/random_window_discriminator.py @@ -0,0 +1,204 @@ +import numpy as np +from torch import nn + + +class GBlock(nn.Module): + def __init__(self, in_channels, cond_channels, downsample_factor): + super().__init__() + + self.in_channels = in_channels + self.cond_channels = cond_channels + self.downsample_factor = downsample_factor + + self.start = nn.Sequential( + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + nn.ReLU(), + nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1), + ) + self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1) + self.end = nn.Sequential( + nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2) + ) + self.residual = nn.Sequential( + nn.Conv1d(in_channels, in_channels * 2, kernel_size=1), + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + ) + + def forward(self, inputs, conditions): + outputs = self.start(inputs) + self.lc_conv1d(conditions) + outputs = self.end(outputs) + residual_outputs = self.residual(inputs) + outputs = outputs + residual_outputs + + return outputs + + +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, downsample_factor): + super().__init__() + + self.in_channels = in_channels + self.downsample_factor = downsample_factor + self.out_channels = out_channels + + self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor) + self.layers = nn.Sequential( + nn.ReLU(), + nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2), + ) + self.residual = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=1), + ) + + def forward(self, inputs): + if self.downsample_factor > 1: + outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs)) + else: + outputs = self.layers(inputs) + self.residual(inputs) + return outputs + + +class ConditionalDiscriminator(nn.Module): + def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)): + super().__init__() + + assert len(downsample_factors) == len(out_channels) + 1 + + self.in_channels = in_channels + self.cond_channels = cond_channels + self.downsample_factors = downsample_factors + self.out_channels = out_channels + + self.pre_cond_layers = nn.ModuleList() + self.post_cond_layers = nn.ModuleList() + + # layers before condition features + self.pre_cond_layers += [DBlock(in_channels, 64, 1)] + in_channels = 64 + for (i, channel) in enumerate(out_channels): + self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i])) + in_channels = channel + + # condition block + self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1]) + + # layers after condition block + self.post_cond_layers += [ + DBlock(in_channels * 2, in_channels * 2, 1), + DBlock(in_channels * 2, in_channels * 2, 1), + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(in_channels * 2, 1, kernel_size=1), + ] + + def forward(self, inputs, conditions): + batch_size = inputs.size()[0] + outputs = inputs.view(batch_size, self.in_channels, -1) + for layer in self.pre_cond_layers: + outputs = layer(outputs) + outputs = self.cond_block(outputs, conditions) + for layer in self.post_cond_layers: + outputs = layer(outputs) + + return outputs + + +class UnconditionalDiscriminator(nn.Module): + def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)): + super().__init__() + + self.downsample_factors = downsample_factors + self.in_channels = in_channels + self.downsample_factors = downsample_factors + self.out_channels = out_channels + + self.layers = nn.ModuleList() + self.layers += [DBlock(self.in_channels, base_channels, 1)] + in_channels = base_channels + for (i, factor) in enumerate(downsample_factors): + self.layers.append(DBlock(in_channels, out_channels[i], factor)) + in_channels *= 2 + self.layers += [ + DBlock(in_channels, in_channels, 1), + DBlock(in_channels, in_channels, 1), + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(in_channels, 1, kernel_size=1), + ] + + def forward(self, inputs): + batch_size = inputs.size()[0] + outputs = inputs.view(batch_size, self.in_channels, -1) + for layer in self.layers: + outputs = layer(outputs) + return outputs + + +class RandomWindowDiscriminator(nn.Module): + """Random Window Discriminator as described in + http://arxiv.org/abs/1909.11646""" + + def __init__( + self, + cond_channels, + hop_length, + uncond_disc_donwsample_factors=(8, 4), + cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)), + cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)), + window_sizes=(512, 1024, 2048, 4096, 8192), + ): + + super().__init__() + self.cond_channels = cond_channels + self.window_sizes = window_sizes + self.hop_length = hop_length + self.base_window_size = self.hop_length * 2 + self.ks = [ws // self.base_window_size for ws in window_sizes] + + # check arguments + assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes) + for ws in window_sizes: + assert ws % hop_length == 0 + + for idx, cf in enumerate(cond_disc_downsample_factors): + assert np.prod(cf) == hop_length // self.ks[idx] + + # define layers + self.unconditional_discriminators = nn.ModuleList([]) + for k in self.ks: + layer = UnconditionalDiscriminator( + in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors + ) + self.unconditional_discriminators.append(layer) + + self.conditional_discriminators = nn.ModuleList([]) + for idx, k in enumerate(self.ks): + layer = ConditionalDiscriminator( + in_channels=k, + cond_channels=cond_channels, + downsample_factors=cond_disc_downsample_factors[idx], + out_channels=cond_disc_out_channels[idx], + ) + self.conditional_discriminators.append(layer) + + def forward(self, x, c): + scores = [] + feats = [] + # unconditional pass + for (window_size, layer) in zip(self.window_sizes, self.unconditional_discriminators): + index = np.random.randint(x.shape[-1] - window_size) + + score = layer(x[:, :, index : index + window_size]) + scores.append(score) + + # conditional pass + for (window_size, layer) in zip(self.window_sizes, self.conditional_discriminators): + frame_size = window_size // self.hop_length + lc_index = np.random.randint(c.shape[-1] - frame_size) + sample_index = lc_index * self.hop_length + x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length] + c_sub = c[:, :, lc_index : lc_index + frame_size] + + score = layer(x_sub, c_sub) + scores.append(score) + return scores, feats diff --git a/TTS/vocoder/models/univnet_discriminator.py b/TTS/vocoder/models/univnet_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b0e5d52c79873623988a8324c94175e5812d5d --- /dev/null +++ b/TTS/vocoder/models/univnet_discriminator.py @@ -0,0 +1,96 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils import spectral_norm, weight_norm + +from TTS.utils.audio import TorchSTFT +from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator + +LRELU_SLOPE = 0.1 + + +class SpecDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.fft_size = fft_size + self.hop_length = hop_length + self.win_length = win_length + self.stft = TorchSTFT(fft_size, hop_length, win_length) + self.discriminators = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), + ] + ) + + self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1)) + + def forward(self, y): + + fmap = [] + with torch.no_grad(): + y = y.squeeze(1) + y = self.stft(y) + y = y.unsqueeze(1) + for _, d in enumerate(self.discriminators): + y = d(y) + y = F.leaky_relu(y, LRELU_SLOPE) + fmap.append(y) + + y = self.out(y) + fmap.append(y) + + return torch.flatten(y, 1, -1), fmap + + +class MultiResSpecDiscriminator(torch.nn.Module): + def __init__( # pylint: disable=dangerous-default-value + self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window" + ): + + super().__init__() + self.discriminators = nn.ModuleList( + [ + SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window), + SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window), + SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window), + ] + ) + + def forward(self, x): + scores = [] + feats = [] + for d in self.discriminators: + score, feat = d(x) + scores.append(score) + feats.append(feat) + + return scores, feats + + +class UnivnetDiscriminator(nn.Module): + """Univnet discriminator wrapping MPD and MSD.""" + + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiResSpecDiscriminator() + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.mpd(x) + scores_, feats_ = self.msd(x) + return scores + scores_, feats + feats_ diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2ee28c7b85852c6b15df28907b6fd1195f3218cd --- /dev/null +++ b/TTS/vocoder/models/univnet_generator.py @@ -0,0 +1,156 @@ +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F + +from TTS.vocoder.layers.lvc_block import LVCBlock + +LRELU_SLOPE = 0.1 + + +class UnivnetGenerator(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + cond_channels: int, + upsample_factors: List[int], + lvc_layers_each_block: int, + lvc_kernel_size: int, + kpnet_hidden_channels: int, + kpnet_conv_size: int, + dropout: float, + use_weight_norm=True, + ): + """Univnet Generator network. + + Paper: https://arxiv.org/pdf/2106.07889.pdf + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of channels of the output tensor. + hidden_channels (int): Number of hidden network channels. + cond_channels (int): Number of channels of the conditioning tensors. + upsample_factors (List[int]): List of uplsample factors for the upsampling layers. + lvc_layers_each_block (int): Number of LVC layers in each block. + lvc_kernel_size (int): Kernel size of the LVC layers. + kpnet_hidden_channels (int): Number of hidden channels in the key-point network. + kpnet_conv_size (int): Number of convolution channels in the key-point network. + dropout (float): Dropout rate. + use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True. + """ + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.cond_channels = cond_channels + self.upsample_scale = np.prod(upsample_factors) + self.lvc_block_nums = len(upsample_factors) + + # define first convolution + self.first_conv = torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True + ) + + # define residual blocks + self.lvc_blocks = torch.nn.ModuleList() + cond_hop_length = 1 + for n in range(self.lvc_block_nums): + cond_hop_length = cond_hop_length * upsample_factors[n] + lvcb = LVCBlock( + in_channels=hidden_channels, + cond_channels=cond_channels, + upsample_ratio=upsample_factors[n], + conv_layers=lvc_layers_each_block, + conv_kernel_size=lvc_kernel_size, + cond_hop_length=cond_hop_length, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=dropout, + ) + self.lvc_blocks += [lvcb] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True + ), + ] + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, c): + """Calculate forward propagation. + Args: + c (Tensor): Local conditioning auxiliary features (B, C ,T'). + Returns: + Tensor: Output tensor (B, out_channels, T) + """ + # random noise + x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) + x = x.to(self.first_conv.bias.device) + x = self.first_conv(x) + + for n in range(self.lvc_block_nums): + x = self.lvc_blocks[n](x, c) + + # apply final layers + for f in self.last_conv_layers: + x = F.leaky_relu(x, LRELU_SLOPE) + x = f(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.weight_norm(m) + # print(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + """Return receptive field size.""" + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + @torch.no_grad() + def inference(self, c): + """Perform inference. + Args: + c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`. + Returns: + Tensor: Output tensor (T, out_channels) + """ + x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) + x = x.to(self.first_conv.bias.device) + + c = c.to(next(self.parameters())) + return self.forward(c) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..c4968f1f1788613e89ddd2b3d38993278139e73f --- /dev/null +++ b/TTS/vocoder/models/wavegrad.py @@ -0,0 +1,344 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn.utils import weight_norm +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.utils.io import load_fsspec +from TTS.vocoder.datasets import WaveGradDataset +from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.generic_utils import plot_results + + +@dataclass +class WavegradArgs(Coqpit): + in_channels: int = 80 + out_channels: int = 1 + use_weight_norm: bool = False + y_conv_channels: int = 32 + x_conv_channels: int = 768 + dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512]) + ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128]) + upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2]) + upsample_dilations: List[List[int]] = field( + default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]] + ) + + +class Wavegrad(BaseVocoder): + """🐸 🌊 WaveGrad 🌊 model. + Paper - https://arxiv.org/abs/2009.00713 + + Examples: + Initializing the model. + + >>> from TTS.vocoder.configs import WavegradConfig + >>> config = WavegradConfig() + >>> model = Wavegrad(config) + + Paper Abstract: + This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the + data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts + from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned + on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting + the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in + terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations. + Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive + baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations. + Audio samples are available at this https URL. + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit): + super().__init__(config) + self.config = config + self.use_weight_norm = config.model_params.use_weight_norm + self.hop_len = np.prod(config.model_params.upsample_factors) + self.noise_level = None + self.num_steps = None + self.beta = None + self.alpha = None + self.alpha_hat = None + self.c1 = None + self.c2 = None + self.sigma = None + + # dblocks + self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2) + self.dblocks = nn.ModuleList([]) + ic = config.model_params.y_conv_channels + for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)): + self.dblocks.append(DBlock(ic, oc, df)) + ic = oc + + # film + self.film = nn.ModuleList([]) + ic = config.model_params.y_conv_channels + for oc in reversed(config.model_params.ublock_out_channels): + self.film.append(FiLM(ic, oc)) + ic = oc + + # ublocksn + self.ublocks = nn.ModuleList([]) + ic = config.model_params.x_conv_channels + for oc, uf, ud in zip( + config.model_params.ublock_out_channels, + config.model_params.upsample_factors, + config.model_params.upsample_dilations, + ): + self.ublocks.append(UBlock(ic, oc, uf, ud)) + ic = oc + + self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1) + self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1) + + if config.model_params.use_weight_norm: + self.apply_weight_norm() + + def forward(self, x, spectrogram, noise_scale): + shift_and_scale = [] + + x = self.y_conv(x) + shift_and_scale.append(self.film[0](x, noise_scale)) + + for film, layer in zip(self.film[1:], self.dblocks): + x = layer(x) + shift_and_scale.append(film(x, noise_scale)) + + x = self.x_conv(spectrogram) + for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)): + x = layer(x, film_shift, film_scale) + x = self.out_conv(x) + return x + + def load_noise_schedule(self, path): + beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg + self.compute_noise_level(beta) + + @torch.no_grad() + def inference(self, x, y_n=None): + """ + Shapes: + x: :math:`[B, C , T]` + y_n: :math:`[B, 1, T]` + """ + if y_n is None: + y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1]) + else: + y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0) + y_n = y_n.type_as(x) + sqrt_alpha_hat = self.noise_level.to(x) + for n in range(len(self.alpha) - 1, -1, -1): + y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) + if n > 0: + z = torch.randn_like(y_n) + y_n += self.sigma[n - 1] * z + y_n.clamp_(-1.0, 1.0) + return y_n + + def compute_y_n(self, y_0): + """Compute noisy audio based on noise schedule""" + self.noise_level = self.noise_level.to(y_0) + if len(y_0.shape) == 3: + y_0 = y_0.squeeze(1) + s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]]) + l_a, l_b = self.noise_level[s], self.noise_level[s + 1] + noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) + noise_scale = noise_scale.unsqueeze(1) + noise = torch.randn_like(y_0) + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise + return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] + + def compute_noise_level(self, beta): + """Compute noise schedule parameters""" + self.num_steps = len(beta) + alpha = 1 - beta + alpha_hat = np.cumprod(alpha) + noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0) + noise_level = alpha_hat**0.5 + + # pylint: disable=not-callable + self.beta = torch.tensor(beta.astype(np.float32)) + self.alpha = torch.tensor(alpha.astype(np.float32)) + self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) + self.noise_level = torch.tensor(noise_level.astype(np.float32)) + + self.c1 = 1 / self.alpha**0.5 + self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 + self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 + + def remove_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() + + nn.utils.remove_weight_norm(self.x_conv) + nn.utils.remove_weight_norm(self.out_conv) + nn.utils.remove_weight_norm(self.y_conv) + + def apply_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + self.x_conv = weight_norm(self.x_conv) + self.out_conv = weight_norm(self.out_conv) + self.y_conv = weight_norm(self.y_conv) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + if self.config.model_params.use_weight_norm: + self.remove_weight_norm() + betas = np.linspace( + config["test_noise_schedule"]["min_val"], + config["test_noise_schedule"]["max_val"], + config["test_noise_schedule"]["num_steps"], + ) + self.compute_noise_level(betas) + else: + betas = np.linspace( + config["train_noise_schedule"]["min_val"], + config["train_noise_schedule"]["max_val"], + config["train_noise_schedule"]["num_steps"], + ) + self.compute_noise_level(betas) + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + # format data + x = batch["input"] + y = batch["waveform"] + + # set noise scale + noise, x_noisy, noise_scale = self.compute_y_n(y) + + # forward pass + noise_hat = self.forward(x_noisy, x, noise_scale) + + # compute losses + loss = criterion(noise, noise_hat) + return {"model_output": noise_hat}, {"loss": loss} + + def train_log( # pylint: disable=no-self-use + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + pass + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + def eval_log( # pylint: disable=no-self-use + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> None: + pass + + def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument + # setup noise schedule and inference + ap = assets["audio_processor"] + noise_schedule = self.config["test_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) + samples = test_loader.dataset.load_test_samples(1) + for sample in samples: + x = sample[0] + x = x[None, :, :].to(next(self.parameters()).device) + y = sample[1] + y = y[None, :] + # compute voice + y_pred = self.inference(x) + # compute spectrograms + figures = plot_results(y_pred, y, ap, "test") + # Sample audio + sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() + return figures, {"test/audio": sample_voice} + + def get_optimizer(self): + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer): + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + + @staticmethod + def get_criterion(): + return torch.nn.L1Loss() + + @staticmethod + def format_batch(batch: Dict) -> Dict: + # return a whole audio segment + m, y = batch[0], batch[1] + y = y.unsqueeze(1) + return {"input": m, "waveform": y} + + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): + ap = assets["audio_processor"] + dataset = WaveGradDataset( + ap=ap, + items=samples, + seq_len=self.config.seq_len, + hop_len=ap.hop_length, + pad_short=self.config.pad_short, + conv_pad=self.config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + verbose=verbose, + ) + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=self.config.batch_size, + shuffle=num_gpus <= 1, + drop_last=False, + sampler=sampler, + num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers, + pin_memory=False, + ) + return loader + + def on_epoch_start(self, trainer): # pylint: disable=unused-argument + noise_schedule = self.config["train_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) + + @staticmethod + def init_from_config(config: "WavegradConfig"): + return Wavegrad(config) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..6686db45dd32b6f4f3bd54e702787412fc344a6b --- /dev/null +++ b/TTS/vocoder/models/wavernn.py @@ -0,0 +1,638 @@ +import sys +import time +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset +from TTS.vocoder.layers.losses import WaveRNNLoss +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian + + +def stream(string, variables): + sys.stdout.write(f"\r{string}" % variables) + + +# pylint: disable=abstract-method +# relates https://github.com/pytorch/pytorch/issues/42305 +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Module): + def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + k_size = pad * 2 + 1 + self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for _ in range(num_res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__( + self, + feat_dims, + upsample_scales, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ): + super().__init__() + self.total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * self.total_scale + self.use_aux_net = use_aux_net + if use_aux_net: + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet_stretch = Stretch2d(self.total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1.0 / k_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + aux = aux.transpose(1, 2) + else: + aux = None + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + m = m.squeeze(1)[:, :, self.indent : -self.indent] + return m.transpose(1, 2), aux + + +class Upsample(nn.Module): + def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net): + super().__init__() + self.scale = scale + self.pad = pad + self.indent = pad * scale + self.use_aux_net = use_aux_net + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m) + aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True) + aux = aux.transpose(1, 2) + else: + aux = None + m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True) + m = m[:, :, self.indent : -self.indent] + m = m * 0.045 # empirically found + + return m.transpose(1, 2), aux + + +@dataclass +class WavernnArgs(Coqpit): + """🐸 WaveRNN model arguments. + + rnn_dims (int): + Number of hidden channels in RNN layers. Defaults to 512. + fc_dims (int): + Number of hidden channels in fully-conntected layers. Defaults to 512. + compute_dims (int): + Number of hidden channels in the feature ResNet. Defaults to 128. + res_out_dim (int): + Number of hidden channels in the feature ResNet output. Defaults to 128. + num_res_blocks (int): + Number of residual blocks in the ResNet. Defaults to 10. + use_aux_net (bool): + enable/disable the feature ResNet. Defaults to True. + use_upsample_net (bool): + enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True. + upsample_factors (list): + Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```. + mode (str): + Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single + Gaussian Distribution and `bits` for quantized bits as the model's output. + mulaw (bool): + enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults + to `True`. + pad (int): + Padding applied to the input feature frames against the convolution layers of the feature network. + Defaults to 2. + """ + + rnn_dims: int = 512 + fc_dims: int = 512 + compute_dims: int = 128 + res_out_dims: int = 128 + num_res_blocks: int = 10 + use_aux_net: bool = True + use_upsample_net: bool = True + upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8]) + mode: str = "mold" # mold [string], gauss [string], bits [int] + mulaw: bool = True # apply mulaw if mode is bits + pad: int = 2 + feat_dims: int = 80 + + +class Wavernn(BaseVocoder): + def __init__(self, config: Coqpit): + """🐸 WaveRNN model. + Original paper - https://arxiv.org/abs/1802.08435 + Official implementation - https://github.com/fatchord/WaveRNN + + Args: + config (Coqpit): [description] + + Raises: + RuntimeError: [description] + + Examples: + >>> from TTS.vocoder.configs import WavernnConfig + >>> config = WavernnConfig() + >>> model = Wavernn(config) + + Paper Abstract: + Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to + both estimating the data distribution and generating high-quality samples. Efficient sampling for this + class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we + describe a set of general techniques for reducing sampling time while maintaining high output quality. + We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that + matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it + possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight + pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of + parameters, large sparse networks perform better than small dense networks and this relationship holds for + sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample + high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on + subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple + samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an + orthogonal method for increasing sampling efficiency. + """ + super().__init__(config) + + if isinstance(self.args.mode, int): + self.n_classes = 2**self.args.mode + elif self.args.mode == "mold": + self.n_classes = 3 * 10 + elif self.args.mode == "gauss": + self.n_classes = 2 + else: + raise RuntimeError("Unknown model mode value - ", self.args.mode) + + self.aux_dims = self.args.res_out_dims // 4 + + if self.args.use_upsample_net: + assert ( + np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length + ), " [!] upsample scales needs to be equal to hop_length" + self.upsample = UpsampleNetwork( + self.args.feat_dims, + self.args.upsample_factors, + self.args.compute_dims, + self.args.num_res_blocks, + self.args.res_out_dims, + self.args.pad, + self.args.use_aux_net, + ) + else: + self.upsample = Upsample( + config.audio.hop_length, + self.args.pad, + self.args.num_res_blocks, + self.args.feat_dims, + self.args.compute_dims, + self.args.res_out_dims, + self.args.use_aux_net, + ) + if self.args.use_aux_net: + self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) + else: + self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) + + def forward(self, x, mels): + bsize = x.size(0) + h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) + h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) + mels, aux = self.upsample(mels) + + if self.args.use_aux_net: + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] + + x = ( + torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + if self.args.use_aux_net + else torch.cat([x.unsqueeze(-1), mels], dim=2) + ) + x = self.I(x) + res = x + self.rnn1.flatten_parameters() + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x + self.rnn2.flatten_parameters() + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def inference(self, mels, batched=None, target=None, overlap=None): + + self.eval() + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + + with torch.no_grad(): + if isinstance(mels, np.ndarray): + mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device)) + + if mels.ndim == 2: + mels = mels.unsqueeze(0) + wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length + + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both") + mels, aux = self.upsample(mels.transpose(1, 2)) + + if batched: + mels = self.fold_with_overlap(mels, target, overlap) + if aux is not None: + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mels.size() + + h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) + h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) + x = torch.zeros(b_size, 1).type_as(mels) + + if self.args.use_aux_net: + d = self.aux_dims + aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + + m_t = mels[:, i, :] + + if self.args.use_aux_net: + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + + x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1) + x = self.I(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.args.mode == "mold": + sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).type_as(mels) + elif self.args.mode == "gauss": + sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).type_as(mels) + elif isinstance(self.args.mode, int): + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0 + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError("Unknown model mode value - ", self.args.mode) + + if i % 100 == 0: + self.gen_display(i, seq_len, b_size, start) + + output = torch.stack(output).transpose(0, 1) + output = output.cpu() + if batched: + output = output.numpy() + output = output.astype(np.float64) + + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + if self.args.mulaw and isinstance(self.args.mode, int): + output = AudioProcessor.mulaw_decode(output, self.args.mode) + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) + output = output[:wave_len] + + if wave_len > len(fade_out): + output[-20 * self.config.audio.hop_length :] *= fade_out + + self.train() + return output + + def gen_display(self, i, seq_len, b_size, start): + gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 + realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate + stream( + "%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ", + (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio), + ) + + def fold_with_overlap(self, x, target, overlap): + """Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + Args: + x (tensor) : Upsampled conditioning features. + shape=(1, timesteps, features) + target (int) : Target timesteps for each index of batch + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + Details: + x = [[h1, h2, ... hn]] + Where each h is a vector of conditioning features + Eg: target=2, overlap=1 with x.size(1)=10 + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + """ + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side="after") + + folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device) + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + + @staticmethod + def get_gru_cell(gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + + @staticmethod + def pad_tensor(x, pad, side="both"): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == "both" else t + pad + padded = torch.zeros(b, total, c).to(x.device) + if side in ("before", "both"): + padded[:, pad : pad + t, :] = x + elif side == "after": + padded[:, :t, :] = x + return padded + + @staticmethod + def xfade_and_unfold(y, target, overlap): + """Applies a crossfade and unfolds into a 1d array. + Args: + y (ndarry) : Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=np.float64 + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (ndarry) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + Details: + y = [[seq1], + [seq2], + [seq3]] + Apply a gain envelope at both ends of the sequences + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + Stagger and add up the groups of samples: + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + """ + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = np.zeros((silence_len), dtype=np.float64) + + # Equal power crossfade + t = np.linspace(-1, 1, fade_len, dtype=np.float64) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = np.concatenate([silence, fade_in]) + fade_out = np.concatenate([fade_out, silence]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = np.zeros((total_len), dtype=np.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + mels = batch["input"] + waveform = batch["waveform"] + waveform_coarse = batch["waveform_coarse"] + + y_hat = self.forward(waveform, mels) + if isinstance(self.args.mode, int): + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + else: + waveform_coarse = waveform_coarse.float() + waveform_coarse = waveform_coarse.unsqueeze(-1) + # compute losses + loss_dict = criterion(y_hat, waveform_coarse) + return {"model_output": y_hat}, loss_dict + + def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + @torch.no_grad() + def test( + self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, Dict]: + ap = assets["audio_processor"] + figures = {} + audios = {} + samples = test_loader.dataset.load_test_samples(1) + for idx, sample in enumerate(samples): + x = torch.FloatTensor(sample[0]) + x = x.to(next(self.parameters()).device) + y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) + x_hat = ap.melspectrogram(y_hat) + figures.update( + { + f"test_{idx}/ground_truth": plot_spectrogram(x.T), + f"test_{idx}/prediction": plot_spectrogram(x_hat.T), + } + ) + audios.update({f"test_{idx}/audio": y_hat}) + return figures, audios + + @staticmethod + def format_batch(batch: Dict) -> Dict: + waveform = batch[0] + mels = batch[1] + waveform_coarse = batch[2] + return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse} + + def get_data_loader( # pylint: disable=no-self-use + self, + config: Coqpit, + assets: Dict, + is_eval: True, + samples: List, + verbose: bool, + num_gpus: int, + ): + ap = assets["audio_processor"] + dataset = WaveRNNDataset( + ap=ap, + items=samples, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_args.pad, + mode=config.model_args.mode, + mulaw=config.model_args.mulaw, + is_training=not is_eval, + verbose=verbose, + ) + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + collate_fn=dataset.collate, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + return loader + + def get_criterion(self): + # define train functions + return WaveRNNLoss(self.args.mode) + + @staticmethod + def init_from_config(config: "WavernnConfig"): + return Wavernn(config) diff --git a/TTS/vocoder/pqmf_output.wav b/TTS/vocoder/pqmf_output.wav new file mode 100644 index 0000000000000000000000000000000000000000..8a77747b00198a4adfd6c398998517df5b4bdb8d Binary files /dev/null and b/TTS/vocoder/pqmf_output.wav differ diff --git a/TTS/vocoder/utils/__init__.py b/TTS/vocoder/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..fe706ba9ffbc3f8aad75285bca34a910246666b3 --- /dev/null +++ b/TTS/vocoder/utils/distribution.py @@ -0,0 +1,154 @@ +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions.normal import Normal + + +def gaussian_loss(y_hat, y, log_std_min=-7.0): + assert y_hat.dim() == 3 + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + # TODO: replace with pytorch dist + log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))) + return log_probs.squeeze().mean() + + +def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + dist = Normal( + mean, + torch.exp(log_std), + ) + sample = dist.sample() + sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) + del dist + return sample + + +def log_sum_exp(x): + """numerically stable log_sum_exp implementation that prevents overflow""" + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + y_hat = y_hat.permute(0, 2, 1) + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + # tf equivalent + + # log_probs = tf.where(x < -0.999, log_cdf_plus, + # tf.where(x > 0.999, log_one_minus_cdf_min, + # tf.where(cdf_delta > 1e-5, + # tf.log(tf.maximum(cdf_delta, 1e-12)), + # log_pdf_mid - np.log(127.5)))) + + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - np.log((num_classes - 1) / 2) + ) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.mean(log_sum_exp(log_probs)) + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + Args: + y (Tensor): :math:`[B, C, T]` + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) + + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) + + return x + + +def to_one_hot(tensor, n, fill_with=1.0): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor) + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..63a0af4445b5684e928b83d2f4fdfaf7e8f5b9a2 --- /dev/null +++ b/TTS/vocoder/utils/generic_utils.py @@ -0,0 +1,72 @@ +from typing import Dict + +import numpy as np +import torch +from matplotlib import pyplot as plt + +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor + + +def interpolate_vocoder_input(scale_factor, spec): + """Interpolate spectrogram by the scale factor. + It is mainly used to match the sampling rates of + the tts and vocoder models. + + Args: + scale_factor (float): scale factor to interpolate the spectrogram + spec (np.array): spectrogram to be interpolated + + Returns: + torch.tensor: interpolated spectrogram. + """ + print(" > before interpolation :", spec.shape) + spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable + spec = torch.nn.functional.interpolate( + spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False + ).squeeze(0) + print(" > after interpolation :", spec.shape) + return spec + + +def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: + """Plot the predicted and the real waveform and their spectrograms. + + Args: + y_hat (torch.tensor): Predicted waveform. + y (torch.tensor): Real waveform. + ap (AudioProcessor): Audio processor used to process the waveform. + name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. + + Returns: + Dict: output figures keyed by the name of the figures. + """ """Plot vocoder model results""" + if name_prefix is None: + name_prefix = "" + + # select an instance from batch + y_hat = y_hat[0].squeeze().detach().cpu().numpy() + y = y[0].squeeze().detach().cpu().numpy() + + spec_fake = ap.melspectrogram(y_hat).T + spec_real = ap.melspectrogram(y).T + spec_diff = np.abs(spec_fake - spec_real) + + # plot figure and save it + fig_wave = plt.figure() + plt.subplot(2, 1, 1) + plt.plot(y) + plt.title("groundtruth speech") + plt.subplot(2, 1, 2) + plt.plot(y_hat) + plt.title("generated speech") + plt.tight_layout() + plt.close() + + figures = { + name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake), + name_prefix + "spectrogram/real": plot_spectrogram(spec_real), + name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff), + name_prefix + "speech_comparison": fig_wave, + } + return figures diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..33111b2452bc2adc9e930b783cd25003934fe840 --- /dev/null +++ b/app.py @@ -0,0 +1,38 @@ +import os +import gradio as gr + +models = [ + "https://github.com/AI4Bharat/Indic-TTS/releases/download/v1-checkpoints-release/hi.zip", + "https://github.com/AI4Bharat/Indic-TTS/releases/download/v1-checkpoints-release/bn.zip" +] + +for model in models: + os.system(f"wget {model}") + os.system(f"unzip {model.split('/')[-1]}") + os.system(f"rm -fr {{model.split('/')[-1]}}") + +def convert(text, language, out = "out.wav"): + if language == "Hindi": + m = "hi" + else: + m = "bn" + + os.system(f'python3 -m TTS.bin.synthesize --text "{text}" --model_path {m}/fastpitch/best_model.pth --config_path {m}/fastpitch/config.json --vocoder_path {m}/hifigan/best_model.pth --vocoder_config_path {m}/hifigan/config.json --speaker_idx "male" --out_path {out}') + + return out + +text = gr.Textbox(value = "यह कल का दिन अद्भुत था क्योंकि हम संगीत कार्यक्रम से वापस आ गए हैं।", + placeholder = "Enter a text to synthesize", + label = "Text") + +language = gr.Dropdown(choices = ["Hindi", "Bangla"], + value = "Hindi", + type = "value", + label = "Language") + +inputs = [text, language] +outputs = gr.outputs.Audio(label = "Output Audio", type = 'filepath') + +title = "Indic Languages Speech Synthesis" + +gr.Interface(convert, inputs, outputs, title=title, enable_queue=True).launch() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f102ee7930a338b04343d93df8b399a6e1cfd57 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,36 @@ +# core deps +numpy==1.23.0 +cython==0.29.28 +scipy>=1.4.0 +torch>=1.7 +torchaudio +soundfile +librosa==0.8.0 +numba +inflect +tqdm +anyascii +pyyaml +fsspec>=2021.04.0 +# deps for examples +flask +# deps for inference +pysbd +# deps for notebooks +umap-learn==0.5.1 +pandas +# deps for training +matplotlib +pyworld==0.2.10 # > 0.2.10 is not p3.10.x compatible +# coqui stack +trainer +# config management +coqpit>=0.0.16 +# chinese g2p deps +jieba +pypinyin +# japanese g2p deps +mecab-python3==1.0.5 +unidic-lite==1.0.8 +# gruut+supported langs +gruut[cs,de,es,fr,it,nl,pt,ru,sv]==2.2.3