diff --git a/viXTTS/.gitattributes b/viXTTS/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..dab9a4e17afd2ef39d90ccb0b40ef2786fe77422 --- /dev/null +++ b/viXTTS/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/viXTTS/TTS/.models.json b/viXTTS/TTS/.models.json new file mode 100644 index 0000000000000000000000000000000000000000..b349e7397b1b9ea7e50f537969d1a0cf087c27e0 --- /dev/null +++ b/viXTTS/TTS/.models.json @@ -0,0 +1,938 @@ +{ + "tts_models": { + "multilingual": { + "multi-dataset": { + "xtts_v2": { + "description": "XTTS-v2.0.3 by Coqui with 17 languages.", + "hf_url": [ + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/speakers_xtts.pth" + ], + "model_hash": "10f92b55c512af7a8d39d650547a15a7", + "default_vocoder": null, + "commit": "480a6cdf7", + "license": "CPML", + "contact": "info@coqui.ai", + "tos_required": true + }, + "xtts_v1.1": { + "description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.", + "hf_url": [ + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/model.pth", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/config.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/vocab.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.2/hash.md5" + ], + "model_hash": "7c62beaf58d39b729de287330dc254e7b515677416839b649a50e7cf74c3df59", + "default_vocoder": null, + "commit": "82910a63", + "license": "CPML", + "contact": "info@coqui.ai", + "tos_required": true + }, + "your_tts": { + "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip", + "default_vocoder": null, + "commit": "e9a1953e", + "license": "CC BY-NC-ND 4.0", + "contact": "egolge@coqui.ai" + }, + "bark": { + "description": "🐶 Bark TTS model released by suno-ai. You can find the original implementation in https://github.com/suno-ai/bark.", + "hf_url": [ + "https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", + "https://coqui.gateway.scarf.sh/hf/text_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/config.json", + "https://coqui.gateway.scarf.sh/hf/bark/hubert.pt", + "https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth" + ], + "default_vocoder": null, + "commit": "e9a1953e", + "license": "MIT", + "contact": "https://www.suno.ai/" + } + } + }, + "bg": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--bg--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "cs": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--cs--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "da": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--da--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "et": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--et--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "ga": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--ga--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "en": { + "ek1": { + "tacotron2": { + "description": "EK1 en-rp tacotron2 by NMStoker", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ek1--tacotron2.zip", + "default_vocoder": "vocoder_models/en/ek1/wavegrad", + "commit": "c802255", + "license": "apache 2.0" + } + }, + "ljspeech": { + "tacotron2-DDC": { + "description": "Tacotron2 with Double Decoder Consistency.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "bae2ad0f", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "tacotron2-DDC_ph": { + "description": "Tacotron2 with Double Decoder Consistency with phonemes.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--tacotron2-DDC_ph.zip", + "default_vocoder": "vocoder_models/en/ljspeech/univnet", + "commit": "3900448", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "glow-tts": { + "description": "", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--glow-tts.zip", + "stats_file": null, + "default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan", + "commit": "", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "speedy-speech": { + "description": "Speedy Speech model trained on LJSpeech dataset using the Alignment Network for learning the durations.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--speedy-speech.zip", + "stats_file": null, + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "4581e3d", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "tacotron2-DCA": { + "description": "", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--tacotron2-DCA.zip", + "default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan", + "commit": "", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "vits": { + "description": "VITS is an End2End TTS model trained on LJSpeech dataset with phonemes.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--vits.zip", + "default_vocoder": null, + "commit": "3900448", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "vits--neon": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--en--ljspeech--vits.zip", + "default_vocoder": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause", + "contact": null, + "commit": null + }, + "fast_pitch": { + "description": "FastPitch model trained on LJSpeech using the Aligner Network", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--ljspeech--fast_pitch.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "b27b3ba", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + }, + "overflow": { + "description": "Overflow model trained on LJSpeech", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.0_models/tts_models--en--ljspeech--overflow.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "3b1a28f", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + }, + "neural_hmm": { + "description": "Neural HMM model trained on LJSpeech", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.11.0_models/tts_models--en--ljspeech--neural_hmm.zip", + "default_vocoder": "vocoder_models/en/ljspeech/hifigan_v2", + "commit": "3b1a28f", + "author": "Shivam Metha @shivammehta25", + "license": "apache 2.0", + "contact": "d83ee8fe45e3c0d776d4a865aca21d7c2ac324c4" + } + }, + "vctk": { + "vits": { + "description": "VITS End2End TTS model trained on VCTK dataset with 109 different speakers with EN accent.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--vctk--vits.zip", + "default_vocoder": null, + "commit": "3900448", + "author": "Eren @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + }, + "fast_pitch": { + "description": "FastPitch model trained on VCTK dataseset.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--vctk--fast_pitch.zip", + "default_vocoder": null, + "commit": "bdab788d", + "author": "Eren @erogol", + "license": "CC BY-NC-ND 4.0", + "contact": "egolge@coqui.ai" + } + }, + "sam": { + "tacotron-DDC": { + "description": "Tacotron2 with Double Decoder Consistency trained with Aceenture's Sam dataset.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--en--sam--tacotron-DDC.zip", + "default_vocoder": "vocoder_models/en/sam/hifigan_v2", + "commit": "bae2ad0f", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.com" + } + }, + "blizzard2013": { + "capacitron-t2-c50": { + "description": "Capacitron additions to Tacotron 2 with Capacity at 50 as in https://arxiv.org/pdf/1906.03402.pdf", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--en--blizzard2013--capacitron-t2-c50.zip", + "commit": "d6284e7", + "default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2", + "author": "Adam Froghyar @a-froghyar", + "license": "apache 2.0", + "contact": "adamfroghyar@gmail.com" + }, + "capacitron-t2-c150_v2": { + "description": "Capacitron additions to Tacotron 2 with Capacity at 150 as in https://arxiv.org/pdf/1906.03402.pdf", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.1_models/tts_models--en--blizzard2013--capacitron-t2-c150_v2.zip", + "commit": "a67039d", + "default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2", + "author": "Adam Froghyar @a-froghyar", + "license": "apache 2.0", + "contact": "adamfroghyar@gmail.com" + } + }, + "multi-dataset": { + "tortoise-v2": { + "description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts", + "github_rls_url": [ + "https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_auto.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/rlg_diffuser.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/vocoder.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth", + "https://coqui.gateway.scarf.sh/v0.14.1_models/config.json" + ], + "commit": "c1875f6", + "default_vocoder": null, + "author": "@neonbjb - James Betker, @manmay-nakhashi Manmay Nakhashi", + "license": "apache 2.0" + } + }, + "jenny": { + "jenny": { + "description": "VITS model trained with Jenny(Dioco) dataset. Named as Jenny as demanded by the license. Original URL for the model https://www.kaggle.com/datasets/noml4u/tts-models--en--jenny-dioco--vits", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.14.0_models/tts_models--en--jenny--jenny.zip", + "default_vocoder": null, + "commit": "ba40a1c", + "license": "custom - see https://github.com/dioco-group/jenny-tts-dataset#important", + "author": "@noml4u" + } + } + }, + "es": { + "mai": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--es--mai--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan", + "commit": "", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + } + }, + "css10": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--es--css10--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "fr": { + "mai": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--fr--mai--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan", + "commit": null, + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + } + }, + "css10": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--fr--css10--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "uk": { + "mai": { + "glow-tts": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--uk--mai--glow-tts.zip", + "author": "@robinhad", + "commit": "bdab788d", + "license": "MIT", + "contact": "", + "default_vocoder": "vocoder_models/uk/mai/multiband-melgan" + }, + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--uk--mai--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "zh-CN": { + "baker": { + "tacotron2-DDC-GST": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--zh-CN--baker--tacotron2-DDC-GST.zip", + "commit": "unknown", + "author": "@kirianguiller", + "license": "apache 2.0", + "default_vocoder": null + } + } + }, + "nl": { + "mai": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--nl--mai--tacotron2-DDC.zip", + "author": "@r-dh", + "license": "apache 2.0", + "default_vocoder": "vocoder_models/nl/mai/parallel-wavegan", + "stats_file": null, + "commit": "540d811" + } + }, + "css10": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--nl--css10--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "de": { + "thorsten": { + "tacotron2-DCA": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--de--thorsten--tacotron2-DCA.zip", + "default_vocoder": "vocoder_models/de/thorsten/fullband-melgan", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + }, + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--de--thorsten--vits.zip", + "default_vocoder": null, + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + }, + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--de--thorsten--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/de/thorsten/hifigan_v1", + "description": "Thorsten-Dec2021-22k-DDC", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + } + }, + "css10": { + "vits-neon": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--de--css10--vits.zip", + "default_vocoder": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause", + "commit": null + } + } + }, + "ja": { + "kokoro": { + "tacotron2-DDC": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--ja--kokoro--tacotron2-DDC.zip", + "default_vocoder": "vocoder_models/ja/kokoro/hifigan_v1", + "description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.", + "author": "@kaiidams", + "license": "apache 2.0", + "commit": "401fbd89" + } + } + }, + "tr": { + "common-voice": { + "glow-tts": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--tr--common-voice--glow-tts.zip", + "default_vocoder": "vocoder_models/tr/common-voice/hifigan", + "license": "MIT", + "description": "Turkish GlowTTS model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "commit": null + } + } + }, + "it": { + "mai_female": { + "glow-tts": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_female--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + }, + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_female--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + } + }, + "mai_male": { + "glow-tts": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_male--glow-tts.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + }, + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/tts_models--it--mai_male--vits.zip", + "default_vocoder": null, + "description": "GlowTTS model as explained on https://github.com/coqui-ai/TTS/issues/1148.", + "author": "@nicolalandro", + "license": "apache 2.0", + "commit": null + } + } + }, + "ewe": { + "openbible": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--ewe--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "hau": { + "openbible": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--hau--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "lin": { + "openbible": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--lin--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "tw_akuapem": { + "openbible": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--tw_akuapem--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "tw_asante": { + "openbible": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--tw_asante--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "yor": { + "openbible": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.2_models/tts_models--yor--openbible--vits.zip", + "default_vocoder": null, + "license": "CC-BY-SA 4.0", + "description": "Original work (audio and text) by Biblica available for free at www.biblica.com and open.bible.", + "author": "@coqui_ai", + "commit": "1b22f03" + } + } + }, + "hu": { + "css10": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--hu--css10--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "el": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--el--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "fi": { + "css10": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--fi--css10--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "hr": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--hr--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "lt": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--lt--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "lv": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--lv--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "mt": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--mt--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "pl": { + "mai_female": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--pl--mai_female--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "pt": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--pt--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "ro": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--ro--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "sk": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--sk--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "sl": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--sl--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "sv": { + "cv": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/tts_models--sv--cv--vits.zip", + "default_vocoder": null, + "commit": null, + "author": "@NeonGeckoCom", + "license": "bsd-3-clause" + } + } + }, + "ca": { + "custom": { + "vits": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--ca--custom--vits.zip", + "default_vocoder": null, + "commit": null, + "description": " It is trained from zero with 101460 utterances consisting of 257 speakers, approx 138 hours of speech. We used three datasets;\nFestcat and Google Catalan TTS (both TTS datasets) and also a part of Common Voice 8. It is trained with TTS v0.8.0.\nhttps://github.com/coqui-ai/TTS/discussions/930#discussioncomment-4466345", + "author": "@gullabi", + "license": "CC-BY-4.0" + } + } + }, + "fa": { + "custom": { + "glow-tts": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--fa--custom--glow-tts.zip", + "default_vocoder": null, + "commit": null, + "description": "persian-tts-female-glow_tts model for text to speech purposes. Single-speaker female voice Trained on persian-tts-dataset-famale. \nThis model has no compatible vocoder thus the output quality is not very good. \nDataset: https://www.kaggle.com/datasets/magnoliasis/persian-tts-dataset-famale.", + "author": "@karim23657", + "license": "CC-BY-4.0" + } + } + }, + "bn": { + "custom": { + "vits-male": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.3_models/tts_models--bn--custom--vits_male.zip", + "default_vocoder": null, + "commit": null, + "description": "Single speaker Bangla male model. For more information -> https://github.com/mobassir94/comprehensive-bangla-tts", + "author": "@mobassir94", + "license": "Apache 2.0" + }, + "vits-female": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.3_models/tts_models--bn--custom--vits_female.zip", + "default_vocoder": null, + "commit": null, + "description": "Single speaker Bangla female model. For more information -> https://github.com/mobassir94/comprehensive-bangla-tts", + "author": "@mobassir94", + "license": "Apache 2.0" + } + } + }, + "be": { + "common-voice": { + "glow-tts":{ + "description": "Belarusian GlowTTS model created by @alex73 (Github).", + "github_rls_url":"https://coqui.gateway.scarf.sh/v0.16.6/tts_models--be--common-voice--glow-tts.zip", + "default_vocoder": "vocoder_models/be/common-voice/hifigan", + "commit": "c0aabb85", + "license": "CC-BY-SA 4.0", + "contact": "alex73mail@gmail.com" + } + } + } + }, + "vocoder_models": { + "universal": { + "libri-tts": { + "wavegrad": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--universal--libri-tts--wavegrad.zip", + "commit": "ea976b0", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "fullband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--universal--libri-tts--fullband-melgan.zip", + "commit": "4132240", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + } + } + }, + "en": { + "ek1": { + "wavegrad": { + "description": "EK1 en-rp wavegrad by NMStoker", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ek1--wavegrad.zip", + "commit": "c802255", + "license": "apache 2.0" + } + }, + "ljspeech": { + "multiband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ljspeech--multiband-melgan.zip", + "commit": "ea976b0", + "author": "Eren Gölge @erogol", + "license": "MPL", + "contact": "egolge@coqui.com" + }, + "hifigan_v2": { + "description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ljspeech--hifigan_v2.zip", + "commit": "bae2ad0f", + "author": "@erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + }, + "univnet": { + "description": "UnivNet model finetuned on TacotronDDC_ph spectrograms for better compatibility.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--ljspeech--univnet_v2.zip", + "commit": "4581e3d", + "author": "Eren @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + } + }, + "blizzard2013": { + "hifigan_v2": { + "description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/vocoder_models--en--blizzard2013--hifigan_v2.zip", + "commit": "d6284e7", + "author": "Adam Froghyar @a-froghyar", + "license": "apache 2.0", + "contact": "adamfroghyar@gmail.com" + } + }, + "vctk": { + "hifigan_v2": { + "description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--vctk--hifigan_v2.zip", + "commit": "2f07160", + "author": "Edresson Casanova", + "license": "apache 2.0", + "contact": "" + } + }, + "sam": { + "hifigan_v2": { + "description": "Finetuned and intended to be used with tts_models/en/sam/tacotron_DDC", + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--en--sam--hifigan_v2.zip", + "commit": "2f07160", + "author": "Eren Gölge @erogol", + "license": "apache 2.0", + "contact": "egolge@coqui.ai" + } + } + }, + "nl": { + "mai": { + "parallel-wavegan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--nl--mai--parallel-wavegan.zip", + "author": "@r-dh", + "license": "apache 2.0", + "commit": "unknown" + } + } + }, + "de": { + "thorsten": { + "wavegrad": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--de--thorsten--wavegrad.zip", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + }, + "fullband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--de--thorsten--fullband-melgan.zip", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + }, + "hifigan_v1": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.8.0_models/vocoder_models--de--thorsten--hifigan_v1.zip", + "description": "HifiGAN vocoder model for Thorsten Neutral Dec2021 22k Samplerate Tacotron2 DDC model", + "author": "@thorstenMueller", + "license": "apache 2.0", + "commit": "unknown" + } + } + }, + "ja": { + "kokoro": { + "hifigan_v1": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--ja--kokoro--hifigan_v1.zip", + "description": "HifiGAN model trained for kokoro dataset by @kaiidams", + "author": "@kaiidams", + "license": "apache 2.0", + "commit": "3900448" + } + } + }, + "uk": { + "mai": { + "multiband-melgan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--uk--mai--multiband-melgan.zip", + "author": "@robinhad", + "commit": "bdab788d", + "license": "MIT", + "contact": "" + } + } + }, + "tr": { + "common-voice": { + "hifigan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.6.1_models/vocoder_models--tr--common-voice--hifigan.zip", + "description": "HifiGAN model using an unknown speaker from the Common-Voice dataset.", + "author": "Fatih Akademi", + "license": "MIT", + "commit": null + } + } + }, + "be": { + "common-voice": { + "hifigan": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.16.6/vocoder_models--be--common-voice--hifigan.zip", + "description": "Belarusian HiFiGAN model created by @alex73 (Github).", + "author": "@alex73", + "license": "CC-BY-SA 4.0", + "commit": "c0aabb85" + } + } + } + }, + "voice_conversion_models": { + "multilingual": { + "vctk": { + "freevc24": { + "github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip", + "description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC", + "author": "Jing-Yi Li @OlaWod", + "license": "MIT", + "commit": null + } + } + } + } +} diff --git a/viXTTS/TTS/VERSION b/viXTTS/TTS/VERSION new file mode 100644 index 0000000000000000000000000000000000000000..2157409059873c80aa93884ecb847639add77b7a --- /dev/null +++ b/viXTTS/TTS/VERSION @@ -0,0 +1 @@ +0.22.0 diff --git a/viXTTS/TTS/__init__.py b/viXTTS/TTS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf05db1b950d82bfd7e20857e09a0fef45b430a --- /dev/null +++ b/viXTTS/TTS/__init__.py @@ -0,0 +1,6 @@ +import os + +with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f: + version = f.read().strip() + +__version__ = version diff --git a/viXTTS/TTS/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eebdd6c3c63822de3155feafd61d7ec20e7cc7e Binary files /dev/null and b/viXTTS/TTS/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/__pycache__/model.cpython-310.pyc b/viXTTS/TTS/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c09f4833aceb3d03e4548985803975bea8e9adb9 Binary files /dev/null and b/viXTTS/TTS/__pycache__/model.cpython-310.pyc differ diff --git a/viXTTS/TTS/api.py b/viXTTS/TTS/api.py new file mode 100644 index 0000000000000000000000000000000000000000..7abc188e74032ae4afc66fd8b639733c83b34f3e --- /dev/null +++ b/viXTTS/TTS/api.py @@ -0,0 +1,458 @@ +import tempfile +import warnings +from pathlib import Path +from typing import Union + +import numpy as np +from torch import nn + +from TTS.utils.audio.numpy_transforms import save_wav +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer +from TTS.config import load_config + + +class TTS(nn.Module): + """TODO: Add voice conversion and Capacitron support.""" + + def __init__( + self, + model_name: str = "", + model_path: str = None, + config_path: str = None, + vocoder_path: str = None, + vocoder_config_path: str = None, + progress_bar: bool = True, + gpu=False, + ): + """🐸TTS python interface that allows to load and use the released models. + + Example with a multi-speaker model: + >>> from TTS.api import TTS + >>> tts = TTS(TTS.list_models()[0]) + >>> wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0]) + >>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav") + + Example with a single-speaker model: + >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + + Example loading a model from a path: + >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + + Example voice cloning with YourTTS in English, French and Portuguese: + >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav") + >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") + >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") + + Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html): + >>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is a test.", file_path="output.wav") + + Args: + model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. + model_path (str, optional): Path to the model checkpoint. Defaults to None. + config_path (str, optional): Path to the model config. Defaults to None. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. + progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + super().__init__() + self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) + self.config = load_config(config_path) if config_path else None + self.synthesizer = None + self.voice_converter = None + self.model_name = "" + if gpu: + warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") + + if model_name is not None and len(model_name) > 0: + if "tts_models" in model_name: + self.load_tts_model_by_name(model_name, gpu) + elif "voice_conversion_models" in model_name: + self.load_vc_model_by_name(model_name, gpu) + else: + self.load_model_by_name(model_name, gpu) + + if model_path: + self.load_tts_model_by_path( + model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu + ) + + @property + def models(self): + return self.manager.list_tts_models() + + @property + def is_multi_speaker(self): + if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager: + return self.synthesizer.tts_model.speaker_manager.num_speakers > 1 + return False + + @property + def is_multi_lingual(self): + # Not sure what sets this to None, but applied a fix to prevent crashing. + if ( + isinstance(self.model_name, str) + and "xtts" in self.model_name + or self.config + and ("xtts" in self.config.model or len(self.config.languages) > 1) + ): + return True + if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: + return self.synthesizer.tts_model.language_manager.num_languages > 1 + return False + + @property + def speakers(self): + if not self.is_multi_speaker: + return None + return self.synthesizer.tts_model.speaker_manager.speaker_names + + @property + def languages(self): + if not self.is_multi_lingual: + return None + return self.synthesizer.tts_model.language_manager.language_names + + @staticmethod + def get_models_file_path(): + return Path(__file__).parent / ".models.json" + + def list_models(self): + return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False, verbose=False) + + def download_model_by_name(self, model_name: str): + model_path, config_path, model_item = self.manager.download_model(model_name) + if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): + # return model directory if there are multiple files + # we assume that the model knows how to load itself + return None, None, None, None, model_path + if model_item.get("default_vocoder") is None: + return model_path, config_path, None, None, None + vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) + return model_path, config_path, vocoder_path, vocoder_config_path, None + + def load_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.load_tts_model_by_name(model_name, gpu) + + def load_vc_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the voice conversion models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.model_name = model_name + model_path, config_path, _, _, _ = self.download_model_by_name(model_name) + self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) + + def load_tts_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + + TODO: Add tests + """ + self.synthesizer = None + self.model_name = model_name + + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( + model_name + ) + + # init synthesizer + # None values are fetch from the model + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint=None, + encoder_config=None, + model_dir=model_dir, + use_cuda=gpu, + ) + + def load_tts_model_by_path( + self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False + ): + """Load a model from a path. + + Args: + model_path (str): Path to the model checkpoint. + config_path (str): Path to the model config. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config (str, optional): Path to the vocoder config. Defaults to None. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config, + encoder_checkpoint=None, + encoder_config=None, + use_cuda=gpu, + ) + + def _check_arguments( + self, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = None, + **kwargs, + ) -> None: + """Check if the arguments are valid for the model.""" + # check for the coqui tts models + if self.is_multi_speaker and (speaker is None and speaker_wav is None): + raise ValueError("Model is multi-speaker but no `speaker` is provided.") + if self.is_multi_lingual and language is None: + raise ValueError("Model is multi-lingual but no `language` is provided.") + if not self.is_multi_speaker and speaker is not None and "voice_dir" not in kwargs: + raise ValueError("Model is not multi-speaker but `speaker` is provided.") + if not self.is_multi_lingual and language is not None: + raise ValueError("Model is not multi-lingual but `language` is provided.") + if not emotion is None and not speed is None: + raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.") + + def tts( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = None, + split_sentences: bool = True, + **kwargs, + ): + """Convert text to speech. + + Args: + text (str): + Input text to synthesize. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + language (str): Language of the text. If None, the default language of the speaker is used. Language is only + supported by `XTTS` model. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + emotion (str, optional): + Emotion to use for 🐸Coqui Studio models. If None, Studio models use "Neutral". Defaults to None. + speed (float, optional): + Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0. + Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + kwargs (dict, optional): + Additional arguments for the model. + """ + self._check_arguments( + speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs + ) + wav = self.synthesizer.tts( + text=text, + speaker_name=speaker, + language_name=language, + speaker_wav=speaker_wav, + reference_wav=None, + style_wav=None, + style_text=None, + reference_speaker_name=None, + split_sentences=split_sentences, + **kwargs, + ) + return wav + + def tts_to_file( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = 1.0, + pipe_out=None, + file_path: str = "output.wav", + split_sentences: bool = True, + **kwargs, + ): + """Convert text to speech. + + Args: + text (str): + Input text to synthesize. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + emotion (str, optional): + Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral". + speed (float, optional): + Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None. + pipe_out (BytesIO, optional): + Flag to stdout the generated TTS wav file for shell pipe. + file_path (str, optional): + Output file path. Defaults to "output.wav". + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + kwargs (dict, optional): + Additional arguments for the model. + """ + self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) + + wav = self.tts( + text=text, + speaker=speaker, + language=language, + speaker_wav=speaker_wav, + split_sentences=split_sentences, + **kwargs, + ) + self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) + return file_path + + def voice_conversion( + self, + source_wav: str, + target_wav: str, + ): + """Voice conversion with FreeVC. Convert source wav to target speaker. + + Args:`` + source_wav (str): + Path to the source wav file. + target_wav (str):` + Path to the target wav file. + """ + wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav) + return wav + + def voice_conversion_to_file( + self, + source_wav: str, + target_wav: str, + file_path: str = "output.wav", + ): + """Voice conversion with FreeVC. Convert source wav to target speaker. + + Args: + source_wav (str): + Path to the source wav file. + target_wav (str): + Path to the target wav file. + file_path (str, optional): + Output file path. Defaults to "output.wav". + """ + wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav) + save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) + return file_path + + def tts_with_vc( + self, + text: str, + language: str = None, + speaker_wav: str = None, + speaker: str = None, + split_sentences: bool = True, + ): + """Convert text to speech with voice conversion. + + It combines tts with voice conversion to fake voice cloning. + + - Convert text to speech with tts. + - Convert the output wav to target speaker with voice conversion. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + """ + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + # Lazy code... save it to a temp file to resample it while reading it for VC + self.tts_to_file( + text=text, speaker=speaker, language=language, file_path=fp.name, split_sentences=split_sentences + ) + if self.voice_converter is None: + self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24") + wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav) + return wav + + def tts_with_vc_to_file( + self, + text: str, + language: str = None, + speaker_wav: str = None, + file_path: str = "output.wav", + speaker: str = None, + split_sentences: bool = True, + ): + """Convert text to speech with voice conversion and save to file. + + Check `tts_with_vc` for more details. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + file_path (str, optional): + Output file path. Defaults to "output.wav". + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + """ + wav = self.tts_with_vc( + text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences + ) + save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) diff --git a/viXTTS/TTS/bin/__init__.py b/viXTTS/TTS/bin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/bin/collect_env_info.py b/viXTTS/TTS/bin/collect_env_info.py new file mode 100644 index 0000000000000000000000000000000000000000..662fcd02ece0fad387b6bfc4bad9316c7e2a0bad --- /dev/null +++ b/viXTTS/TTS/bin/collect_env_info.py @@ -0,0 +1,48 @@ +"""Get detailed info about the working environment.""" +import os +import platform +import sys + +import numpy +import torch + +sys.path += [os.path.abspath(".."), os.path.abspath(".")] +import json + +import TTS + + +def system_info(): + return { + "OS": platform.system(), + "architecture": platform.architecture(), + "version": platform.version(), + "processor": platform.processor(), + "python": platform.python_version(), + } + + +def cuda_info(): + return { + "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], + "available": torch.cuda.is_available(), + "version": torch.version.cuda, + } + + +def package_info(): + return { + "numpy": numpy.__version__, + "PyTorch_version": torch.__version__, + "PyTorch_debug": torch.version.debug, + "TTS": TTS.__version__, + } + + +def main(): + details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()} + print(json.dumps(details, indent=4, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/bin/compute_attention_masks.py b/viXTTS/TTS/bin/compute_attention_masks.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab520be7d9f41ecf4f124446400b5e1b597ae8b --- /dev/null +++ b/viXTTS/TTS/bin/compute_attention_masks.py @@ -0,0 +1,165 @@ +import argparse +import importlib +import os +from argparse import RawTextHelpFormatter + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets.TTSDataset import TTSDataset +from TTS.tts.models import setup_model +from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_checkpoint + +if __name__ == "__main__": + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Extract attention masks from trained Tacotron/Tacotron2 models. +These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" + """Each attention mask is written to the same path as the input wav file with ".npy" file extension. +(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" + """ +Example run: + CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py + --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth + --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json + --dataset_metafile metadata.csv + --data_path /root/LJSpeech-1.1/ + --batch_size 32 + --dataset ljspeech + --use_cuda True +""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Path to Tacotron/Tacotron2 config file.", + ) + parser.add_argument( + "--dataset", + type=str, + default="", + required=True, + help="Target dataset processor name from TTS.tts.dataset.preprocess.", + ) + + parser.add_argument( + "--dataset_metafile", + type=str, + default="", + required=True, + help="Dataset metafile inclusing file paths with transcripts.", + ) + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") + + parser.add_argument( + "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." + ) + args = parser.parse_args() + + C = load_config(args.config_path) + ap = AudioProcessor(**C.audio) + + # if the vocabulary was passed, replace the default + if "characters" in C.keys(): + symbols, phonemes = make_symbols(**C.characters) + + # load the model + num_chars = len(phonemes) if C.use_phonemes else len(symbols) + # TODO: handle multi-speaker + model = setup_model(C) + model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) + + # data loader + preprocessor = importlib.import_module("TTS.tts.datasets.formatters") + preprocessor = getattr(preprocessor, args.dataset) + meta_data = preprocessor(args.data_path, args.dataset_metafile) + dataset = TTSDataset( + model.decoder.r, + C.text_cleaner, + compute_linear_spec=False, + ap=ap, + meta_data=meta_data, + characters=C.characters if "characters" in C.keys() else None, + add_blank=C["add_blank"] if "add_blank" in C.keys() else False, + use_phonemes=C.use_phonemes, + phoneme_cache_path=C.phoneme_cache_path, + phoneme_language=C.phoneme_language, + enable_eos_bos=C.enable_eos_bos_chars, + ) + + dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=4, + collate_fn=dataset.collate_fn, + shuffle=False, + drop_last=False, + ) + + # compute attentions + file_paths = [] + with torch.no_grad(): + for data in tqdm(loader): + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[3] + mel_input = data[4] + mel_lengths = data[5] + stop_targets = data[6] + item_idxs = data[7] + + # dispatch data to GPU + if args.use_cuda: + text_input = text_input.cuda() + text_lengths = text_lengths.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + + model_outputs = model.forward(text_input, text_lengths, mel_input) + + alignments = model_outputs["alignments"].detach() + for idx, alignment in enumerate(alignments): + item_idx = item_idxs[idx] + # interpolate if r > 1 + alignment = ( + torch.nn.functional.interpolate( + alignment.transpose(0, 1).unsqueeze(0), + size=None, + scale_factor=model.decoder.r, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + ) + .squeeze(0) + .transpose(0, 1) + ) + # remove paddings + alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() + # set file paths + wav_file_name = os.path.basename(item_idx) + align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" + file_path = item_idx.replace(wav_file_name, align_file_name) + # save output + wav_file_abs_path = os.path.abspath(item_idx) + file_abs_path = os.path.abspath(file_path) + file_paths.append([wav_file_abs_path, file_abs_path]) + np.save(file_path, alignment) + + # ourput metafile + metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") + + with open(metafile, "w", encoding="utf-8") as f: + for p in file_paths: + f.write(f"{p[0]}|{p[1]}\n") + print(f" >> Metafile created: {metafile}") diff --git a/viXTTS/TTS/bin/compute_embeddings.py b/viXTTS/TTS/bin/compute_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5a37df736fd75c8228ceefd818c6ec4a63867f --- /dev/null +++ b/viXTTS/TTS/bin/compute_embeddings.py @@ -0,0 +1,197 @@ +import argparse +import os +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.managers import save_file +from TTS.tts.utils.speakers import SpeakerManager + + +def compute_embeddings( + model_path, + config_path, + output_path, + old_speakers_file=None, + old_append=False, + config_dataset_path=None, + formatter_name=None, + dataset_name=None, + dataset_path=None, + meta_file_train=None, + meta_file_val=None, + disable_cuda=False, + no_eval=False, +): + use_cuda = torch.cuda.is_available() and not disable_cuda + + if config_dataset_path is not None: + c_dataset = load_config(config_dataset_path) + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not no_eval) + else: + c_dataset = BaseDatasetConfig() + c_dataset.formatter = formatter_name + c_dataset.dataset_name = dataset_name + c_dataset.path = dataset_path + if meta_file_train is not None: + c_dataset.meta_file_train = meta_file_train + if meta_file_val is not None: + c_dataset.meta_file_val = meta_file_val + meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval) + + if meta_data_eval is None: + samples = meta_data_train + else: + samples = meta_data_train + meta_data_eval + + encoder_manager = SpeakerManager( + encoder_model_path=model_path, + encoder_config_path=config_path, + d_vectors_file_path=old_speakers_file, + use_cuda=use_cuda, + ) + + class_name_key = encoder_manager.encoder_config.class_name_key + + # compute speaker embeddings + if old_speakers_file is not None and old_append: + speaker_mapping = encoder_manager.embeddings + else: + speaker_mapping = {} + + for fields in tqdm(samples): + class_name = fields[class_name_key] + audio_file = fields["audio_file"] + embedding_key = fields["audio_unique_name"] + + # Only update the speaker name when the embedding is already in the old file. + if embedding_key in speaker_mapping: + speaker_mapping[embedding_key]["name"] = class_name + continue + + if old_speakers_file is not None and embedding_key in encoder_manager.clip_ids: + # get the embedding from the old file + embedd = encoder_manager.get_embedding_by_clip(embedding_key) + else: + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(audio_file) + + # create speaker_mapping if target dataset is defined + speaker_mapping[embedding_key] = {} + speaker_mapping[embedding_key]["name"] = class_name + speaker_mapping[embedding_key]["embedding"] = embedd + + if speaker_mapping: + # save speaker_mapping if target dataset is defined + if os.path.isdir(output_path): + mapping_file_path = os.path.join(output_path, "speakers.pth") + else: + mapping_file_path = output_path + + if os.path.dirname(mapping_file_path) != "": + os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) + + save_file(speaker_mapping, mapping_file_path) + print("Speaker embeddings saved at:", mapping_file_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" + """ + Example runs: + python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json + + python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --formatter_name coqui --dataset_path /path/to/vctk/dataset --dataset_name my_vctk --meta_file_train /path/to/vctk/metafile_train.csv --meta_file_val /path/to/vctk/metafile_eval.csv + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument( + "--model_path", + type=str, + help="Path to model checkpoint file. It defaults to the released speaker encoder.", + default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar", + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to model config file. It defaults to the released speaker encoder config.", + default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json", + ) + parser.add_argument( + "--config_dataset_path", + type=str, + help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.", + default=None, + ) + parser.add_argument( + "--output_path", + type=str, + help="Path for output `pth` or `json` file.", + default="speakers.pth", + ) + parser.add_argument( + "--old_file", + type=str, + help="The old existing embedding file, from which the embeddings will be directly loaded for already computed audio clips.", + default=None, + ) + parser.add_argument( + "--old_append", + help="Append new audio clip embeddings to the old embedding file, generate a new non-duplicated merged embedding file. Default False", + default=False, + action="store_true", + ) + parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False) + parser.add_argument("--no_eval", help="Do not compute eval?. Default False", default=False, action="store_true") + parser.add_argument( + "--formatter_name", + type=str, + help="Name of the formatter to use. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--dataset_name", + type=str, + help="Name of the dataset to use. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--dataset_path", + type=str, + help="Path to the dataset. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--meta_file_train", + type=str, + help="Path to the train meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--meta_file_val", + type=str, + help="Path to the evaluation meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`", + default=None, + ) + args = parser.parse_args() + + compute_embeddings( + args.model_path, + args.config_path, + args.output_path, + old_speakers_file=args.old_file, + old_append=args.old_append, + config_dataset_path=args.config_dataset_path, + formatter_name=args.formatter_name, + dataset_name=args.dataset_name, + dataset_path=args.dataset_path, + meta_file_train=args.meta_file_train, + meta_file_val=args.meta_file_val, + disable_cuda=args.disable_cuda, + no_eval=args.no_eval, + ) diff --git a/viXTTS/TTS/bin/compute_statistics.py b/viXTTS/TTS/bin/compute_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab7ea7a3b10ec3cc23d8a744c7bdc79de52dbf2 --- /dev/null +++ b/viXTTS/TTS/bin/compute_statistics.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import glob +import os + +import numpy as np +from tqdm import tqdm + +# from TTS.utils.io import load_config +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") + parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") + parser.add_argument("out_path", type=str, help="save path (directory and filename).") + parser.add_argument( + "--data_path", + type=str, + required=False, + help="folder including the target set of wavs overriding dataset config.", + ) + args, overrides = parser.parse_known_args() + + CONFIG = load_config(args.config_path) + CONFIG.parse_known_args(overrides, relaxed_parser=True) + + # load config + CONFIG.audio.signal_norm = False # do not apply earlier normalization + CONFIG.audio.stats_path = None # discard pre-defined stats + + # load audio processor + ap = AudioProcessor(**CONFIG.audio.to_dict()) + + # load the meta data of target dataset + if args.data_path: + dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True) + else: + dataset_items = load_tts_samples(CONFIG.datasets)[0] # take only train data + print(f" > There are {len(dataset_items)} files.") + + mel_sum = 0 + mel_square_sum = 0 + linear_sum = 0 + linear_square_sum = 0 + N = 0 + for item in tqdm(dataset_items): + # compute features + wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"]) + linear = ap.spectrogram(wav) + mel = ap.melspectrogram(wav) + + # compute stats + N += mel.shape[1] + mel_sum += mel.sum(1) + linear_sum += linear.sum(1) + mel_square_sum += (mel**2).sum(axis=1) + linear_square_sum += (linear**2).sum(axis=1) + + mel_mean = mel_sum / N + mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2) + linear_mean = linear_sum / N + linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2) + + output_file_path = args.out_path + stats = {} + stats["mel_mean"] = mel_mean + stats["mel_std"] = mel_scale + stats["linear_mean"] = linear_mean + stats["linear_std"] = linear_scale + + print(f" > Avg mel spec mean: {mel_mean.mean()}") + print(f" > Avg mel spec scale: {mel_scale.mean()}") + print(f" > Avg linear spec mean: {linear_mean.mean()}") + print(f" > Avg linear spec scale: {linear_scale.mean()}") + + # set default config values for mean-var scaling + CONFIG.audio.stats_path = output_file_path + CONFIG.audio.signal_norm = True + # remove redundant values + del CONFIG.audio.max_norm + del CONFIG.audio.min_level_db + del CONFIG.audio.symmetric_norm + del CONFIG.audio.clip_norm + stats["audio_config"] = CONFIG.audio.to_dict() + np.save(output_file_path, stats, allow_pickle=True) + print(f" > stats saved to {output_file_path}") + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/bin/eval_encoder.py b/viXTTS/TTS/bin/eval_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..60fed1393215cd5e2e349795b585ae12f2e227fa --- /dev/null +++ b/viXTTS/TTS/bin/eval_encoder.py @@ -0,0 +1,88 @@ +import argparse +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.speakers import SpeakerManager + + +def compute_encoder_accuracy(dataset_items, encoder_manager): + class_name_key = encoder_manager.encoder_config.class_name_key + map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None) + + class_acc_dict = {} + + # compute embeddings for all wav_files + for item in tqdm(dataset_items): + class_name = item[class_name_key] + wav_file = item["audio_file"] + + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(wav_file) + if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None: + embedding = torch.FloatTensor(embedd).unsqueeze(0) + if encoder_manager.use_cuda: + embedding = embedding.cuda() + + class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item() + predicted_label = map_classid_to_classname[str(class_id)] + else: + predicted_label = None + + if class_name is not None and predicted_label is not None: + is_equal = int(class_name == predicted_label) + if class_name not in class_acc_dict: + class_acc_dict[class_name] = [is_equal] + else: + class_acc_dict[class_name].append(is_equal) + else: + raise RuntimeError("Error: class_name or/and predicted_label are None") + + acc_avg = 0 + for key, values in class_acc_dict.items(): + acc = sum(values) / len(values) + print("Class", key, "Accuracy:", acc) + acc_avg += acc + + print("Average Accuracy:", acc_avg / len(class_acc_dict)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Compute the accuracy of the encoder.\n\n""" + """ + Example runs: + python TTS/bin/eval_encoder.py emotion_encoder_model.pth emotion_encoder_config.json dataset_config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", + ) + + parser.add_argument( + "config_dataset_path", + type=str, + help="Path to dataset config file.", + ) + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + + args = parser.parse_args() + + c_dataset = load_config(args.config_dataset_path) + + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) + items = meta_data_train + meta_data_eval + + enc_manager = SpeakerManager( + encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda + ) + + compute_encoder_accuracy(items, enc_manager) diff --git a/viXTTS/TTS/bin/extract_tts_spectrograms.py b/viXTTS/TTS/bin/extract_tts_spectrograms.py new file mode 100644 index 0000000000000000000000000000000000000000..c6048626b3cb89daee37b42f757a4ba1e8b7843d --- /dev/null +++ b/viXTTS/TTS/bin/extract_tts_spectrograms.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +"""Extract Mel spectrograms with teacher forcing.""" + +import argparse +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.models import setup_model +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import quantize +from TTS.utils.generic_utils import count_parameters + +use_cuda = torch.cuda.is_available() + + +def setup_loader(ap, r, verbose=False): + tokenizer, _ = TTSTokenizer.init_from_config(c) + dataset = TTSDataset( + outputs_per_step=r, + compute_linear_spec=False, + samples=meta_data, + tokenizer=tokenizer, + ap=ap, + batch_group_size=0, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, + phoneme_cache_path=c.phoneme_cache_path, + precompute_num_workers=0, + use_noise_augment=False, + verbose=verbose, + speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, + d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, + ) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.preprocess_samples() + + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=None, + num_workers=c.num_loader_workers, + pin_memory=False, + ) + return loader + + +def set_filename(wav_path, out_path): + wav_file = os.path.basename(wav_path) + file_name = wav_file.split(".")[0] + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) + wavq_path = os.path.join(out_path, "quant", file_name) + mel_path = os.path.join(out_path, "mel", file_name) + wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav") + wav_path = os.path.join(out_path, "wav", file_name + ".wav") + return file_name, wavq_path, mel_path, wav_gl_path, wav_path + + +def format_data(data): + # setup input data + text_input = data["token_id"] + text_lengths = data["token_id_lengths"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + item_idx = data["item_idxs"] + d_vectors = data["d_vectors"] + speaker_ids = data["speaker_ids"] + attn_mask = data["attns"] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_ids is not None: + speaker_ids = speaker_ids.cuda(non_blocking=True) + if d_vectors is not None: + d_vectors = d_vectors.cuda(non_blocking=True) + if attn_mask is not None: + attn_mask = attn_mask.cuda(non_blocking=True) + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) + + +@torch.no_grad() +def inference( + model_name, + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=None, + d_vectors=None, +): + if model_name == "glow_tts": + speaker_c = None + if speaker_ids is not None: + speaker_c = speaker_ids + elif d_vectors is not None: + speaker_c = d_vectors + outputs = model.inference_with_MAS( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, + ) + model_output = outputs["model_outputs"] + model_output = model_output.detach().cpu().numpy() + + elif "tacotron" in model_name: + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input) + postnet_outputs = outputs["model_outputs"] + # normalize tacotron output + if model_name == "tacotron": + mel_specs = [] + postnet_outputs = postnet_outputs.data.cpu().numpy() + for b in range(postnet_outputs.shape[0]): + postnet_output = postnet_outputs[b] + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) + model_output = torch.stack(mel_specs).cpu().numpy() + + elif model_name == "tacotron2": + model_output = postnet_outputs.detach().cpu().numpy() + return model_output + + +def extract_spectrograms( + data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt" +): + model.eval() + export_metadata = [] + for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): + # format data + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + _, + _, + _, + item_idx, + ) = format_data(data) + + model_output = inference( + c.model.lower(), + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + ) + + for idx in range(text_input.shape[0]): + wav_file_path = item_idx[idx] + wav = ap.load_wav(wav_file_path) + _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) + + # quantize and save wav + if quantize_bits > 0: + wavq = quantize(wav, quantize_bits) + np.save(wavq_path, wavq) + + # save TTS mel + mel = model_output[idx] + mel_length = mel_lengths[idx] + mel = mel[:mel_length, :].T + np.save(mel_path, mel) + + export_metadata.append([wav_file_path, mel_path]) + if save_audio: + ap.save_wav(wav, wav_path) + + if debug: + print("Audio for debug saved at:", wav_gl_path) + wav = ap.inv_melspectrogram(mel) + ap.save_wav(wav, wav_gl_path) + + with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: + for data in export_metadata: + f.write(f"{data[0]}|{data[1]+'.npy'}\n") + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data, speaker_manager + + # Audio processor + ap = AudioProcessor(**c.audio) + + # load data instances + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + # use eval and training partitions + meta_data = meta_data_train + meta_data_eval + + # init speaker manager + if c.use_speaker_embedding: + speaker_manager = SpeakerManager(data_items=meta_data) + elif c.use_d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file) + else: + speaker_manager = None + + # setup model + model = setup_model(c) + + # restore model + model.load_checkpoint(c, args.checkpoint_path, eval=True) + + if use_cuda: + model.cuda() + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + # set r + r = 1 if c.model.lower() == "glow_tts" else model.decoder.r + own_loader = setup_loader(ap, r, verbose=True) + + extract_spectrograms( + own_loader, + model, + ap, + args.output_path, + quantize_bits=args.quantize_bits, + save_audio=args.save_audio, + debug=args.debug, + metada_name="metada.txt", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) + parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) + parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) + parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") + parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") + parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero") + parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + args = parser.parse_args() + + c = load_config(args.config_path) + c.audio.trim_silence = False + main(args) diff --git a/viXTTS/TTS/bin/find_unique_chars.py b/viXTTS/TTS/bin/find_unique_chars.py new file mode 100644 index 0000000000000000000000000000000000000000..ea16974839df6cf9942ef24a5535597940fde5b2 --- /dev/null +++ b/viXTTS/TTS/bin/find_unique_chars.py @@ -0,0 +1,45 @@ +"""Find all the unique characters in a dataset""" +import argparse +from argparse import RawTextHelpFormatter + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples + + +def main(): + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_chars.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + items = train_items + eval_items + + texts = "".join(item["text"] for item in items) + chars = set(texts) + lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + + print(f" > Number of unique characters: {len(chars)}") + print(f" > Unique characters: {''.join(sorted(chars))}") + print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") + print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/bin/find_unique_phonemes.py b/viXTTS/TTS/bin/find_unique_phonemes.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd7a78eef2c4850bca9369def55d68336cd53aa --- /dev/null +++ b/viXTTS/TTS/bin/find_unique_phonemes.py @@ -0,0 +1,74 @@ +"""Find all the unique characters in a dataset""" +import argparse +import multiprocessing +from argparse import RawTextHelpFormatter + +from tqdm.contrib.concurrent import process_map + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.text.phonemizers import Gruut + + +def compute_phonemes(item): + text = item["text"] + ph = phonemizer.phonemize(text).replace("|", "") + return set(list(ph)) + + +def main(): + # pylint: disable=W0601 + global c, phonemizer + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_phonemes.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + items = train_items + eval_items + print("Num items:", len(items)) + + language_list = [item["language"] for item in items] + is_lang_def = all(language_list) + + if not c.phoneme_language or not is_lang_def: + raise ValueError("Phoneme language must be defined in config.") + + if not language_list.count(language_list[0]) == len(language_list): + raise ValueError( + "Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!" + ) + + phonemizer = Gruut(language=language_list[0], keep_puncs=True) + + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) + phones = [] + for ph in phonemes: + phones.extend(ph) + + phones = set(phones) + lower_phones = filter(lambda c: c.islower(), phones) + phones_force_lower = [c.lower() for c in phones] + phones_force_lower = set(phones_force_lower) + + print(f" > Number of unique phonemes: {len(phones)}") + print(f" > Unique phonemes: {''.join(sorted(phones))}") + print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") + print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/bin/remove_silence_using_vad.py b/viXTTS/TTS/bin/remove_silence_using_vad.py new file mode 100644 index 0000000000000000000000000000000000000000..a1eaf4c9a713e2e72a9e8434397ac430ff10aef1 --- /dev/null +++ b/viXTTS/TTS/bin/remove_silence_using_vad.py @@ -0,0 +1,124 @@ +import argparse +import glob +import multiprocessing +import os +import pathlib + +import torch +from tqdm import tqdm + +from TTS.utils.vad import get_vad_model_and_utils, remove_silence + +torch.set_num_threads(1) + + +def adjust_path_and_remove_silence(audio_path): + output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) + # ignore if the file exists + if os.path.exists(output_path) and not args.force: + return output_path, False + + # create all directory structure + pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) + # remove the silence and save the audio + output_path, is_speech = remove_silence( + model_and_utils, + audio_path, + output_path, + trim_just_beginning_and_end=args.trim_just_beginning_and_end, + use_cuda=args.use_cuda, + ) + return output_path, is_speech + + +def preprocess_audios(): + files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True)) + print("> Number of files: ", len(files)) + if not args.force: + print("> Ignoring files that already exist in the output idrectory.") + + if args.trim_just_beginning_and_end: + print("> Trimming just the beginning and the end with nonspeech parts.") + else: + print("> Trimming all nonspeech parts.") + + filtered_files = [] + if files: + # create threads + # num_threads = multiprocessing.cpu_count() + # process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15) + + if args.num_processes > 1: + with multiprocessing.Pool(processes=args.num_processes) as pool: + results = list( + tqdm( + pool.imap_unordered(adjust_path_and_remove_silence, files), + total=len(files), + desc="Processing audio files", + ) + ) + for output_path, is_speech in results: + if not is_speech: + filtered_files.append(output_path) + else: + for f in tqdm(files): + output_path, is_speech = adjust_path_and_remove_silence(f) + if not is_speech: + filtered_files.append(output_path) + + # write files that do not have speech + with open(os.path.join(args.output_dir, "filtered_files.txt"), "w", encoding="utf-8") as f: + for file in filtered_files: + f.write(str(file) + "\n") + else: + print("> No files Found !") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" + ) + parser.add_argument("-i", "--input_dir", type=str, help="Dataset root dir", required=True) + parser.add_argument("-o", "--output_dir", type=str, help="Output Dataset dir", default="") + parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files") + parser.add_argument( + "-g", + "--glob", + type=str, + default="**/*.wav", + help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav", + ) + parser.add_argument( + "-t", + "--trim_just_beginning_and_end", + type=bool, + default=True, + help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True", + ) + parser.add_argument( + "-c", + "--use_cuda", + type=bool, + default=False, + help="If True use cuda", + ) + parser.add_argument( + "--use_onnx", + type=bool, + default=False, + help="If True use onnx", + ) + parser.add_argument( + "--num_processes", + type=int, + default=1, + help="Number of processes to use", + ) + args = parser.parse_args() + + if args.output_dir == "": + args.output_dir = args.input_dir + + # load the model and utils + model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda, use_onnx=args.use_onnx) + preprocess_audios() diff --git a/viXTTS/TTS/bin/resample.py b/viXTTS/TTS/bin/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f28485d1fb235ab0d521ee30318c64b48fbd5a --- /dev/null +++ b/viXTTS/TTS/bin/resample.py @@ -0,0 +1,90 @@ +import argparse +import glob +import os +from argparse import RawTextHelpFormatter +from multiprocessing import Pool +from shutil import copytree + +import librosa +import soundfile as sf +from tqdm import tqdm + + +def resample_file(func_args): + filename, output_sr = func_args + y, sr = librosa.load(filename, sr=output_sr) + sf.write(filename, y, sr) + + +def resample_files(input_dir, output_sr, output_dir=None, file_ext="wav", n_jobs=10): + if output_dir: + print("Recursively copying the input folder...") + copytree(input_dir, output_dir) + input_dir = output_dir + + print("Resampling the audio files...") + audio_files = glob.glob(os.path.join(input_dir, f"**/*.{file_ext}"), recursive=True) + print(f"Found {len(audio_files)} files...") + audio_files = list(zip(audio_files, len(audio_files) * [output_sr])) + with Pool(processes=n_jobs) as p: + with tqdm(total=len(audio_files)) as pbar: + for _, _ in enumerate(p.imap_unordered(resample_file, audio_files)): + pbar.update() + + print("Done !") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Resample a folder recusively with librosa + Can be used in place or create a copy of the folder as an output.\n\n + Example run: + python TTS/bin/resample.py + --input_dir /root/LJSpeech-1.1/ + --output_sr 22050 + --output_dir /root/resampled_LJSpeech-1.1/ + --file_ext wav + --n_jobs 24 + """, + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--input_dir", + type=str, + default=None, + required=True, + help="Path of the folder containing the audio files to resample", + ) + + parser.add_argument( + "--output_sr", + type=int, + default=22050, + required=False, + help="Samlple rate to which the audio files should be resampled", + ) + + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=False, + help="Path of the destination folder. If not defined, the operation is done in place", + ) + + parser.add_argument( + "--file_ext", + type=str, + default="wav", + required=False, + help="Extension of the audio files to resample", + ) + + parser.add_argument( + "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" + ) + + args = parser.parse_args() + + resample_files(args.input_dir, args.output_sr, args.output_dir, args.file_ext, args.n_jobs) diff --git a/viXTTS/TTS/bin/synthesize.py b/viXTTS/TTS/bin/synthesize.py new file mode 100644 index 0000000000000000000000000000000000000000..b86252ab676bcc1acfab1f6616153ea16a4528e6 --- /dev/null +++ b/viXTTS/TTS/bin/synthesize.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import contextlib +import sys +from argparse import RawTextHelpFormatter + +# pylint: disable=redefined-outer-name, unused-argument +from pathlib import Path + +description = """ +Synthesize speech on command line. + +You can either use your trained model or choose a model from the provided list. + +If you don't specify any models, then it uses LJSpeech based English model. + +#### Single Speaker Models + +- List provided models: + + ``` + $ tts --list_models + ``` + +- Get model info (for both tts_models and vocoder_models): + + - Query by type/name: + The model_info_by_name uses the name as it from the --list_models. + ``` + $ tts --model_info_by_name "///" + ``` + For example: + ``` + $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts + $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 + ``` + - Query by type/idx: + The model_query_idx uses the corresponding idx from --list_models. + + ``` + $ tts --model_info_by_idx "/" + ``` + + For example: + + ``` + $ tts --model_info_by_idx tts_models/3 + ``` + + - Query info for model info by full name: + ``` + $ tts --model_info_by_name "///" + ``` + +- Run TTS with default models: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav + ``` + +- Run TTS and pipe out the generated TTS wav file data: + + ``` + $ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay + ``` + +- Run a TTS model with its default vocoder model: + + ``` + $ tts --text "Text for TTS" --model_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav + ``` + +- Run with specific TTS and vocoder models from the list: + + ``` + $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav + ``` + +- Run your own TTS model (Using Griffin-Lim Vocoder): + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + ``` + +- Run your own TTS and Vocoder models: + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json + ``` + +#### Multi-speaker Models + +- List the available speakers and choose a among them: + + ``` + $ tts --model_name "//" --list_speaker_idxs + ``` + +- Run the multi-speaker TTS model with the target speaker ID: + + ``` + $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx + ``` + +- Run your own multi-speaker TTS model: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx + ``` + +### Voice Conversion Models + +``` +$ tts --out_path output/path/speech.wav --model_name "//" --source_wav --target_wav +``` +""" + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + if v.lower() in ("no", "false", "f", "n", "0"): + return False + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def main(): + parser = argparse.ArgumentParser( + description=description.replace(" ```\n", ""), + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--list_models", + type=str2bool, + nargs="?", + const=True, + default=False, + help="list available pre-trained TTS and vocoder models.", + ) + + parser.add_argument( + "--model_info_by_idx", + type=str, + default=None, + help="model info using query format: /", + ) + + parser.add_argument( + "--model_info_by_name", + type=str, + default=None, + help="model info using query format: ///", + ) + + parser.add_argument("--text", type=str, default=None, help="Text to generate speech.") + + # Args for running pre-trained TTS models. + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/tacotron2-DDC", + help="Name of one of the pre-trained TTS models in format //", + ) + parser.add_argument( + "--vocoder_name", + type=str, + default=None, + help="Name of one of the pre-trained vocoder models in format //", + ) + + # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model file.", + ) + parser.add_argument( + "--out_path", + type=str, + default="tts_output.wav", + help="Output wav file path.", + ) + parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) + parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu") + parser.add_argument( + "--vocoder_path", + type=str, + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", + default=None, + ) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) + parser.add_argument( + "--encoder_path", + type=str, + help="Path to speaker encoder model file.", + default=None, + ) + parser.add_argument("--encoder_config_path", type=str, help="Path to speaker encoder config file.", default=None) + parser.add_argument( + "--pipe_out", + help="stdout the generated TTS wav file for shell pipe.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + + # args for multi-speaker synthesis + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) + parser.add_argument( + "--speaker_idx", + type=str, + help="Target speaker ID for a multi-speaker TTS model.", + default=None, + ) + parser.add_argument( + "--language_idx", + type=str, + help="Target language ID for a multi-lingual TTS model.", + default=None, + ) + parser.add_argument( + "--speaker_wav", + nargs="+", + help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.", + default=None, + ) + parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None) + parser.add_argument( + "--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None + ) + parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None) + parser.add_argument( + "--list_speaker_idxs", + help="List available speaker ids for the defined multi-speaker model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + parser.add_argument( + "--list_language_idxs", + help="List available language ids for the defined multi-lingual model.", + type=str2bool, + nargs="?", + const=True, + default=False, + ) + # aux args + parser.add_argument( + "--save_spectogram", + type=bool, + help="If true save raw spectogram for further (vocoder) processing in out_path.", + default=False, + ) + parser.add_argument( + "--reference_wav", + type=str, + help="Reference wav file to convert in the voice of the speaker_idx or speaker_wav", + default=None, + ) + parser.add_argument( + "--reference_speaker_idx", + type=str, + help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).", + default=None, + ) + parser.add_argument( + "--progress_bar", + type=str2bool, + help="If true shows a progress bar for the model download. Defaults to True", + default=True, + ) + + # voice conversion args + parser.add_argument( + "--source_wav", + type=str, + default=None, + help="Original audio file to convert in the voice of the target_wav", + ) + parser.add_argument( + "--target_wav", + type=str, + default=None, + help="Target audio file to convert in the voice of the source_wav", + ) + + parser.add_argument( + "--voice_dir", + type=str, + default=None, + help="Voice dir for tortoise model", + ) + + args = parser.parse_args() + + # print the description if either text or list_models is not set + check_args = [ + args.text, + args.list_models, + args.list_speaker_idxs, + args.list_language_idxs, + args.reference_wav, + args.model_info_by_idx, + args.model_info_by_name, + args.source_wav, + args.target_wav, + ] + if not any(check_args): + parser.parse_args(["-h"]) + + pipe_out = sys.stdout if args.pipe_out else None + + with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout): + # Late-import to make things load faster + from TTS.api import TTS + from TTS.utils.manage import ModelManager + from TTS.utils.synthesizer import Synthesizer + + # load model manager + path = Path(__file__).parent / "../.models.json" + manager = ModelManager(path, progress_bar=args.progress_bar) + api = TTS() + + tts_path = None + tts_config_path = None + speakers_file_path = None + language_ids_file_path = None + vocoder_path = None + vocoder_config_path = None + encoder_path = None + encoder_config_path = None + vc_path = None + vc_config_path = None + model_dir = None + + # CASE1 #list : list pre-trained TTS models + if args.list_models: + manager.list_models() + sys.exit() + + # CASE2 #info : model info for pre-trained TTS models + if args.model_info_by_idx: + model_query = args.model_info_by_idx + manager.model_info_by_idx(model_query) + sys.exit() + + if args.model_info_by_name: + model_query_full_name = args.model_info_by_name + manager.model_info_by_full_name(model_query_full_name) + sys.exit() + + # CASE3: load pre-trained model paths + if args.model_name is not None and not args.model_path: + model_path, config_path, model_item = manager.download_model(args.model_name) + # tts model + if model_item["model_type"] == "tts_models": + tts_path = model_path + tts_config_path = config_path + if "default_vocoder" in model_item: + args.vocoder_name = ( + model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + ) + + # voice conversion model + if model_item["model_type"] == "voice_conversion_models": + vc_path = model_path + vc_config_path = config_path + + # tts model with multiple files to be loaded from the directory path + if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list): + model_dir = model_path + tts_path = None + tts_config_path = None + args.vocoder_name = None + + # load vocoder + if args.vocoder_name is not None and not args.vocoder_path: + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) + + # CASE4: set custom model paths + if args.model_path is not None: + tts_path = args.model_path + tts_config_path = args.config_path + speakers_file_path = args.speakers_file_path + language_ids_file_path = args.language_ids_file_path + + if args.vocoder_path is not None: + vocoder_path = args.vocoder_path + vocoder_config_path = args.vocoder_config_path + + if args.encoder_path is not None: + encoder_path = args.encoder_path + encoder_config_path = args.encoder_config_path + + device = args.device + if args.use_cuda: + device = "cuda" + + # load models + synthesizer = Synthesizer( + tts_path, + tts_config_path, + speakers_file_path, + language_ids_file_path, + vocoder_path, + vocoder_config_path, + encoder_path, + encoder_config_path, + vc_path, + vc_config_path, + model_dir, + args.voice_dir, + ).to(device) + + # query speaker ids of a multi-speaker model. + if args.list_speaker_idxs: + print( + " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + ) + print(synthesizer.tts_model.speaker_manager.name_to_id) + return + + # query langauge ids of a multi-lingual model. + if args.list_language_idxs: + print( + " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + ) + print(synthesizer.tts_model.language_manager.name_to_id) + return + + # check the arguments against a multi-speaker model. + if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): + print( + " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to " + "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." + ) + return + + # RUN THE SYNTHESIS + if args.text: + print(" > Text: {}".format(args.text)) + + # kick it + if tts_path is not None: + wav = synthesizer.tts( + args.text, + speaker_name=args.speaker_idx, + language_name=args.language_idx, + speaker_wav=args.speaker_wav, + reference_wav=args.reference_wav, + style_wav=args.capacitron_style_wav, + style_text=args.capacitron_style_text, + reference_speaker_name=args.reference_speaker_idx, + ) + elif vc_path is not None: + wav = synthesizer.voice_conversion( + source_wav=args.source_wav, + target_wav=args.target_wav, + ) + elif model_dir is not None: + wav = synthesizer.tts( + args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav + ) + + # save the results + print(" > Saving output to {}".format(args.out_path)) + synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/bin/train_encoder.py b/viXTTS/TTS/bin/train_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a32ad00f56ee730ff1abe5770c802b01246aed06 --- /dev/null +++ b/viXTTS/TTS/bin/train_encoder.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import sys +import time +import traceback + +import torch +from torch.utils.data import DataLoader +from trainer.io import copy_model_files, save_best_model, save_checkpoint +from trainer.torch import NoamLR +from trainer.trainer_utils import get_optimizer + +from TTS.encoder.dataset import EncoderDataset +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.encoder.utils.training import init_training +from TTS.encoder.utils.visual import plot_embeddings +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import count_parameters, remove_experiment_folder +from TTS.utils.samplers import PerfectBatchSampler +from TTS.utils.training import check_update + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True +torch.manual_seed(54321) +use_cuda = torch.cuda.is_available() +num_gpus = torch.cuda.device_count() +print(" > Using CUDA: ", use_cuda) +print(" > Number of GPUs: ", num_gpus) + + +def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): + num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class + num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch + + dataset = EncoderDataset( + c, + ap, + meta_data_eval if is_val else meta_data_train, + voice_len=c.voice_len, + num_utter_per_class=num_utter_per_class, + num_classes_in_batch=num_classes_in_batch, + verbose=verbose, + augmentation_config=c.audio_augmentation if not is_val else None, + use_torch_spec=c.model_params.get("use_torch_spec", False), + ) + # get classes list + classes = dataset.get_class_list() + + sampler = PerfectBatchSampler( + dataset.items, + classes, + batch_size=num_classes_in_batch * num_utter_per_class, # total batch size + num_classes_in_batch=num_classes_in_batch, + num_gpus=1, + shuffle=not is_val, + drop_last=True, + ) + + if len(classes) < num_classes_in_batch: + if is_val: + raise RuntimeError( + f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !" + ) + raise RuntimeError( + f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !" + ) + + # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal + if is_val: + dataset.set_classes(train_classes) + + loader = DataLoader( + dataset, + num_workers=c.num_loader_workers, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + ) + + return loader, classes, dataset.get_map_classid_to_classname() + + +def evaluation(model, criterion, data_loader, global_step): + eval_loss = 0 + for _, data in enumerate(data_loader): + with torch.no_grad(): + # setup input data + inputs, labels = data + + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose( + labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1 + ).reshape(labels.shape) + inputs = torch.transpose( + inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1 + ).reshape(inputs.shape) + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels + ) + + eval_loss += loss.item() + + eval_avg_loss = eval_loss / len(data_loader) + # save stats + dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss}) + # plot the last batch in the evaluation + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.eval_figures(global_step, figures) + return eval_avg_loss + + +def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): + model.train() + best_loss = {"train_loss": None, "eval_loss": float("inf")} + avg_loader_time = 0 + end_time = time.time() + for epoch in range(c.epochs): + tot_loss = 0 + epoch_time = 0 + for _, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + inputs, labels = data + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape( + labels.shape + ) + inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape( + inputs.shape + ) + # ToDo: move it to a unit test + # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) + # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + # idx = 0 + # for j in range(0, c.num_classes_in_batch, 1): + # for i in range(j, len(labels), c.num_classes_in_batch): + # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): + # print("Invalid") + # print(labels) + # exit() + # idx += 1 + # labels = labels_converted + # inputs = inputs_converted + + loader_time = time.time() - end_time + global_step += 1 + + # setup lr + if c.lr_decay: + scheduler.step() + optimizer.zero_grad() + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels + ) + loss.backward() + grad_norm, _ = check_update(model, c.grad_clip) + optimizer.step() + + step_time = time.time() - start_time + epoch_time += step_time + + # acumulate the total epoch loss + tot_loss += loss.item() + + # Averaged Loader Time + num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 + avg_loader_time = ( + 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time + if avg_loader_time != 0 + else loader_time + ) + current_lr = optimizer.param_groups[0]["lr"] + + if global_step % c.steps_plot_stats == 0: + # Plot Training Epoch Stats + train_stats = { + "loss": loss.item(), + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time, + "avg_loader_time": avg_loader_time, + } + dashboard_logger.train_epoch_stats(global_step, train_stats) + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.train_figures(global_step, figures) + + if global_step % c.print_step == 0: + print( + " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} " + "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( + global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr + ), + flush=True, + ) + + if global_step % c.save_step == 0: + # save model + save_checkpoint( + c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict() + ) + + end_time = time.time() + + print("") + print( + ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " + "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( + epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time + ), + flush=True, + ) + # evaluation + if c.run_eval: + model.eval() + eval_loss = evaluation(model, criterion, eval_data_loader, global_step) + print("\n\n") + print("--> EVAL PERFORMANCE") + print( + " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss), + flush=True, + ) + # save the best checkpoint + best_loss = save_best_model( + {"train_loss": None, "eval_loss": eval_loss}, + best_loss, + c, + model, + optimizer, + None, + global_step, + epoch, + OUT_PATH, + criterion=criterion.state_dict(), + ) + model.train() + + return best_loss, global_step + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data_train + global meta_data_eval + global train_classes + + ap = AudioProcessor(**c.audio) + model = setup_encoder_model(c) + + optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model) + + # pylint: disable=redefined-outer-name + meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) + + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False, verbose=True) + if c.run_eval: + eval_data_loader, _, _ = setup_loader(ap, is_val=True, verbose=True) + else: + eval_data_loader = None + + num_classes = len(train_classes) + criterion = model.get_criterion(c, num_classes) + + if c.loss == "softmaxproto" and c.model != "speaker_encoder": + c.map_classid_to_classname = map_classid_to_classname + copy_model_files(c, OUT_PATH, new_fields={}) + + if args.restore_path: + criterion, args.restore_step = model.load_checkpoint( + c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion + ) + print(" > Model restored from step %d" % args.restore_step, flush=True) + else: + args.restore_step = 0 + + if c.lr_decay: + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) + else: + scheduler = None + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + + if use_cuda: + model = model.cuda() + criterion.cuda() + + global_step = args.restore_step + _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step) + + +if __name__ == "__main__": + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/viXTTS/TTS/bin/train_tts.py b/viXTTS/TTS/bin/train_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb4f6f69122a4a9aa4e07695f1816ce9727f323 --- /dev/null +++ b/viXTTS/TTS/bin/train_tts.py @@ -0,0 +1,71 @@ +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models import setup_model + + +@dataclass +class TrainTTSArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainTTSArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the model from config + model = setup_model(config, train_samples + eval_samples) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + model.config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/bin/train_vocoder.py b/viXTTS/TTS/bin/train_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..32ecd7bdc3652b3683be846bdd9518e937aee904 --- /dev/null +++ b/viXTTS/TTS/bin/train_vocoder.py @@ -0,0 +1,77 @@ +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.models import setup_model + + +@dataclass +class TrainVocoderArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainVocoderArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + if "feature_path" in config and config.feature_path: + # load pre-computed features + print(f" > Loading features from: {config.feature_path}") + eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size) + else: + # load data raw wav files + eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # init the model from config + model = setup_model(config) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/bin/tune_wavegrad.py b/viXTTS/TTS/bin/tune_wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..09582cea7c7962b098efcde5754a02573d18264a --- /dev/null +++ b/viXTTS/TTS/bin/tune_wavegrad.py @@ -0,0 +1,103 @@ +"""Search a good noise schedule for WaveGrad for a given number of inference iterations""" +import argparse +from itertools import product as cartesian_product + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.models import setup_model + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") + parser.add_argument("--config_path", type=str, help="Path to model config file.") + parser.add_argument("--data_path", type=str, help="Path to data directory.") + parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") + parser.add_argument( + "--num_iter", + type=int, + help="Number of model inference iterations that you like to optimize noise schedule for.", + ) + parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") + parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") + parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", + ) + + # load config + args = parser.parse_args() + config = load_config(args.config_path) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # load dataset + _, train_data = load_wav_data(args.data_path, 0) + train_data = train_data[: args.num_samples] + dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True, + ) + loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_full_clips, + drop_last=False, + num_workers=config.num_loader_workers, + pin_memory=False, + ) + + # setup the model + model = setup_model(config) + if args.use_cuda: + model.cuda() + + # setup optimization parameters + base_values = sorted(10 * np.random.uniform(size=args.search_depth)) + print(f" > base values: {base_values}") + exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) + best_error = float("inf") + best_schedule = None # pylint: disable=C0103 + total_search_iter = len(base_values) ** args.num_iter + for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): + beta = exponents * base + model.compute_noise_level(beta) + for data in loader: + mel, audio = data + y_hat = model.inference(mel.cuda() if args.use_cuda else mel) + + if args.use_cuda: + y_hat = y_hat.cpu() + y_hat = y_hat.numpy() + + mel_hat = [] + for i in range(y_hat.shape[0]): + m = ap.melspectrogram(y_hat[i, 0])[:, :-1] + mel_hat.append(torch.from_numpy(m)) + + mel_hat = torch.stack(mel_hat) + mse = torch.sum((mel - mel_hat) ** 2).mean() + if mse.item() < best_error: + best_error = mse.item() + best_schedule = {"beta": beta} + print(f" > Found a better schedule. - MSE: {mse.item()}") + np.save(args.output_path, best_schedule) diff --git a/viXTTS/TTS/config/__init__.py b/viXTTS/TTS/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a6dd68e24e7f2c2c67504bee4c19f1eb6660a1 --- /dev/null +++ b/viXTTS/TTS/config/__init__.py @@ -0,0 +1,135 @@ +import json +import os +import re +from typing import Dict + +import fsspec +import yaml +from coqpit import Coqpit + +from TTS.config.shared_configs import * +from TTS.utils.generic_utils import find_module + + +def read_json_with_comments(json_path): + """for backward compat.""" + # fallback to json + with fsspec.open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments but not urls with // + input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str) + return json.loads(input_str) + +def register_config(model_name: str) -> Coqpit: + """Find the right config for the given model name. + + Args: + model_name (str): Model name. + + Raises: + ModuleNotFoundError: No matching config for the model name. + + Returns: + Coqpit: config class. + """ + config_class = None + config_name = model_name + "_config" + + # TODO: fix this + if model_name == "xtts": + from TTS.tts.configs.xtts_config import XttsConfig + + config_class = XttsConfig + paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"] + for path in paths: + try: + config_class = find_module(path, config_name) + except ModuleNotFoundError: + pass + if config_class is None: + raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.") + return config_class + + +def _process_model_name(config_dict: Dict) -> str: + """Format the model name as expected. It is a band-aid for the old `vocoder` model names. + + Args: + config_dict (Dict): A dictionary including the config fields. + + Returns: + str: Formatted modelname. + """ + model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] + model_name = model_name.replace("_generator", "").replace("_discriminator", "") + return model_name + + +def load_config(config_path: str) -> Coqpit: + """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name + to find the corresponding Config class. Then initialize the Config. + + Args: + config_path (str): path to the config file. + + Raises: + TypeError: given config file has an unknown type. + + Returns: + Coqpit: TTS config object. + """ + config_dict = {} + ext = os.path.splitext(config_path)[1] + if ext in (".yml", ".yaml"): + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + elif ext == ".json": + try: + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.decoder.JSONDecodeError: + # backwards compat. + data = read_json_with_comments(config_path) + else: + raise TypeError(f" [!] Unknown config file type {ext}") + config_dict.update(data) + model_name = _process_model_name(config_dict) + config_class = register_config(model_name.lower()) + config = config_class() + config.from_dict(config_dict) + return config + + +def check_config_and_model_args(config, arg_name, value): + """Check the give argument in `config.model_args` if exist or in `config` for + the given value. + + Return False if the argument does not exist in `config.model_args` or `config`. + This is to patch up the compatibility between models with and without `model_args`. + + TODO: Remove this in the future with a unified approach. + """ + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] == value + if hasattr(config, arg_name): + return config[arg_name] == value + return False + + +def get_from_config_or_model_args(config, arg_name): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + return config[arg_name] + + +def get_from_config_or_model_args_with_default(config, arg_name, def_val): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + if hasattr(config, arg_name): + return config[arg_name] + return def_val diff --git a/viXTTS/TTS/config/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/config/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96751159b1b2b969c0787d415828b4e4b97b37a7 Binary files /dev/null and b/viXTTS/TTS/config/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/config/__pycache__/shared_configs.cpython-310.pyc b/viXTTS/TTS/config/__pycache__/shared_configs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..752967c3610446b70d38e52ede16435b785b0aaf Binary files /dev/null and b/viXTTS/TTS/config/__pycache__/shared_configs.cpython-310.pyc differ diff --git a/viXTTS/TTS/config/shared_configs.py b/viXTTS/TTS/config/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..7fae77d61361eff8c8fa521a0f4a90dc46f63c75 --- /dev/null +++ b/viXTTS/TTS/config/shared_configs.py @@ -0,0 +1,268 @@ +from dataclasses import asdict, dataclass +from typing import List + +from coqpit import Coqpit, check_argument +from trainer import TrainerConfig + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 1024. + + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + + sample_rate (int): + Audio sampling rate. Defaults to 22050. + + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + pitch_fmax (float, optional): + Maximum frequency of the F0 frames. Defaults to ```640```. + + pitch_fmin (float, optional): + Minimum frequency of the F0 frames. Defaults to ```1```. + + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # rms volume normalization + do_rms_norm: bool = False + db_level: float = None + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + # f0 params + pitch_fmax: float = 640.0 + pitch_fmin: float = 1.0 + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + """Base config for TTS datasets. + + Args: + formatter (str): + Formatter name that defines used formatter in ```TTS.tts.datasets.formatter```. Defaults to `""`. + + dataset_name (str): + Unique name for the dataset. Defaults to `""`. + + path (str): + Root path to the dataset files. Defaults to `""`. + + meta_file_train (str): + Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. + Defaults to `""`. + + ignored_speakers (List): + List of speakers IDs that are not used at the training. Default None. + + language (str): + Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`. + + phonemizer (str): + Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`. + + meta_file_val (str): + Name of the dataset meta file that defines the instances used at validation. + + meta_file_attn_mask (str): + Path to the file that lists the attention mask files used with models that require attention masks to + train the duration predictor. + """ + + formatter: str = "" + dataset_name: str = "" + path: str = "" + meta_file_train: str = "" + ignored_speakers: List[str] = None + language: str = "" + phonemizer: str = "" + meta_file_val: str = "" + meta_file_attn_mask: str = "" + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("formatter", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(TrainerConfig): + """Base config to define the basic 🐸TTS training parameters that are shared + among all the models. It is based on ```Trainer.TrainingConfig```. + + Args: + model (str): + Name of the model that is used in the training. + + num_loader_workers (int): + Number of workers for training time dataloader. + + num_eval_loader_workers (int): + Number of workers for evaluation time dataloader. + """ + + model: str = None + # dataloading + num_loader_workers: int = 0 + num_eval_loader_workers: int = 0 + use_noise_augment: bool = False diff --git a/viXTTS/TTS/demos/xtts_ft_demo/requirements.txt b/viXTTS/TTS/demos/xtts_ft_demo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cb5b16f66e295fe0de66ae16cfa4ad34afa8845f --- /dev/null +++ b/viXTTS/TTS/demos/xtts_ft_demo/requirements.txt @@ -0,0 +1,2 @@ +faster_whisper==0.9.0 +gradio==4.7.1 \ No newline at end of file diff --git a/viXTTS/TTS/demos/xtts_ft_demo/utils/formatter.py b/viXTTS/TTS/demos/xtts_ft_demo/utils/formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..536faa01086d788c82287e87b660279fbcc2ad7c --- /dev/null +++ b/viXTTS/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -0,0 +1,160 @@ +import os +import gc +import torchaudio +import pandas +from faster_whisper import WhisperModel +from glob import glob + +from tqdm import tqdm + +import torch +import torchaudio +# torch.set_num_threads(1) + +from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners + +torch.set_num_threads(16) + + +import os + +audio_types = (".wav", ".mp3", ".flac") + + +def list_audios(basePath, contains=None): + # return the set of files that are valid + return list_files(basePath, validExts=audio_types, contains=contains) + +def list_files(basePath, validExts=None, contains=None): + # loop over the directory structure + for (rootDir, dirNames, filenames) in os.walk(basePath): + # loop over the filenames in the current directory + for filename in filenames: + # if the contains string is not none and the filename does not contain + # the supplied string, then ignore the file + if contains is not None and filename.find(contains) == -1: + continue + + # determine the file extension of the current file + ext = filename[filename.rfind("."):].lower() + + # check to see if the file is an audio and should be processed + if validExts is None or ext.endswith(validExts): + # construct the path to the audio and yield it + audioPath = os.path.join(rootDir, filename) + yield audioPath + +def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + audio_total_size = 0 + # make sure that ooutput file exists + os.makedirs(out_path, exist_ok=True) + + # Loading Whisper + device = "cuda" if torch.cuda.is_available() else "cpu" + + print("Loading Whisper Model!") + asr_model = WhisperModel("large-v2", device=device, compute_type="float16") + + metadata = {"audio_file": [], "text": [], "speaker_name": []} + + if gradio_progress is not None: + tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...") + else: + tqdm_object = tqdm(audio_files) + + for audio_path in tqdm_object: + wav, sr = torchaudio.load(audio_path) + # stereo to mono if needed + if wav.size(0) != 1: + wav = torch.mean(wav, dim=0, keepdim=True) + + wav = wav.squeeze() + audio_total_size += (wav.size(-1) / sr) + + segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) + segments = list(segments) + i = 0 + sentence = "" + sentence_start = None + first_word = True + # added all segments words in a unique list + words_list = [] + for _, segment in enumerate(segments): + words = list(segment.words) + words_list.extend(words) + + # process each word + for word_idx, word in enumerate(words_list): + if first_word: + sentence_start = word.start + # If it is the first sentence, add buffer or get the begining of the file + if word_idx == 0: + sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start + else: + # get previous sentence end + previous_word_end = words_list[word_idx - 1].end + # add buffer or get the silence midle between the previous sentence and the current one + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) + + sentence = word.word + first_word = False + else: + sentence += word.word + + if word.word[-1] in ["!", ".", "?"]: + sentence = sentence[1:] + # Expand number and abbreviations plus normalization + sentence = multilingual_cleaners(sentence, target_language) + audio_file_name, _ = os.path.splitext(os.path.basename(audio_path)) + + audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav" + + # Check for the next word's existence + if word_idx + 1 < len(words_list): + next_word_start = words_list[word_idx + 1].start + else: + # If don't have more words it means that it is the last sentence then use the audio len as next word start + next_word_start = (wav.shape[0] - 1) / sr + + # Average the current word end and next word start + word_end = min((word.end + next_word_start) / 2, word.end + buffer) + + absoulte_path = os.path.join(out_path, audio_file) + os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) + i += 1 + first_word = True + + audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) + # if the audio is too short ignore it (i.e < 0.33 seconds) + if audio.size(-1) >= sr/3: + torchaudio.save(absoulte_path, + audio, + sr + ) + else: + continue + + metadata["audio_file"].append(audio_file) + metadata["text"].append(sentence) + metadata["speaker_name"].append(speaker_name) + + df = pandas.DataFrame(metadata) + df = df.sample(frac=1) + num_val_samples = int(len(df)*eval_percentage) + + df_eval = df[:num_val_samples] + df_train = df[num_val_samples:] + + df_train = df_train.sort_values('audio_file') + train_metadata_path = os.path.join(out_path, "metadata_train.csv") + df_train.to_csv(train_metadata_path, sep="|", index=False) + + eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") + df_eval = df_eval.sort_values('audio_file') + df_eval.to_csv(eval_metadata_path, sep="|", index=False) + + # deallocate VRAM and RAM + del asr_model, df_train, df_eval, df, metadata + gc.collect() + + return train_metadata_path, eval_metadata_path, audio_total_size \ No newline at end of file diff --git a/viXTTS/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/viXTTS/TTS/demos/xtts_ft_demo/utils/gpt_train.py new file mode 100644 index 0000000000000000000000000000000000000000..a98765c3e79b6227797ba528d425ec14cc8dbd45 --- /dev/null +++ b/viXTTS/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -0,0 +1,172 @@ +import os +import gc + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.utils.manage import ModelManager + + +def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995): + # Logging parameters + RUN_NAME = "GPT_XTTS_FT" + PROJECT_NAME = "XTTS_trainer" + DASHBOARD_LOGGER = "tensorboard" + LOGGER_URI = None + + # Set here the path that the checkpoints will be saved. Default: ./run/training/ + OUT_PATH = os.path.join(output_path, "run", "training") + + # Training Parameters + OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False + START_WITH_EVAL = False # if True it will star with evaluation + BATCH_SIZE = batch_size # set here the batch size + GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps + + + # Define here the dataset that you want to use for the fine-tuning on. + config_dataset = BaseDatasetConfig( + formatter="coqui", + dataset_name="ft_dataset", + path=os.path.dirname(train_csv), + meta_file_train=train_csv, + meta_file_val=eval_csv, + language=language, + ) + + # Add here the configs of the datasets + DATASETS_CONFIG_LIST = [config_dataset] + + # Define the path where XTTS v2.0.1 files will be downloaded + CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") + os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) + + + # DVAE files + DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" + MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" + + # Set the path to the downloaded files + DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK)) + MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK)) + + # download DVAE files if needed + if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): + print(" > Downloading DVAE files!") + ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) + + + # Download XTTS v2.0 checkpoint if needed + TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" + XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" + XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json" + + # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. + TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file + XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file + XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file + + # download XTTS v2.0 files if needed + if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): + print(" > Downloading XTTS v2.0 files!") + ModelManager._download_model_files( + [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) + + # init args and config + model_args = GPTArgs( + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs + debug_loading_failures=False, + max_wav_length=max_audio_length, # ~11.6 seconds + max_text_length=200, + mel_norm_file=MEL_NORM_FILE, + dvae_checkpoint=DVAE_CHECKPOINT, + xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune + tokenizer_file=TOKENIZER_FILE, + gpt_num_audio_tokens=1026, + gpt_start_audio_token=1024, + gpt_stop_audio_token=1025, + gpt_use_masking_gt_prompt_approach=True, + gpt_use_perceiver_resampler=True, + ) + # define audio config + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) + # training parameters config + config = GPTTrainerConfig( + epochs=num_epochs, + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name=PROJECT_NAME, + run_description=""" + GPT XTTS training + """, + dashboard_logger=DASHBOARD_LOGGER, + logger_uri=LOGGER_URI, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=100, + save_step=1000, + save_n_checkpoints=1, + save_checkpoints=True, + # target_loss="loss", + print_eval=False, + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. + optimizer="AdamW", + optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate + lr_scheduler="MultiStepLR", + # it was adjusted accordly for the new step scheme + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + test_sentences=[], + ) + + # init the model from config + model = GPTTrainer.init_from_config(config) + + # load training samples + train_samples, eval_samples = load_tts_samples( + DATASETS_CONFIG_LIST, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the trainer and 🚀 + trainer = Trainer( + TrainerArgs( + restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter + skip_train_epoch=False, + start_with_eval=START_WITH_EVAL, + grad_accum_steps=GRAD_ACUMM_STEPS, + ), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + ) + trainer.fit() + + # get the longest text audio file to use as speaker reference + samples_len = [len(item["text"].split(" ")) for item in train_samples] + longest_text_idx = samples_len.index(max(samples_len)) + speaker_ref = train_samples[longest_text_idx]["audio_file"] + + trainer_out_path = trainer.output_path + + # deallocate VRAM and RAM + del model, trainer, train_samples, eval_samples + gc.collect() + + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref diff --git a/viXTTS/TTS/demos/xtts_ft_demo/xtts_demo.py b/viXTTS/TTS/demos/xtts_ft_demo/xtts_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb11f29d16084c9fb6dd46b42fbed017f487374 --- /dev/null +++ b/viXTTS/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -0,0 +1,415 @@ +import argparse +import os +import sys +import tempfile + +import gradio as gr +import librosa.display +import numpy as np + +import os +import torch +import torchaudio +import traceback +from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list +from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + + +def clear_gpu_cache(): + # clear the GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +XTTS_MODEL = None +def load_model(xtts_checkpoint, xtts_config, xtts_vocab): + global XTTS_MODEL + clear_gpu_cache() + if not xtts_checkpoint or not xtts_config or not xtts_vocab: + return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" + config = XttsConfig() + config.load_json(xtts_config) + XTTS_MODEL = Xtts.init_from_config(config) + print("Loading XTTS model! ") + XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) + if torch.cuda.is_available(): + XTTS_MODEL.cuda() + + print("Model Loaded!") + return "Model Loaded!" + +def run_tts(lang, tts_text, speaker_audio_file): + if XTTS_MODEL is None or not speaker_audio_file: + return "You need to run the previous step to load the model !!", None, None + + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) + out = XTTS_MODEL.inference( + text=tts_text, + language=lang, + gpt_cond_latent=gpt_cond_latent, + speaker_embedding=speaker_embedding, + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + length_penalty=XTTS_MODEL.config.length_penalty, + repetition_penalty=XTTS_MODEL.config.repetition_penalty, + top_k=XTTS_MODEL.config.top_k, + top_p=XTTS_MODEL.config.top_p, + ) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) + out_path = fp.name + torchaudio.save(out_path, out["wav"], 24000) + + return "Speech generated !", out_path, speaker_audio_file + + + + +# define a logger to redirect +class Logger: + def __init__(self, filename="log.out"): + self.log_file = filename + self.terminal = sys.stdout + self.log = open(self.log_file, "w") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + +# redirect stdout and stderr to a file +sys.stdout = Logger() +sys.stderr = sys.stdout + + +# logging.basicConfig(stream=sys.stdout, level=logging.INFO) +import logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout) + ] +) + +def read_logs(): + sys.stdout.flush() + with open(sys.stdout.log_file, "r") as f: + return f.read() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""XTTS fine-tuning demo\n\n""" + """ + Example runs: + python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--port", + type=int, + help="Port to run the gradio demo. Default: 5003", + default=5003, + ) + parser.add_argument( + "--out_path", + type=str, + help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/", + default="/tmp/xtts_ft/", + ) + + parser.add_argument( + "--num_epochs", + type=int, + help="Number of epochs to train. Default: 10", + default=10, + ) + parser.add_argument( + "--batch_size", + type=int, + help="Batch size. Default: 4", + default=4, + ) + parser.add_argument( + "--grad_acumm", + type=int, + help="Grad accumulation steps. Default: 1", + default=1, + ) + parser.add_argument( + "--max_audio_length", + type=int, + help="Max permitted audio size in seconds. Default: 11", + default=11, + ) + + args = parser.parse_args() + + with gr.Blocks() as demo: + with gr.Tab("1 - Data processing"): + out_path = gr.Textbox( + label="Output path (where data and checkpoints will be saved):", + value=args.out_path, + ) + # upload_file = gr.Audio( + # sources="upload", + # label="Select here the audio files that you want to use for XTTS trainining !", + # type="filepath", + # ) + upload_file = gr.File( + file_count="multiple", + label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", + ) + lang = gr.Dropdown( + label="Dataset Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja" + ], + ) + progress_data = gr.Label( + label="Progress:" + ) + logs = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs, every=1) + + prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") + + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): + clear_gpu_cache() + out_path = os.path.join(out_path, "dataset") + os.makedirs(out_path, exist_ok=True) + if audio_path is None: + return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" + else: + try: + train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + except: + traceback.print_exc() + error = traceback.format_exc() + return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" + + clear_gpu_cache() + + # if audio total len is less than 2 minutes raise an error + if audio_total_size < 120: + message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" + print(message) + return message, "", "" + + print("Dataset Processed!") + return "Dataset Processed!", train_meta, eval_meta + + with gr.Tab("2 - Fine-tuning XTTS Encoder"): + train_csv = gr.Textbox( + label="Train CSV:", + ) + eval_csv = gr.Textbox( + label="Eval CSV:", + ) + num_epochs = gr.Slider( + label="Number of epochs:", + minimum=1, + maximum=100, + step=1, + value=args.num_epochs, + ) + batch_size = gr.Slider( + label="Batch size:", + minimum=2, + maximum=512, + step=1, + value=args.batch_size, + ) + grad_acumm = gr.Slider( + label="Grad accumulation steps:", + minimum=2, + maximum=128, + step=1, + value=args.grad_acumm, + ) + max_audio_length = gr.Slider( + label="Max permitted audio size in seconds:", + minimum=2, + maximum=20, + step=1, + value=args.max_audio_length, + ) + progress_train = gr.Label( + label="Progress:" + ) + logs_tts_train = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs_tts_train, every=1) + train_btn = gr.Button(value="Step 2 - Run the training") + + def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): + clear_gpu_cache() + if not train_csv or not eval_csv: + return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" + try: + # convert seconds to waveform frames + max_audio_length = int(max_audio_length * 22050) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) + except: + traceback.print_exc() + error = traceback.format_exc() + return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" + + # copy original files to avoid parameters changes issues + os.system(f"cp {config_path} {exp_path}") + os.system(f"cp {vocab_file} {exp_path}") + + ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") + print("Model training done!") + clear_gpu_cache() + return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav + + with gr.Tab("3 - Inference"): + with gr.Row(): + with gr.Column() as col1: + xtts_checkpoint = gr.Textbox( + label="XTTS checkpoint path:", + value="", + ) + xtts_config = gr.Textbox( + label="XTTS config path:", + value="", + ) + + xtts_vocab = gr.Textbox( + label="XTTS vocab path:", + value="", + ) + progress_load = gr.Label( + label="Progress:" + ) + load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") + + with gr.Column() as col2: + speaker_reference_audio = gr.Textbox( + label="Speaker reference audio:", + value="", + ) + tts_language = gr.Dropdown( + label="Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja", + ] + ) + tts_text = gr.Textbox( + label="Input Text.", + value="This model sounds really good and above all, it's reasonably fast.", + ) + tts_btn = gr.Button(value="Step 4 - Inference") + + with gr.Column() as col3: + progress_gen = gr.Label( + label="Progress:" + ) + tts_output_audio = gr.Audio(label="Generated Audio.") + reference_audio = gr.Audio(label="Reference audio used.") + + prompt_compute_btn.click( + fn=preprocess_dataset, + inputs=[ + upload_file, + lang, + out_path, + ], + outputs=[ + progress_data, + train_csv, + eval_csv, + ], + ) + + + train_btn.click( + fn=train_model, + inputs=[ + lang, + train_csv, + eval_csv, + num_epochs, + batch_size, + grad_acumm, + out_path, + max_audio_length, + ], + outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], + ) + + load_btn.click( + fn=load_model, + inputs=[ + xtts_checkpoint, + xtts_config, + xtts_vocab + ], + outputs=[progress_load], + ) + + tts_btn.click( + fn=run_tts, + inputs=[ + tts_language, + tts_text, + speaker_reference_audio, + ], + outputs=[progress_gen, tts_output_audio, reference_audio], + ) + + demo.launch( + share=True, + debug=False, + server_port=args.port, + server_name="0.0.0.0" + ) diff --git a/viXTTS/TTS/encoder/README.md b/viXTTS/TTS/encoder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b38b20052b707b0358068bc0ce58bc300a149def --- /dev/null +++ b/viXTTS/TTS/encoder/README.md @@ -0,0 +1,18 @@ +### Speaker Encoder + +This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. + +With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart. + +Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q). + +![](umap.png) + +Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page. + +To run the code, you need to follow the same flow as in TTS. + +- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. +- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` +- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. +- Watch training on Tensorboard as in TTS diff --git a/viXTTS/TTS/encoder/__init__.py b/viXTTS/TTS/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/encoder/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/encoder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6f452b4342155456aaca590996e012dc6ba444e Binary files /dev/null and b/viXTTS/TTS/encoder/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/encoder/__pycache__/losses.cpython-310.pyc b/viXTTS/TTS/encoder/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37f6382464ffb8b2ed7dcf9395d602b3e6f361cf Binary files /dev/null and b/viXTTS/TTS/encoder/__pycache__/losses.cpython-310.pyc differ diff --git a/viXTTS/TTS/encoder/configs/base_encoder_config.py b/viXTTS/TTS/encoder/configs/base_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbaa0457bb55aef70d54dd36fd9b2b7f7c702bb --- /dev/null +++ b/viXTTS/TTS/encoder/configs/base_encoder_config.py @@ -0,0 +1,61 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import MISSING + +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseEncoderConfig(BaseTrainingConfig): + """Defines parameters for a Generic Encoder model.""" + + model: str = None + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # model params + model_params: Dict = field( + default_factory=lambda: { + "model_name": "lstm", + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": True, + } + ) + + audio_augmentation: Dict = field(default_factory=lambda: {}) + + # training params + epochs: int = 10000 + loss: str = "angleproto" + grad_clip: float = 3.0 + lr: float = 0.0001 + optimizer: str = "radam" + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0}) + lr_decay: bool = False + warmup_steps: int = 4000 + + # logging params + tb_model_param_stats: bool = False + steps_plot_stats: int = 10 + save_step: int = 1000 + print_step: int = 20 + run_eval: bool = False + + # data loader + num_classes_in_batch: int = MISSING + num_utter_per_class: int = MISSING + eval_num_classes_in_batch: int = None + eval_num_utter_per_class: int = None + + num_loader_workers: int = MISSING + voice_len: float = 1.6 + + def check_values(self): + super().check_values() + c = asdict(self) + assert ( + c["model_params"]["input_dim"] == self.audio.num_mels + ), " [!] model input dimendion must be equal to melspectrogram dimension." diff --git a/viXTTS/TTS/encoder/configs/emotion_encoder_config.py b/viXTTS/TTS/encoder/configs/emotion_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5eda2671be980abce4a0506a075387b601a1596c --- /dev/null +++ b/viXTTS/TTS/encoder/configs/emotion_encoder_config.py @@ -0,0 +1,12 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class EmotionEncoderConfig(BaseEncoderConfig): + """Defines parameters for Emotion Encoder model.""" + + model: str = "emotion_encoder" + map_classid_to_classname: dict = None + class_name_key: str = "emotion_name" diff --git a/viXTTS/TTS/encoder/configs/speaker_encoder_config.py b/viXTTS/TTS/encoder/configs/speaker_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6dceb00277ba68efe128936ff7f9456338f9753f --- /dev/null +++ b/viXTTS/TTS/encoder/configs/speaker_encoder_config.py @@ -0,0 +1,11 @@ +from dataclasses import asdict, dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class SpeakerEncoderConfig(BaseEncoderConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + class_name_key: str = "speaker_name" diff --git a/viXTTS/TTS/encoder/dataset.py b/viXTTS/TTS/encoder/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..582b1fe9ca35cb9afbc20b8f72b6173282201272 --- /dev/null +++ b/viXTTS/TTS/encoder/dataset.py @@ -0,0 +1,147 @@ +import random + +import torch +from torch.utils.data import Dataset + +from TTS.encoder.utils.generic_utils import AugmentWAV + + +class EncoderDataset(Dataset): + def __init__( + self, + config, + ap, + meta_data, + voice_len=1.6, + num_classes_in_batch=64, + num_utter_per_class=10, + verbose=False, + augmentation_config=None, + use_torch_spec=None, + ): + """ + Args: + ap (TTS.tts.utils.AudioProcessor): audio processor object. + meta_data (list): list of dataset instances. + seq_len (int): voice segment length in seconds. + verbose (bool): print diagnostic information. + """ + super().__init__() + self.config = config + self.items = meta_data + self.sample_rate = ap.sample_rate + self.seq_len = int(voice_len * self.sample_rate) + self.num_utter_per_class = num_utter_per_class + self.ap = ap + self.verbose = verbose + self.use_torch_spec = use_torch_spec + self.classes, self.items = self.__parse_items() + + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + # Data Augmentation + self.augmentator = None + self.gaussian_augmentation_config = None + if augmentation_config: + self.data_augmentation_p = augmentation_config["p"] + if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): + self.augmentator = AugmentWAV(ap, augmentation_config) + + if "gaussian" in augmentation_config.keys(): + self.gaussian_augmentation_config = augmentation_config["gaussian"] + + if self.verbose: + print("\n > DataLoader initialization") + print(f" | > Classes per Batch: {num_classes_in_batch}") + print(f" | > Number of instances : {len(self.items)}") + print(f" | > Sequence length: {self.seq_len}") + print(f" | > Num Classes: {len(self.classes)}") + print(f" | > Classes: {self.classes}") + + def load_wav(self, filename): + audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) + return audio + + def __parse_items(self): + class_to_utters = {} + for item in self.items: + path_ = item["audio_file"] + class_name = item[self.config.class_name_key] + if class_name in class_to_utters.keys(): + class_to_utters[class_name].append(path_) + else: + class_to_utters[class_name] = [ + path_, + ] + + # skip classes with number of samples >= self.num_utter_per_class + class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class} + + classes = list(class_to_utters.keys()) + classes.sort() + + new_items = [] + for item in self.items: + path_ = item["audio_file"] + class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"] + # ignore filtered classes + if class_name not in classes: + continue + # ignore small audios + if self.load_wav(path_).shape[0] - self.seq_len <= 0: + continue + + new_items.append({"wav_file_path": path_, "class_name": class_name}) + + return classes, new_items + + def __len__(self): + return len(self.items) + + def get_num_classes(self): + return len(self.classes) + + def get_class_list(self): + return self.classes + + def set_classes(self, classes): + self.classes = classes + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + def get_map_classid_to_classname(self): + return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) + + def __getitem__(self, idx): + return self.items[idx] + + def collate_fn(self, batch): + # get the batch class_ids + labels = [] + feats = [] + for item in batch: + utter_path = item["wav_file_path"] + class_name = item["class_name"] + + # get classid + class_id = self.classname_to_classid[class_name] + # load wav file + wav = self.load_wav(utter_path) + offset = random.randint(0, wav.shape[0] - self.seq_len) + wav = wav[offset : offset + self.seq_len] + + if self.augmentator is not None and self.data_augmentation_p: + if random.random() < self.data_augmentation_p: + wav = self.augmentator.apply_one(wav) + + if not self.use_torch_spec: + mel = self.ap.melspectrogram(wav) + feats.append(torch.FloatTensor(mel)) + else: + feats.append(torch.FloatTensor(wav)) + + labels.append(class_id) + + feats = torch.stack(feats) + labels = torch.LongTensor(labels) + + return feats, labels diff --git a/viXTTS/TTS/encoder/losses.py b/viXTTS/TTS/encoder/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5aa0fc48fe00aeedeff28ba48ed2af498ce582 --- /dev/null +++ b/viXTTS/TTS/encoder/losses.py @@ -0,0 +1,226 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +# adapted from https://github.com/cvqluu/GE2E-Loss +class GE2ELoss(nn.Module): + def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): + """ + Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector (e.g. d-vector) + Args: + - init_w (float): defines the initial value of w in Equation (5) of [1] + - init_b (float): definies the initial value of b in Equation (5) of [1] + """ + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.loss_method = loss_method + + print(" > Initialized Generalized End-to-End loss") + + assert self.loss_method in ["softmax", "contrast"] + + if self.loss_method == "softmax": + self.embed_loss = self.embed_loss_softmax + if self.loss_method == "contrast": + self.embed_loss = self.embed_loss_contrast + + # pylint: disable=R0201 + def calc_new_centroids(self, dvecs, centroids, spkr, utt): + """ + Calculates the new centroids excluding the reference utterance + """ + excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) + excl = torch.mean(excl, 0) + new_centroids = [] + for i, centroid in enumerate(centroids): + if i == spkr: + new_centroids.append(excl) + else: + new_centroids.append(centroid) + return torch.stack(new_centroids) + + def calc_cosine_sim(self, dvecs, centroids): + """ + Make the cosine similarity matrix with dims (N,M,N) + """ + cos_sim_matrix = [] + for spkr_idx, speaker in enumerate(dvecs): + cs_row = [] + for utt_idx, utterance in enumerate(speaker): + new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) + # vector based cosine similarity for speed + cs_row.append( + torch.clamp( + torch.mm( + utterance.unsqueeze(1).transpose(0, 1), + new_centroids.transpose(0, 1), + ) + / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), + 1e-6, + ) + ) + cs_row = torch.cat(cs_row, dim=0) + cos_sim_matrix.append(cs_row) + return torch.stack(cos_sim_matrix) + + # pylint: disable=R0201 + def embed_loss_softmax(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by taking softmax + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + # pylint: disable=R0201 + def embed_loss_contrast(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) + excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) + L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + def forward(self, x, _label=None): + """ + Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + centroids = torch.mean(x, 1) + cos_sim_matrix = self.calc_cosine_sim(x, centroids) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = self.w * cos_sim_matrix + self.b + L = self.embed_loss(x, cos_sim_matrix) + return L.mean() + + +# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py +class AngleProtoLoss(nn.Module): + """ + Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector + Args: + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, init_w=10.0, init_b=-5.0): + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.criterion = torch.nn.CrossEntropyLoss() + + print(" > Initialized Angular Prototypical loss") + + def forward(self, x, _label=None): + """ + Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + out_anchor = torch.mean(x[:, 1:, :], 1) + out_positive = x[:, 0, :] + num_speakers = out_anchor.size()[0] + + cos_sim_matrix = F.cosine_similarity( + out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), + out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), + ) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = cos_sim_matrix * self.w + self.b + label = torch.arange(num_speakers).to(cos_sim_matrix.device) + L = self.criterion(cos_sim_matrix, label) + return L + + +class SoftmaxLoss(nn.Module): + """ + Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + """ + + def __init__(self, embedding_dim, n_speakers): + super().__init__() + + self.criterion = torch.nn.CrossEntropyLoss() + self.fc = nn.Linear(embedding_dim, n_speakers) + + print("Initialised Softmax Loss") + + def forward(self, x, label=None): + # reshape for compatibility + x = x.reshape(-1, x.size()[-1]) + label = label.reshape(-1) + + x = self.fc(x) + L = self.criterion(x, label) + + return L + + def inference(self, embedding): + x = self.fc(embedding) + activations = torch.nn.functional.softmax(x, dim=1).squeeze(0) + class_id = torch.argmax(activations) + return class_id + + +class SoftmaxAngleProtoLoss(nn.Module): + """ + Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): + super().__init__() + + self.softmax = SoftmaxLoss(embedding_dim, n_speakers) + self.angleproto = AngleProtoLoss(init_w, init_b) + + print("Initialised SoftmaxAnglePrototypical Loss") + + def forward(self, x, label=None): + """ + Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + Lp = self.angleproto(x) + + Ls = self.softmax(x, label) + + return Ls + Lp diff --git a/viXTTS/TTS/encoder/models/__pycache__/base_encoder.cpython-310.pyc b/viXTTS/TTS/encoder/models/__pycache__/base_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3066e40767b35a0490a2f518ed6bd110ca282312 Binary files /dev/null and b/viXTTS/TTS/encoder/models/__pycache__/base_encoder.cpython-310.pyc differ diff --git a/viXTTS/TTS/encoder/models/__pycache__/lstm.cpython-310.pyc b/viXTTS/TTS/encoder/models/__pycache__/lstm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdedb0bde7d1dac4595a301f307a43a7050ef0e5 Binary files /dev/null and b/viXTTS/TTS/encoder/models/__pycache__/lstm.cpython-310.pyc differ diff --git a/viXTTS/TTS/encoder/models/__pycache__/resnet.cpython-310.pyc b/viXTTS/TTS/encoder/models/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4acb7fed1a718740de3511baf45a92b793d86336 Binary files /dev/null and b/viXTTS/TTS/encoder/models/__pycache__/resnet.cpython-310.pyc differ diff --git a/viXTTS/TTS/encoder/models/base_encoder.py b/viXTTS/TTS/encoder/models/base_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..957ea3c4ca719c2a054c93382787909e418288b2 --- /dev/null +++ b/viXTTS/TTS/encoder/models/base_encoder.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +import torchaudio +from coqpit import Coqpit +from torch import nn + +from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss +from TTS.utils.generic_utils import set_init_dict +from TTS.utils.io import load_fsspec + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + +class BaseEncoder(nn.Module): + """Base `encoder` class. Every new `encoder` model must inherit this. + + It defines common `encoder` specific functions. + """ + + # pylint: disable=W0102 + def __init__(self): + super(BaseEncoder, self).__init__() + + def get_torch_mel_spectrogram_class(self, audio_config): + return torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + @torch.no_grad() + def inference(self, x, l2_norm=True): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config["hop_length"] + + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + return embeddings + + def get_criterion(self, c: Coqpit, num_classes=None): + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method="softmax") + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + elif c.loss == "softmaxproto": + criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) + else: + raise Exception("The %s not is a loss supported" % c.loss) + return criterion + + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + try: + self.load_state_dict(state["model"]) + print(" > Model fully restored. ") + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"], c) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + print(" > Criterion load ignored because of:", error) + + # instance and load the criterion for the encoder classifier in inference time + if ( + eval + and criterion is None + and "criterion" in state + and getattr(config, "map_classid_to_classname", None) is not None + ): + criterion = self.get_criterion(config, len(config.map_classid_to_classname)) + criterion.load_state_dict(state["criterion"]) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion diff --git a/viXTTS/TTS/encoder/models/lstm.py b/viXTTS/TTS/encoder/models/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..51852b5b820d181824b0db1a205cd5d7bd4fb20d --- /dev/null +++ b/viXTTS/TTS/encoder/models/lstm.py @@ -0,0 +1,99 @@ +import torch +from torch import nn + +from TTS.encoder.models.base_encoder import BaseEncoder + + +class LSTMWithProjection(nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder(BaseEncoder): + def __init__( + self, + input_dim, + proj_dim=256, + lstm_dim=768, + num_lstm_layers=3, + use_lstm_with_projection=True, + use_torch_spec=False, + audio_config=None, + ): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0.0) + elif "weight" in name: + nn.init.xavier_normal_(param) + + def forward(self, x, l2_norm=True): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x.squeeze_(1) + x = self.torch_spec(x) + x = self.instancenorm(x).transpose(1, 2) + d = self.layers(x) + if self.use_lstm_with_projection: + d = d[:, -1] + if l2_norm: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d diff --git a/viXTTS/TTS/encoder/models/resnet.py b/viXTTS/TTS/encoder/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5eafcd6005739fcdc454fb20def3e66791766a53 --- /dev/null +++ b/viXTTS/TTS/encoder/models/resnet.py @@ -0,0 +1,198 @@ +import torch +from torch import nn + +# from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.encoder.models.base_encoder import BaseEncoder + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class ResNetSpeakerEncoder(BaseEncoder): + """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 + Adapted from: https://github.com/clovaai/voxceleb_trainer + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x diff --git a/viXTTS/TTS/encoder/requirements.txt b/viXTTS/TTS/encoder/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a486cc45ddb44591bd03c9c0df294fbe98c13884 --- /dev/null +++ b/viXTTS/TTS/encoder/requirements.txt @@ -0,0 +1,2 @@ +umap-learn +numpy>=1.17.0 diff --git a/viXTTS/TTS/encoder/utils/__init__.py b/viXTTS/TTS/encoder/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/encoder/utils/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/encoder/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e286431eb1743f3bbf2d8575c07dd8758929bf2 Binary files /dev/null and b/viXTTS/TTS/encoder/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/encoder/utils/__pycache__/generic_utils.cpython-310.pyc b/viXTTS/TTS/encoder/utils/__pycache__/generic_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e8bdf7b88fa6f5319d141e7145fef3926bf9df9 Binary files /dev/null and b/viXTTS/TTS/encoder/utils/__pycache__/generic_utils.cpython-310.pyc differ diff --git a/viXTTS/TTS/encoder/utils/generic_utils.py b/viXTTS/TTS/encoder/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..236d6fe937a9637fd86f06bea5fb45fad4ee2502 --- /dev/null +++ b/viXTTS/TTS/encoder/utils/generic_utils.py @@ -0,0 +1,136 @@ +import glob +import os +import random + +import numpy as np +from scipy import signal + +from TTS.encoder.models.lstm import LSTMSpeakerEncoder +from TTS.encoder.models.resnet import ResNetSpeakerEncoder + + +class AugmentWAV(object): + def __init__(self, ap, augmentation_config): + self.ap = ap + self.use_additive_noise = False + + if "additive" in augmentation_config.keys(): + self.additive_noise_config = augmentation_config["additive"] + additive_path = self.additive_noise_config["sounds_path"] + if additive_path: + self.use_additive_noise = True + # get noise types + self.additive_noise_types = [] + for key in self.additive_noise_config.keys(): + if isinstance(self.additive_noise_config[key], dict): + self.additive_noise_types.append(key) + + additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) + + self.noise_list = {} + + for wav_file in additive_files: + noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] + # ignore not listed directories + if noise_dir not in self.additive_noise_types: + continue + if not noise_dir in self.noise_list: + self.noise_list[noise_dir] = [] + self.noise_list[noise_dir].append(wav_file) + + print( + f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" + ) + + self.use_rir = False + + if "rir" in augmentation_config.keys(): + self.rir_config = augmentation_config["rir"] + if self.rir_config["rir_path"]: + self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) + self.use_rir = True + + print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") + + self.create_augmentation_global_list() + + def create_augmentation_global_list(self): + if self.use_additive_noise: + self.global_noise_list = self.additive_noise_types + else: + self.global_noise_list = [] + if self.use_rir: + self.global_noise_list.append("RIR_AUG") + + def additive_noise(self, noise_type, audio): + clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) + + noise_list = random.sample( + self.noise_list[noise_type], + random.randint( + self.additive_noise_config[noise_type]["min_num_noises"], + self.additive_noise_config[noise_type]["max_num_noises"], + ), + ) + + audio_len = audio.shape[0] + noises_wav = None + for noise in noise_list: + noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len] + + if noiseaudio.shape[0] < audio_len: + continue + + noise_snr = random.uniform( + self.additive_noise_config[noise_type]["min_snr_in_db"], + self.additive_noise_config[noise_type]["max_num_noises"], + ) + noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4) + noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio + + if noises_wav is None: + noises_wav = noise_wav + else: + noises_wav += noise_wav + + # if all possible files is less than audio, choose other files + if noises_wav is None: + return self.additive_noise(noise_type, audio) + + return audio + noises_wav + + def reverberate(self, audio): + audio_len = audio.shape[0] + + rir_file = random.choice(self.rir_files) + rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) + rir = rir / np.sqrt(np.sum(rir**2)) + return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] + + def apply_one(self, audio): + noise_type = random.choice(self.global_noise_list) + if noise_type == "RIR_AUG": + return self.reverberate(audio) + + return self.additive_noise(noise_type, audio) + + +def setup_encoder_model(config: "Coqpit"): + if config.model_params["model_name"].lower() == "lstm": + model = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + elif config.model_params["model_name"].lower() == "resnet": + model = ResNetSpeakerEncoder( + input_dim=config.model_params["input_dim"], + proj_dim=config.model_params["proj_dim"], + log_input=config.model_params.get("log_input", False), + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + return model diff --git a/viXTTS/TTS/encoder/utils/prepare_voxceleb.py b/viXTTS/TTS/encoder/utils/prepare_voxceleb.py new file mode 100644 index 0000000000000000000000000000000000000000..b93baf9e60f0d5c35a4e86f6746e29f6097174b5 --- /dev/null +++ b/viXTTS/TTS/encoder/utils/prepare_voxceleb.py @@ -0,0 +1,219 @@ +# coding=utf-8 +# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Only support eager mode and TF>=2.0.0 +# pylint: disable=no-member, invalid-name, relative-beyond-top-level +# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes +""" voxceleb 1 & 2 """ + +import hashlib +import os +import subprocess +import sys +import zipfile + +import pandas +import soundfile as sf +from absl import logging + +SUBSETS = { + "vox1_dev_wav": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", + ], + "vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], +} + +MD5SUM = { + "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", + "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", + "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", +} + +USER = {"user": "", "password": ""} + +speaker_id_dict = {} + + +def download_and_extract(directory, subset, urls): + """Download and extract the given split of dataset. + + Args: + directory: the directory where to put the downloaded data. + subset: subset name of the corpus. + urls: the list of urls to download the data file. + """ + os.makedirs(directory, exist_ok=True) + + try: + for url in urls: + zip_filepath = os.path.join(directory, url.split("/")[-1]) + if os.path.exists(zip_filepath): + continue + logging.info("Downloading %s to %s" % (url, zip_filepath)) + subprocess.call( + "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), + shell=True, + ) + + statinfo = os.stat(zip_filepath) + logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + + # concatenate all parts into zip files + if ".zip" not in zip_filepath: + zip_filepath = "_".join(zip_filepath.split("_")[:-1]) + subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) + zip_filepath += ".zip" + extract_path = zip_filepath.strip(".zip") + + # check zip file md5sum + with open(zip_filepath, "rb") as f_zip: + md5 = hashlib.md5(f_zip.read()).hexdigest() + if md5 != MD5SUM[subset]: + raise ValueError("md5sum of %s mismatch" % zip_filepath) + + with zipfile.ZipFile(zip_filepath, "r") as zfile: + zfile.extractall(directory) + extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) + subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) + finally: + # os.remove(zip_filepath) + pass + + +def exec_cmd(cmd): + """Run a command in a subprocess. + Args: + cmd: command line to be executed. + Return: + int, the return code. + """ + try: + retcode = subprocess.call(cmd, shell=True) + if retcode < 0: + logging.info(f"Child was terminated by signal {retcode}") + except OSError as e: + logging.info(f"Execution failed: {e}") + retcode = -999 + return retcode + + +def decode_aac_with_ffmpeg(aac_file, wav_file): + """Decode a given AAC file into WAV using ffmpeg. + Args: + aac_file: file path to input AAC file. + wav_file: file path to output WAV file. + Return: + bool, True if success. + """ + cmd = f"ffmpeg -i {aac_file} {wav_file}" + logging.info(f"Decoding aac file using command line: {cmd}") + ret = exec_cmd(cmd) + if ret != 0: + logging.error(f"Failed to decode aac file with retcode {ret}") + logging.error("Please check your ffmpeg installation.") + return False + return True + + +def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): + """Optionally convert AAC to WAV and make speaker labels. + Args: + input_dir: the directory which holds the input dataset. + subset: the name of the specified subset. e.g. vox1_dev_wav + output_dir: the directory to place the newly generated csv files. + output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv + """ + + logging.info("Preprocessing audio and label for subset %s" % subset) + source_dir = os.path.join(input_dir, subset) + + files = [] + # Convert all AAC file into WAV format. At the same time, generate the csv + for root, _, filenames in os.walk(source_dir): + for filename in filenames: + name, ext = os.path.splitext(filename) + if ext.lower() == ".wav": + _, ext2 = os.path.splitext(name) + if ext2: + continue + wav_file = os.path.join(root, filename) + elif ext.lower() == ".m4a": + # Convert AAC to WAV. + aac_file = os.path.join(root, filename) + wav_file = aac_file + ".wav" + if not os.path.exists(wav_file): + if not decode_aac_with_ffmpeg(aac_file, wav_file): + raise RuntimeError("Audio decoding failed.") + else: + continue + speaker_name = root.split(os.path.sep)[-2] + if speaker_name not in speaker_id_dict: + num = len(speaker_id_dict) + speaker_id_dict[speaker_name] = num + # wav_filesize = os.path.getsize(wav_file) + wav_length = len(sf.read(wav_file)[0]) + files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) + + # Write to CSV file which contains four columns: + # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". + csv_file_path = os.path.join(output_dir, output_file) + df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + df.to_csv(csv_file_path, index=False, sep="\t") + logging.info("Successfully generated csv file {}".format(csv_file_path)) + + +def processor(directory, subset, force_process): + """download and process""" + urls = SUBSETS + if subset not in urls: + raise ValueError(subset, "is not in voxceleb") + + subset_csv = os.path.join(directory, subset + ".csv") + if not force_process and os.path.exists(subset_csv): + return subset_csv + + logging.info("Downloading and process the voxceleb in %s", directory) + logging.info("Preparing subset %s", subset) + download_and_extract(directory, subset, urls[subset]) + convert_audio_and_make_label(directory, subset, directory, subset + ".csv") + logging.info("Finished downloading and processing") + return subset_csv + + +if __name__ == "__main__": + logging.set_verbosity(logging.INFO) + if len(sys.argv) != 4: + print("Usage: python prepare_data.py save_directory user password") + sys.exit() + + DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] + for SUBSET in SUBSETS: + processor(DIR, SUBSET, False) diff --git a/viXTTS/TTS/encoder/utils/training.py b/viXTTS/TTS/encoder/utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8f271d80c40ff8fa5bbb824615c19d0f99d19d --- /dev/null +++ b/viXTTS/TTS/encoder/utils/training.py @@ -0,0 +1,99 @@ +import os +from dataclasses import dataclass, field + +from coqpit import Coqpit +from trainer import TrainerArgs, get_last_checkpoint +from trainer.io import copy_model_files +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger + +from TTS.config import load_config, register_config +from TTS.tts.utils.text.characters import parse_symbols +from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch + + +@dataclass +class TrainArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def getarguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def process_args(args, config=None): + """Process parsed comand line arguments and initialize the config if not provided. + Args: + args (argparse.Namespace or dict like): Parsed input arguments. + config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. + Returns: + c (TTS.utils.io.AttrDict): Config paramaters. + out_path (str): Path to save models and logging. + audio_path (str): Path to save generated test audios. + c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does + logging to the console. + dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging + TODO: + - Interactive config definition. + """ + if isinstance(args, tuple): + args, coqpit_overrides = args + if args.continue_path: + # continue a previous training from its output folder + experiment_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + args.restore_path, best_model = get_last_checkpoint(args.continue_path) + if not args.best_path: + args.best_path = best_model + # init config if not already defined + if config is None: + if args.config_path: + # init from a file + config = load_config(args.config_path) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(coqpit_overrides) + config = register_config(config_base.model)() + # override values from command-line args + config.parse_known_args(coqpit_overrides, relaxed_parser=True) + experiment_path = args.continue_path + if not experiment_path: + experiment_path = get_experiment_folder_path(config.output_path, config.run_name) + audio_path = os.path.join(experiment_path, "test_audios") + config.output_log_path = experiment_path + # setup rank 0 process in distributed training + dashboard_logger = None + if args.rank == 0: + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + # if model characters are not set in the config file + # save the default set to the config file for future + # compatibility. + if config.has("characters") and config.characters is None: + used_characters = parse_symbols() + new_fields["characters"] = used_characters + copy_model_files(config, experiment_path, new_fields) + dashboard_logger = logger_factory(config, experiment_path) + c_logger = ConsoleLogger() + return config, experiment_path, audio_path, c_logger, dashboard_logger + + +def init_arguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def init_training(config: Coqpit = None): + """Initialization of a training run.""" + parser = init_arguments() + args = parser.parse_known_args() + config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config) + return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger diff --git a/viXTTS/TTS/encoder/utils/visual.py b/viXTTS/TTS/encoder/utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..6575b86ec22818fe1dc0c1e6336a7fd255855330 --- /dev/null +++ b/viXTTS/TTS/encoder/utils/visual.py @@ -0,0 +1,50 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import umap + +matplotlib.use("Agg") + + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=float, + ) + / 255 +) + + +def plot_embeddings(embeddings, num_classes_in_batch): + num_utter_per_class = embeddings.shape[0] // num_classes_in_batch + + # if necessary get just the first 10 classes + if num_classes_in_batch > 10: + num_classes_in_batch = 10 + embeddings = embeddings[: num_classes_in_batch * num_utter_per_class] + + model = umap.UMAP() + projection = model.fit_transform(embeddings) + ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class) + colors = [colormap[i] for i in ground_truth] + fig, ax = plt.subplots(figsize=(16, 10)) + _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection") + plt.tight_layout() + plt.savefig("umap") + return fig diff --git a/viXTTS/TTS/model.py b/viXTTS/TTS/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6be7b444695756c00c4faa8f2f6c787dfcf9d8 --- /dev/null +++ b/viXTTS/TTS/model.py @@ -0,0 +1,59 @@ +from abc import abstractmethod +from typing import Dict + +import torch +from coqpit import Coqpit +from trainer import TrainerModel + +# pylint: skip-file + + +class BaseTrainerModel(TrainerModel): + """BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS. + + Every new 🐸TTS model must inherit it. + """ + + @staticmethod + @abstractmethod + def init_from_config(config: Coqpit): + """Init the model and all its attributes from the given config. + + Override this depending on your model. + """ + ... + + @abstractmethod + def inference(self, input: torch.Tensor, aux_input={}) -> Dict: + """Forward pass for inference. + + It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` + is considered to be the main output and you can add any other auxiliary outputs as you want. + + We don't use `*kwargs` since it is problematic with the TorchScript API. + + Args: + input (torch.Tensor): [description] + aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. + + Returns: + Dict: [description] + """ + outputs_dict = {"model_outputs": None} + ... + return outputs_dict + + @abstractmethod + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ) -> None: + """Load a model checkpoint gile and get ready for training or inference. + + Args: + config (Coqpit): Model configuration. + checkpoint_path (str): Path to the model checkpoint file. + eval (bool, optional): If true, init model for inference else for training. Defaults to False. + strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. + cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. + """ + ... diff --git a/viXTTS/TTS/server/README.md b/viXTTS/TTS/server/README.md new file mode 100644 index 0000000000000000000000000000000000000000..270656c4e39dc11636efbb1ba51eba7c9b4a8f04 --- /dev/null +++ b/viXTTS/TTS/server/README.md @@ -0,0 +1,18 @@ +# :frog: TTS demo server +Before you use the server, make sure you [install](https://github.com/coqui-ai/TTS/tree/dev#install-tts)) :frog: TTS properly. Then, you can follow the steps below. + +**Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` end point on the terminal. + +Examples runs: + +List officially released models. +```python TTS/server/server.py --list_models ``` + +Run the server with the official models. +```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan``` + +Run the server with the official models on a GPU. +```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda True``` + +Run the server with a custom models. +```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth --vocoder_config /path/to/vocoder/config.json``` diff --git a/viXTTS/TTS/server/__init__.py b/viXTTS/TTS/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/server/conf.json b/viXTTS/TTS/server/conf.json new file mode 100644 index 0000000000000000000000000000000000000000..49b6c09c3848a224dfb39a1f653aa1b289a4b6e5 --- /dev/null +++ b/viXTTS/TTS/server/conf.json @@ -0,0 +1,12 @@ +{ + "tts_path":"/media/erogol/data_ssd/Models/libri_tts/5049/", // tts model root folder + "tts_file":"best_model.pth", // tts checkpoint file + "tts_config":"config.json", // tts config.json file + "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. + "vocoder_config":null, + "vocoder_file": null, + "is_wavernn_batched":true, + "port": 5002, + "use_cuda": true, + "debug": true +} diff --git a/viXTTS/TTS/server/server.py b/viXTTS/TTS/server/server.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2141a9aa419b9095956ccae317621fa3a604da --- /dev/null +++ b/viXTTS/TTS/server/server.py @@ -0,0 +1,258 @@ +#!flask/bin/python +import argparse +import io +import json +import os +import sys +from pathlib import Path +from threading import Lock +from typing import Union +from urllib.parse import parse_qs + +from flask import Flask, render_template, render_template_string, request, send_file + +from TTS.config import load_config +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer + + +def create_argparser(): + def convert_boolean(x): + return x.lower() in ["true", "1", "yes"] + + parser = argparse.ArgumentParser() + parser.add_argument( + "--list_models", + type=convert_boolean, + nargs="?", + const=True, + default=False, + help="list available pre-trained tts and vocoder models.", + ) + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/tacotron2-DDC", + help="Name of one of the pre-trained tts models in format //", + ) + parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.") + + # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model file.", + ) + parser.add_argument( + "--vocoder_path", + type=str, + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", + default=None, + ) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--port", type=int, default=5002, help="port to listen on.") + parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") + parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") + parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.") + return parser + + +# parse the args +args = create_argparser().parse_args() + +path = Path(__file__).parent / "../.models.json" +manager = ModelManager(path) + +if args.list_models: + manager.list_models() + sys.exit() + +# update in-use models to the specified released models. +model_path = None +config_path = None +speakers_file_path = None +vocoder_path = None +vocoder_config_path = None + +# CASE1: list pre-trained TTS models +if args.list_models: + manager.list_models() + sys.exit() + +# CASE2: load pre-trained model paths +if args.model_name is not None and not args.model_path: + model_path, config_path, model_item = manager.download_model(args.model_name) + args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + +if args.vocoder_name is not None and not args.vocoder_path: + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) + +# CASE3: set custom model paths +if args.model_path is not None: + model_path = args.model_path + config_path = args.config_path + speakers_file_path = args.speakers_file_path + +if args.vocoder_path is not None: + vocoder_path = args.vocoder_path + vocoder_config_path = args.vocoder_config_path + +# load models +synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint="", + encoder_config="", + use_cuda=args.use_cuda, +) + +use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( + synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None +) +speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) + +use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and ( + synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None +) +language_manager = getattr(synthesizer.tts_model, "language_manager", None) + +# TODO: set this from SpeakerManager +use_gst = synthesizer.tts_config.get("use_gst", False) +app = Flask(__name__) + + +def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]: + """Transform an uri style_wav, in either a string (path to wav file to be use for style transfer) + or a dict (gst tokens/values to be use for styling) + + Args: + style_wav (str): uri + + Returns: + Union[str, dict]: path to file (str) or gst style (dict) + """ + if style_wav: + if os.path.isfile(style_wav) and style_wav.endswith(".wav"): + return style_wav # style_wav is a .wav file located on the server + + style_wav = json.loads(style_wav) + return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...} + return None + + +@app.route("/") +def index(): + return render_template( + "index.html", + show_details=args.show_details, + use_multi_speaker=use_multi_speaker, + use_multi_language=use_multi_language, + speaker_ids=speaker_manager.name_to_id if speaker_manager is not None else None, + language_ids=language_manager.name_to_id if language_manager is not None else None, + use_gst=use_gst, + ) + + +@app.route("/details") +def details(): + if args.config_path is not None and os.path.isfile(args.config_path): + model_config = load_config(args.config_path) + else: + if args.model_name is not None: + model_config = load_config(config_path) + + if args.vocoder_config_path is not None and os.path.isfile(args.vocoder_config_path): + vocoder_config = load_config(args.vocoder_config_path) + else: + if args.vocoder_name is not None: + vocoder_config = load_config(vocoder_config_path) + else: + vocoder_config = None + + return render_template( + "details.html", + show_details=args.show_details, + model_config=model_config, + vocoder_config=vocoder_config, + args=args.__dict__, + ) + + +lock = Lock() + + +@app.route("/api/tts", methods=["GET", "POST"]) +def tts(): + with lock: + text = request.headers.get("text") or request.values.get("text", "") + speaker_idx = request.headers.get("speaker-id") or request.values.get("speaker_id", "") + language_idx = request.headers.get("language-id") or request.values.get("language_id", "") + style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") + style_wav = style_wav_uri_to_dict(style_wav) + + print(f" > Model input: {text}") + print(f" > Speaker Idx: {speaker_idx}") + print(f" > Language Idx: {language_idx}") + wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + return send_file(out, mimetype="audio/wav") + + +# Basic MaryTTS compatibility layer + + +@app.route("/locales", methods=["GET"]) +def mary_tts_api_locales(): + """MaryTTS-compatible /locales endpoint""" + # NOTE: We currently assume there is only one model active at the same time + if args.model_name is not None: + model_details = args.model_name.split("/") + else: + model_details = ["", "en", "", "default"] + return render_template_string("{{ locale }}\n", locale=model_details[1]) + + +@app.route("/voices", methods=["GET"]) +def mary_tts_api_voices(): + """MaryTTS-compatible /voices endpoint""" + # NOTE: We currently assume there is only one model active at the same time + if args.model_name is not None: + model_details = args.model_name.split("/") + else: + model_details = ["", "en", "", "default"] + return render_template_string( + "{{ name }} {{ locale }} {{ gender }}\n", name=model_details[3], locale=model_details[1], gender="u" + ) + + +@app.route("/process", methods=["GET", "POST"]) +def mary_tts_api_process(): + """MaryTTS-compatible /process endpoint""" + with lock: + if request.method == "POST": + data = parse_qs(request.get_data(as_text=True)) + # NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model + text = data.get("INPUT_TEXT", [""])[0] + else: + text = request.args.get("INPUT_TEXT", "") + print(f" > Model input: {text}") + wavs = synthesizer.tts(text) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + return send_file(out, mimetype="audio/wav") + + +def main(): + app.run(debug=args.debug, host="::", port=args.port) + + +if __name__ == "__main__": + main() diff --git a/viXTTS/TTS/server/static/coqui-log-green-TTS.png b/viXTTS/TTS/server/static/coqui-log-green-TTS.png new file mode 100644 index 0000000000000000000000000000000000000000..6ad188b8c03a170097c0393c6769996f03cf9054 Binary files /dev/null and b/viXTTS/TTS/server/static/coqui-log-green-TTS.png differ diff --git a/viXTTS/TTS/server/templates/details.html b/viXTTS/TTS/server/templates/details.html new file mode 100644 index 0000000000000000000000000000000000000000..51c9ed85a83ac0aab045623ee1e6c430fbe51b9d --- /dev/null +++ b/viXTTS/TTS/server/templates/details.html @@ -0,0 +1,131 @@ + + + + + + + + + + + TTS engine + + + + + + + + + + Fork me on GitHub + + {% if show_details == true %} + +
+ Model details +
+ +
+
+ CLI arguments: + + + + + + + {% for key, value in args.items() %} + + + + + + + {% endfor %} +
CLI key Value
{{ key }}{{ value }}
+
+

+ +
+ + {% if model_config != None %} + +
+ Model config: + + + + + + + + + {% for key, value in model_config.items() %} + + + + + + + {% endfor %} + +
Key Value
{{ key }}{{ value }}
+
+ + {% endif %} + +

+ + + +
+ {% if vocoder_config != None %} +
+ Vocoder model config: + + + + + + + + + {% for key, value in vocoder_config.items() %} + + + + + + + {% endfor %} + + +
Key Value
{{ key }}{{ value }}
+
+ {% endif %} +

+ + {% else %} +
+ Please start server with --show_details=true to see details. +
+ + {% endif %} + + + + \ No newline at end of file diff --git a/viXTTS/TTS/server/templates/index.html b/viXTTS/TTS/server/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..6354d3919d9a1e9c1e22e9866c84c4eb8284bc13 --- /dev/null +++ b/viXTTS/TTS/server/templates/index.html @@ -0,0 +1,154 @@ + + + + + + + + + + + TTS engine + + + + + + + + + + Fork me on GitHub + + + + + +
+
+
+ + +
    +
+ + {%if use_gst%} + + {%endif%} + + +

+ + {%if use_multi_speaker%} + Choose a speaker: +

+ {%endif%} + + {%if use_multi_language%} + Choose a language: +

+ {%endif%} + + + {%if show_details%} +

+ {%endif%} + +

+
+
+
+ + + + + + + \ No newline at end of file diff --git a/viXTTS/TTS/tts/__init__.py b/viXTTS/TTS/tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/tts/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e6f28e63fcef3011877fd31b994b170de67234 Binary files /dev/null and b/viXTTS/TTS/tts/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/configs/__init__.py b/viXTTS/TTS/tts/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3146ac1c116cb807a81889b7a9ab223b9a051036 --- /dev/null +++ b/viXTTS/TTS/tts/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +# configs_dir = os.path.dirname(__file__) +# for file in os.listdir(configs_dir): +# path = os.path.join(configs_dir, file) +# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): +# config_name = file[: file.find(".py")] if file.endswith(".py") else file +# module = importlib.import_module("TTS.tts.configs." + config_name) +# for attribute_name in dir(module): +# attribute = getattr(module, attribute_name) + +# if isclass(attribute): +# # Add the class to this package's variables +# globals()[attribute_name] = attribute diff --git a/viXTTS/TTS/tts/configs/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/tts/configs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f6b8c91b33038a5fa56244b380dd5dbda681627 Binary files /dev/null and b/viXTTS/TTS/tts/configs/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/configs/__pycache__/shared_configs.cpython-310.pyc b/viXTTS/TTS/tts/configs/__pycache__/shared_configs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0502e3ee1dee7550743d1f333ad41d0527e6748 Binary files /dev/null and b/viXTTS/TTS/tts/configs/__pycache__/shared_configs.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/configs/__pycache__/xtts_config.cpython-310.pyc b/viXTTS/TTS/tts/configs/__pycache__/xtts_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27ee39d0cc4173b2b0bb35daee37bb3850cdb650 Binary files /dev/null and b/viXTTS/TTS/tts/configs/__pycache__/xtts_config.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/configs/align_tts_config.py b/viXTTS/TTS/tts/configs/align_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..317a01af53ce26914d83610a913eb44b5836dac2 --- /dev/null +++ b/viXTTS/TTS/tts/configs/align_tts_config.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.align_tts import AlignTTSArgs + + +@dataclass +class AlignTTSConfig(BaseTTSConfig): + """Defines parameters for AlignTTS model. + Example: + + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `align_tts`. + positional_encoding (bool): + enable / disable positional encoding applied to the encoder output. Defaults to True. + hidden_channels (int): + Base number of hidden channels. Defines all the layers expect ones defined by the specific encoder or decoder + parameters. Defaults to 256. + hidden_channels_dp (int): + Number of hidden channels of the duration predictor's layers. Defaults to 256. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `fftransformer`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + decoder_type (str): + Type of the decoder used by the model. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `fftransformer`. + decoder_params (dict): + Parameters used to define the decoder network. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + phase_start_steps (List[int]): + A list of number of steps required to start the next training phase. AlignTTS has 4 different training + phases. Thus you need to define 4 different values to enable phase based training. If None, it + trains the whole model together. Defaults to None. + ssim_alpha (float): + Weight for the SSIM loss. If set <= 0, disables the SSIM loss. Defaults to 1.0. + duration_loss_alpha (float): + Weight for the duration predictor's loss. Defaults to 1.0. + mdn_alpha (float): + Weight for the MDN loss. Defaults to 1.0. + spec_loss_alpha (float): + Weight for the MSE spectrogram loss. If set <= 0, disables the L1 loss. Defaults to 1.0. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage.""" + + model: str = "align_tts" + # model specific params + model_args: AlignTTSArgs = field(default_factory=AlignTTSArgs) + phase_start_steps: List[int] = None + + ssim_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + mdn_alpha: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = None + lr_scheduler_params: dict = None + lr: float = 1e-4 + grad_clip: float = 5.0 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/viXTTS/TTS/tts/configs/bark_config.py b/viXTTS/TTS/tts/configs/bark_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1cd1374afe8d5f0b9e87ed81db25d7e4032af9 --- /dev/null +++ b/viXTTS/TTS/tts/configs/bark_config.py @@ -0,0 +1,105 @@ +import os +from dataclasses import dataclass, field +from typing import Dict + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.layers.bark.model import GPTConfig +from TTS.tts.layers.bark.model_fine import FineGPTConfig +from TTS.tts.models.bark import BarkAudioConfig +from TTS.utils.generic_utils import get_user_data_dir + + +@dataclass +class BarkConfig(BaseTTSConfig): + """Bark TTS configuration + + Args: + model (str): model name that registers the model. + audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig(). + num_chars (int): number of characters in the alphabet. Defaults to 0. + semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig(). + fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig(). + coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig(). + CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024. + SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9. + SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000. + CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024. + N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2. + N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8. + COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75. + SAMPLE_RATE (int): sample rate. Defaults to 24_000. + USE_SMALLER_MODELS (bool): use smaller models. Defaults to False. + TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048. + SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000. + TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048. + TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049. + TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050. + SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051. + COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048. + COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050. + REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree". + REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None. + LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None. + SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None. + CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir(). + DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir(). + """ + + model: str = "bark" + audio: BarkAudioConfig = field(default_factory=BarkAudioConfig) + num_chars: int = 0 + semantic_config: GPTConfig = field(default_factory=GPTConfig) + fine_config: FineGPTConfig = field(default_factory=FineGPTConfig) + coarse_config: GPTConfig = field(default_factory=GPTConfig) + CONTEXT_WINDOW_SIZE: int = 1024 + SEMANTIC_RATE_HZ: float = 49.9 + SEMANTIC_VOCAB_SIZE: int = 10_000 + CODEBOOK_SIZE: int = 1024 + N_COARSE_CODEBOOKS: int = 2 + N_FINE_CODEBOOKS: int = 8 + COARSE_RATE_HZ: int = 75 + SAMPLE_RATE: int = 24_000 + USE_SMALLER_MODELS: bool = False + + TEXT_ENCODING_OFFSET: int = 10_048 + SEMANTIC_PAD_TOKEN: int = 10_000 + TEXT_PAD_TOKEN: int = 129_595 + SEMANTIC_INFER_TOKEN: int = 129_599 + COARSE_SEMANTIC_PAD_TOKEN: int = 12_048 + COARSE_INFER_TOKEN: int = 12_050 + + REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/" + REMOTE_MODEL_PATHS: Dict = None + LOCAL_MODEL_PATHS: Dict = None + SMALL_REMOTE_MODEL_PATHS: Dict = None + CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0")) + DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers")) + + def __post_init__(self): + self.REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, + } + self.LOCAL_MODEL_PATHS = { + "text": os.path.join(self.CACHE_DIR, "text_2.pt"), + "coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"), + "fine": os.path.join(self.CACHE_DIR, "fine_2.pt"), + "hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"), + "hubert": os.path.join(self.CACHE_DIR, "hubert.pt"), + } + self.SMALL_REMOTE_MODEL_PATHS = { + "text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")}, + "coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")}, + "fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")}, + } + self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init diff --git a/viXTTS/TTS/tts/configs/delightful_tts_config.py b/viXTTS/TTS/tts/configs/delightful_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..805d995369e29fce7d6aa87750356b21458cd64a --- /dev/null +++ b/viXTTS/TTS/tts/configs/delightful_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig + + +@dataclass +class DelightfulTTSConfig(BaseTTSConfig): + """ + Configuration class for the DelightfulTTS model. + + Attributes: + model (str): Name of the model ("delightful_tts"). + audio (DelightfulTtsAudioConfig): Configuration for audio settings. + model_args (DelightfulTtsArgs): Configuration for model arguments. + use_attn_priors (bool): Whether to use attention priors. + vocoder (VocoderConfig): Configuration for the vocoder. + init_discriminator (bool): Whether to initialize the discriminator. + steps_to_start_discriminator (int): Number of steps to start the discriminator. + grad_clip (List[float]): Gradient clipping values. + lr_gen (float): Learning rate for the gan generator. + lr_disc (float): Learning rate for the gan discriminator. + lr_scheduler_gen (str): Name of the learning rate scheduler for the generator. + lr_scheduler_gen_params (dict): Parameters for the learning rate scheduler for the generator. + lr_scheduler_disc (str): Name of the learning rate scheduler for the discriminator. + lr_scheduler_disc_params (dict): Parameters for the learning rate scheduler for the discriminator. + scheduler_after_epoch (bool): Whether to schedule after each epoch. + optimizer (str): Name of the optimizer. + optimizer_params (dict): Parameters for the optimizer. + ssim_loss_alpha (float): Alpha value for the SSIM loss. + mel_loss_alpha (float): Alpha value for the mel loss. + aligner_loss_alpha (float): Alpha value for the aligner loss. + pitch_loss_alpha (float): Alpha value for the pitch loss. + energy_loss_alpha (float): Alpha value for the energy loss. + u_prosody_loss_alpha (float): Alpha value for the utterance prosody loss. + p_prosody_loss_alpha (float): Alpha value for the phoneme prosody loss. + dur_loss_alpha (float): Alpha value for the duration loss. + char_dur_loss_alpha (float): Alpha value for the character duration loss. + binary_align_loss_alpha (float): Alpha value for the binary alignment loss. + binary_loss_warmup_epochs (int): Number of warm-up epochs for the binary loss. + disc_loss_alpha (float): Alpha value for the discriminator loss. + gen_loss_alpha (float): Alpha value for the generator loss. + feat_loss_alpha (float): Alpha value for the feature loss. + vocoder_mel_loss_alpha (float): Alpha value for the vocoder mel loss. + multi_scale_stft_loss_alpha (float): Alpha value for the multi-scale STFT loss. + multi_scale_stft_loss_params (dict): Parameters for the multi-scale STFT loss. + return_wav (bool): Whether to return audio waveforms. + use_weighted_sampler (bool): Whether to use a weighted sampler. + weighted_sampler_attrs (dict): Attributes for the weighted sampler. + weighted_sampler_multipliers (dict): Multipliers for the weighted sampler. + r (int): Value for the `r` override. + compute_f0 (bool): Whether to compute F0 values. + f0_cache_path (str): Path to the F0 cache. + attn_prior_cache_path (str): Path to the attention prior cache. + num_speakers (int): Number of speakers. + use_speaker_embedding (bool): Whether to use speaker embedding. + speakers_file (str): Path to the speaker file. + speaker_embedding_channels (int): Number of channels for the speaker embedding. + language_ids_file (str): Path to the language IDs file. + """ + + model: str = "delightful_tts" + + # model specific params + audio: DelightfulTtsAudioConfig = field(default_factory=DelightfulTtsAudioConfig) + model_args: DelightfulTtsArgs = field(default_factory=DelightfulTtsArgs) + use_attn_priors: bool = True + + # vocoder + vocoder: VocoderConfig = field(default_factory=VocoderConfig) + init_discriminator: bool = True + + # optimizer + steps_to_start_discriminator: int = 200000 + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # acoustic model loss params + ssim_loss_alpha: float = 1.0 + mel_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + energy_loss_alpha: float = 1.0 + u_prosody_loss_alpha: float = 0.5 + p_prosody_loss_alpha: float = 0.5 + dur_loss_alpha: float = 1.0 + char_dur_loss_alpha: float = 0.01 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 10 + + # vocoder loss params + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + vocoder_mel_loss_alpha: float = 10.0 + multi_scale_stft_loss_alpha: float = 2.5 + multi_scale_stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # data loader params + return_wav: bool = True + use_weighted_sampler: bool = False + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + attn_prior_cache_path: str = None + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: str = None + d_vector_dim: int = None + + # testing + test_sentences: List[List[str]] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/viXTTS/TTS/tts/configs/fast_pitch_config.py b/viXTTS/TTS/tts/configs/fast_pitch_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d086d26564450c60fa04a7f3a068506f4147d3be --- /dev/null +++ b/viXTTS/TTS/tts/configs/fast_pitch_config.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastPitchConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fast_pitch_config import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None + """ + + model: str = "fast_pitch" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/viXTTS/TTS/tts/configs/fast_speech_config.py b/viXTTS/TTS/tts/configs/fast_speech_config.py new file mode 100644 index 0000000000000000000000000000000000000000..af6c2db6faf55ee2b15047fff86281d42dab1b87 --- /dev/null +++ b/viXTTS/TTS/tts/configs/fast_speech_config.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastSpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastSpeech model. + + Example: + + >>> from TTS.tts.configs.fast_speech_config import FastSpeechConfig + >>> config = FastSpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastSpeechArgs` for more details. Defaults to `FastSpeechArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_speech" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=False)) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/viXTTS/TTS/tts/configs/fastspeech2_config.py b/viXTTS/TTS/tts/configs/fastspeech2_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d179617fb034fff269355ce7e3d78b67db90aacd --- /dev/null +++ b/viXTTS/TTS/tts/configs/fastspeech2_config.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class Fastspeech2Config(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config + >>> config = FastSpeech2Config() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + energy_loss_alpha (float): + Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None + + # dataset configs + compute_energy(bool): + Compute energy. defaults to True + + energy_cache_path(str): + energy cache path. defaults to None + """ + + model: str = "fastspeech2" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=True, use_energy=True)) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + energy_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # dataset configs + compute_energy: bool = True + energy_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/viXTTS/TTS/tts/configs/glow_tts_config.py b/viXTTS/TTS/tts/configs/glow_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f42f3e5a510bacf1b2312ccea7d46201bbcb774f --- /dev/null +++ b/viXTTS/TTS/tts/configs/glow_tts_config.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class GlowTTSConfig(BaseTTSConfig): + """Defines parameters for GlowTTS model. + + Example: + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> config = GlowTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `glow_tts`. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `rel_pos_transformers`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `{"kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "num_heads": 2, "hidden_channels_ffn": 768}` + use_encoder_prenet (bool): + enable / disable the use of a prenet for the encoder. Defaults to True. + hidden_channels_enc (int): + Number of base hidden channels used by the encoder network. It defines the input and the output channel sizes, + and for some encoder types internal hidden channels sizes too. Defaults to 192. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 192 as in the original work. + hidden_channels_dp (int): + Number of layer channels of the duration predictor network. Defaults to 256 as in the original work. + mean_only (bool): + If true predict only the mean values by the decoder flow. Defaults to True. + out_channels (int): + Number of channels of the model output tensor. Defaults to 80. + num_flow_blocks_dec (int): + Number of decoder blocks. Defaults to 12. + inference_noise_scale (float): + Noise scale used at inference. Defaults to 0.33. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_block_layers (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate for decoder. Defaults to 0.1. + num_speaker (int): + Number of speaker to define the size of speaker embedding layer. Defaults to 0. + c_in_channels (int): + Number of speaker embedding channels. It is set to 512 if embeddings are learned. Defaults to 0. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + mean_only (bool): + If True, encoder only computes mean value and uses constant variance for each time step. Defaults to true. + encoder_type (str): + Encoder module type. Possible values are`["rel_pos_transformer", "gated_conv", "residual_conv_bn", "time_depth_separable"]` + Check `TTS.tts.layers.glow_tts.encoder` for more details. Defaults to `rel_pos_transformers` as in the original paper. + encoder_params (dict): + Encoder module parameters. Defaults to None. + d_vector_dim (int): + Channels of external speaker embedding vectors. Defaults to 0. + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + style_wav_for_test (str): + Path to the wav file used for changing the style of the speech. Defaults to None. + inference_noise_scale (float): + Variance used for sampling the random noise added to the decoder's input at inference. Defaults to 0.0. + length_scale (float): + Multiply the predicted durations with this value to change the speech speed. Defaults to 1. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "glow_tts" + + # model params + num_chars: int = None + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) + use_encoder_prenet: bool = True + hidden_channels_enc: int = 192 + hidden_channels_dec: int = 192 + hidden_channels_dp: int = 256 + dropout_p_dp: float = 0.1 + dropout_p_dec: float = 0.05 + mean_only: bool = True + out_channels: int = 80 + num_flow_blocks_dec: int = 12 + inference_noise_scale: float = 0.33 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_block_layers: int = 4 + num_speakers: int = 0 + c_in_channels: int = 0 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + "input_length": None, + } + ) + d_vector_dim: int = 0 + + # training params + data_dep_init_steps: int = 10 + + # inference params + style_wav_for_test: str = None + inference_noise_scale: float = 0.0 + length_scale: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + grad_clip: float = 5.0 + lr: float = 1e-3 + + # overrides + min_seq_len: int = 3 + max_seq_len: int = 500 + r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it. + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/viXTTS/TTS/tts/configs/neuralhmm_tts_config.py b/viXTTS/TTS/tts/configs/neuralhmm_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..50f72847ed3e1c7089915ef8fd77ae5775c5b260 --- /dev/null +++ b/viXTTS/TTS/tts/configs/neuralhmm_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class NeuralhmmTTSConfig(BaseTTSConfig): + """ + Define parameters for Neural HMM TTS model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "NeuralHMM_TTS" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0 + deterministic_transition: bool = True + duration_threshold: float = 0.43 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = True + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.001 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/viXTTS/TTS/tts/configs/overflow_config.py b/viXTTS/TTS/tts/configs/overflow_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3e5548b8f62f76c88acca85d19e2cee8687ebd --- /dev/null +++ b/viXTTS/TTS/tts/configs/overflow_config.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class OverflowConfig(BaseTTSConfig): # The classname has to be camel case + """ + Define parameters for OverFlow model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 150. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_flow_blocks_dec (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate of the decoder. Defaults to 0.05. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + c_in_channels (int): + Unused parameter from GlowTTS's decoder. Defaults to 0. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "Overflow" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0.334 + deterministic_transition: bool = True + duration_threshold: float = 0.55 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = False + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.01 + + # Decoder parameters + hidden_channels_dec: int = 150 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_flow_blocks_dec: int = 12 + num_block_layers: int = 4 + dropout_p_dec: float = 0.05 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + c_in_channels: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/viXTTS/TTS/tts/configs/shared_configs.py b/viXTTS/TTS/tts/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bf17322c190bb234d4e27c6196e53b276fb5f09d --- /dev/null +++ b/viXTTS/TTS/tts/configs/shared_configs.py @@ -0,0 +1,344 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import Coqpit, check_argument + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class GSTConfig(Coqpit): + """Defines the Global Style Token Module + + Args: + gst_style_input_wav (str): + Path to the wav file used to define the style of the output speech at inference. Defaults to None. + + gst_style_input_weights (dict): + Defines the weights for each style token used at inference. Defaults to None. + + gst_embedding_dim (int): + Defines the size of the GST embedding vector dimensions. Defaults to 256. + + gst_num_heads (int): + Number of attention heads used by the multi-head attention. Defaults to 4. + + gst_num_style_tokens (int): + Number of style token vectors. Defaults to 10. + """ + + gst_style_input_wav: str = None + gst_style_input_weights: dict = None + gst_embedding_dim: int = 256 + gst_use_speaker_embedding: bool = False + gst_num_heads: int = 4 + gst_num_style_tokens: int = 10 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("gst_style_input_weights", c, restricted=False) + check_argument("gst_style_input_wav", c, restricted=False) + check_argument("gst_embedding_dim", c, restricted=True, min_val=0, max_val=1000) + check_argument("gst_use_speaker_embedding", c, restricted=False) + check_argument("gst_num_heads", c, restricted=True, min_val=2, max_val=10) + check_argument("gst_num_style_tokens", c, restricted=True, min_val=1, max_val=1000) + + +@dataclass +class CapacitronVAEConfig(Coqpit): + """Defines the capacitron VAE Module + Args: + capacitron_capacity (int): + Defines the variational capacity limit of the prosody embeddings. Defaults to 150. + capacitron_VAE_embedding_dim (int): + Defines the size of the Capacitron embedding vector dimension. Defaults to 128. + capacitron_use_text_summary_embeddings (bool): + If True, use a text summary embedding in Capacitron. Defaults to True. + capacitron_text_summary_embedding_dim (int): + Defines the size of the capacitron text embedding vector dimension. Defaults to 128. + capacitron_use_speaker_embedding (bool): + if True use speaker embeddings in Capacitron. Defaults to False. + capacitron_VAE_loss_alpha (float): + Weight for the VAE loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + capacitron_grad_clip (float): + Gradient clipping value for all gradients except beta. Defaults to 5.0 + """ + + capacitron_loss_alpha: int = 1 + capacitron_capacity: int = 150 + capacitron_VAE_embedding_dim: int = 128 + capacitron_use_text_summary_embeddings: bool = True + capacitron_text_summary_embedding_dim: int = 128 + capacitron_use_speaker_embedding: bool = False + capacitron_VAE_loss_alpha: float = 0.25 + capacitron_grad_clip: float = 5.0 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("capacitron_capacity", c, restricted=True, min_val=10, max_val=500) + check_argument("capacitron_VAE_embedding_dim", c, restricted=True, min_val=16, max_val=1024) + check_argument("capacitron_use_speaker_embedding", c, restricted=False) + check_argument("capacitron_text_summary_embedding_dim", c, restricted=False, min_val=16, max_val=512) + check_argument("capacitron_VAE_loss_alpha", c, restricted=False) + check_argument("capacitron_grad_clip", c, restricted=False) + + +@dataclass +class CharactersConfig(Coqpit): + """Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses. + + Args: + characters_class (str): + Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on + the configuration. Defaults to None. + + vocab_dict (dict): + Defines the vocabulary dictionary used to encode the characters. Defaults to None. + + pad (str): + characters in place of empty padding. Defaults to None. + + eos (str): + characters showing the end of a sentence. Defaults to None. + + bos (str): + characters showing the beginning of a sentence. Defaults to None. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + characters (str): + character set used by the model. Characters not in this list are ignored when converting input text to + a list of sequence IDs. Defaults to None. + + punctuations (str): + characters considered as punctuation as parsing the input sentence. Defaults to None. + + phonemes (str): + characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new + models. Defaults to None. + + is_unique (bool): + remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old + models trained with character lists with duplicates. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + characters_class: str = None + + # using BaseVocabulary + vocab_dict: Dict = None + + # using on BaseCharacters + pad: str = None + eos: str = None + bos: str = None + blank: str = None + characters: str = None + punctuations: str = None + phonemes: str = None + is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates + is_sorted: bool = True + + +@dataclass +class BaseTTSConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + + audio (BaseAudioConfig): + Audio processor config object instance. + + use_phonemes (bool): + enable / disable phoneme use. + + phonemizer (str): + Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`. + Defaults to None. + + phoneme_language (str): + Language code for the phonemizer. You can check the list of supported languages by running + `python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None. + + compute_input_seq_cache (bool): + enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of + the training, It allows faster data loader time and precise limitation with `max_seq_len` and + `min_seq_len`. + + text_cleaner (str): + Name of the text cleaner used for cleaning and formatting transcripts. + + enable_eos_bos_chars (bool): + enable / disable the use of eos and bos characters. + + test_senteces_file (str): + Path to a txt file that has sentences used at test time. The file must have a sentence per line. + + phoneme_cache_path (str): + Path to the output folder caching the computed phonemes for each sample. + + characters (CharactersConfig): + Instance of a CharactersConfig class. + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. + + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). + + compute_f0 (int): + (Not in use yet). + + compute_energy (int): + (Not in use yet). + + compute_linear_spec (bool): + If True data loader computes and returns linear spectrograms alongside the other data. + + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + + use_noise_augment (bool): + Augment the input audio with random noise. + + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + + shuffle (bool): + If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True. + + drop_last (bool): + If True, the data loader will drop the last batch if it is not complete. It helps to prevent + issues that emerge from the partial batch statistics. Defaults to True. + + add_blank (bool): + Add blank characters between each other two characters. It improves performance for some models at expense + of slower run-time due to the longer input sequence. + + datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # phoneme settings + use_phonemes: bool = False + phonemizer: str = None + phoneme_language: str = None + compute_input_seq_cache: bool = False + text_cleaner: str = None + enable_eos_bos_chars: bool = False + test_sentences_file: str = "" + phoneme_cache_path: str = None + # vocabulary parameters + characters: CharactersConfig = None + add_blank: bool = False + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") + compute_f0: bool = False + compute_energy: bool = False + compute_linear_spec: bool = False + precompute_num_workers: int = 0 + use_noise_augment: bool = False + start_by_longest: bool = False + shuffle: bool = False + drop_last: bool = False + # dataset + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = "radam" + optimizer_params: dict = None + # scheduler + lr_scheduler: str = None + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/viXTTS/TTS/tts/configs/speedy_speech_config.py b/viXTTS/TTS/tts/configs/speedy_speech_config.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8517dfc478a135978df19f3126313a616c14c2 --- /dev/null +++ b/viXTTS/TTS/tts/configs/speedy_speech_config.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class SpeedySpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as SpeedySpeech model. + + Example: + + >>> from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig + >>> config = SpeedySpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `speedy_speech`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `RAdam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `l1`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `huber`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "speedy_speech" + base_model: str = "forward_tts" + + # set model args as SpeedySpeech + model_args: ForwardTTSArgs = field( + default_factory=lambda: ForwardTTSArgs( + use_pitch=False, + encoder_type="residual_conv_bn", + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13, + }, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + out_channels=80, + hidden_channels=128, + positional_encoding=True, + detach_duration_predictor=True, + ) + ) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "l1" + duration_loss_type: str = "huber" + use_ssim_loss: bool = False + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 0.3 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/viXTTS/TTS/tts/configs/tacotron2_config.py b/viXTTS/TTS/tts/configs/tacotron2_config.py new file mode 100644 index 0000000000000000000000000000000000000000..95b65202218cf3aa0dd70c8d8cd55a3f913ed308 --- /dev/null +++ b/viXTTS/TTS/tts/configs/tacotron2_config.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from TTS.tts.configs.tacotron_config import TacotronConfig + + +@dataclass +class Tacotron2Config(TacotronConfig): + """Defines parameters for Tacotron2 based models. + + Example: + + >>> from TTS.tts.configs.tacotron2_config import Tacotron2Config + >>> config = Tacotron2Config() + + Check `TacotronConfig` for argument descriptions. + """ + + model: str = "tacotron2" + out_channels: int = 80 + encoder_in_features: int = 512 + decoder_in_features: int = 512 diff --git a/viXTTS/TTS/tts/configs/tacotron_config.py b/viXTTS/TTS/tts/configs/tacotron_config.py new file mode 100644 index 0000000000000000000000000000000000000000..350b5ea99633569d6977851875d5d8d83175ac36 --- /dev/null +++ b/viXTTS/TTS/tts/configs/tacotron_config.py @@ -0,0 +1,235 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig, CapacitronVAEConfig, GSTConfig + + +@dataclass +class TacotronConfig(BaseTTSConfig): + """Defines parameters for Tacotron based models. + + Example: + + >>> from TTS.tts.configs.tacotron_config import TacotronConfig + >>> config = TacotronConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Tacotron`. + use_gst (bool): + enable / disable the use of Global Style Token modules. Defaults to False. + gst (GSTConfig): + Instance of `GSTConfig` class. + gst_style_input (str): + Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and + this is not defined, the model uses a zero vector as an input. Defaults to None. + use_capacitron_vae (bool): + enable / disable the use of Capacitron modules. Defaults to False. + capacitron_vae (CapacitronConfig): + Instance of `CapacitronConfig` class. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + num_speakers (int): + Number of speakers for multi-speaker models. Defaults to 1. + r (int): + Initial number of output frames that the decoder computed per iteration. Larger values makes training and inference + faster but reduces the quality of the output frames. This must be equal to the largest `r` value used in + `gradual_training` schedule. Defaults to 1. + gradual_training (List[List]): + Parameters for the gradual training schedule. It is in the form `[[a, b, c], [d ,e ,f] ..]` where `a` is + the step number to start using the rest of the values, `b` is the `r` value and `c` is the batch size. + If sets None, no gradual training is used. Defaults to None. + memory_size (int): + Defines the number of previous frames used by the Prenet. If set to < 0, then it uses only the last frame. + Defaults to -1. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dropout (bool): + enables / disables the use of dropout in the Prenet. Defaults to True. + prenet_dropout_at_inference (bool): + enable / disable the use of dropout in the Prenet at the inference time. Defaults to False. + stopnet (bool): + enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True. + stopnet_pos_weight (float): + Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with + datasets with longer sentences. Defaults to 0.2. + max_decoder_steps (int): + Max number of steps allowed for the decoder. Defaults to 50. + encoder_in_features (int): + Channels of encoder input and character embedding tensors. Defaults to 256. + decoder_in_features (int): + Channels of decoder input and encoder output tensors. Defaults to 256. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + separate_stopnet (bool): + Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True. + attention_type (str): + attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'. + attention_heads (int): + Number of attention heads for GMM attention. Defaults to 5. + windowing (bool): + It especially useful at inference to keep attention alignment diagonal. Defaults to False. + use_forward_attn (bool): + It is only valid if ```attn_type``` is ```original```. Defaults to False. + forward_attn_mask (bool): + enable/disable extra masking over forward attention. It is useful at inference to prevent + possible attention failures. Defaults to False. + transition_agent (bool): + enable/disable transition agent in forward attention. Defaults to False. + location_attn (bool): + enable/disable location sensitive attention as in the original Tacotron2 paper. + It is only valid if ```attn_type``` is ```original```. Defaults to True. + bidirectional_decoder (bool): + enable/disable bidirectional decoding. Defaults to False. + double_decoder_consistency (bool): + enable/disable double decoder consistency. Defaults to False. + ddc_r (int): + reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this + as a multiple of the `r` value. Defaults to 6. + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to `RAdam`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `NoamLR`. + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + lr (float): + Initial learning rate. Defaults to `1e-4`. + wd (float): + Weight decay coefficient. Defaults to `1e-6`. + grad_clip (float): + Gradient clipping threshold. Defaults to `5`. + seq_len_norm (bool): + enable / disable the sequnce length normalization in the loss functions. If set True, loss of a sample + is divided by the sequence length. Defaults to False. + loss_masking (bool): + enable / disable masking the paddings of the samples in loss computation. Defaults to True. + decoder_loss_alpha (float): + Weight for the decoder loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_loss_alpha (float): + Weight for the postnet loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_diff_spec_alpha (float): + Weight for the postnet differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_diff_spec_alpha (float): + + Weight for the decoder differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_ssim_alpha (float): + Weight for the decoder SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_ssim_alpha (float): + Weight for the postnet SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + ga_alpha (float): + Weight for the guided attention loss. If set less than or equal to zero, it disables the corresponding loss + function. Defaults to 5. + """ + + model: str = "tacotron" + # model_params: TacotronArgs = field(default_factory=lambda: TacotronArgs()) + use_gst: bool = False + gst: GSTConfig = None + gst_style_input: str = None + + use_capacitron_vae: bool = False + capacitron_vae: CapacitronVAEConfig = None + + # model specific params + num_speakers: int = 1 + num_chars: int = 0 + r: int = 2 + gradual_training: List[List[int]] = None + memory_size: int = -1 + prenet_type: str = "original" + prenet_dropout: bool = True + prenet_dropout_at_inference: bool = False + stopnet: bool = True + separate_stopnet: bool = True + stopnet_pos_weight: float = 0.2 + max_decoder_steps: int = 10000 + encoder_in_features: int = 256 + decoder_in_features: int = 256 + decoder_output_dim: int = 80 + out_channels: int = 513 + + # attention layers + attention_type: str = "original" + attention_heads: int = None + attention_norm: str = "sigmoid" + attention_win: bool = False + windowing: bool = False + use_forward_attn: bool = False + forward_attn_mask: bool = False + transition_agent: bool = False + location_attn: bool = True + + # advance methods + bidirectional_decoder: bool = False + double_decoder_consistency: bool = False + ddc_r: int = 6 + + # multi-speaker settings + speakers_file: str = None + use_speaker_embedding: bool = False + speaker_embedding_dim: int = 512 + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = None + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + seq_len_norm: bool = False + loss_masking: bool = True + + # loss params + decoder_loss_alpha: float = 0.25 + postnet_loss_alpha: float = 0.25 + postnet_diff_spec_alpha: float = 0.25 + decoder_diff_spec_alpha: float = 0.25 + decoder_ssim_alpha: float = 0.25 + postnet_ssim_alpha: float = 0.25 + ga_alpha: float = 5.0 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def check_values(self): + if self.gradual_training: + assert ( + self.gradual_training[0][1] == self.r + ), f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}" + if self.model == "tacotron" and self.audio is not None: + assert self.out_channels == ( + self.audio.fft_size // 2 + 1 + ), f"{self.out_channels} vs {self.audio.fft_size // 2 + 1}" + if self.model == "tacotron2" and self.audio is not None: + assert self.out_channels == self.audio.num_mels diff --git a/viXTTS/TTS/tts/configs/tortoise_config.py b/viXTTS/TTS/tts/configs/tortoise_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d60e43d71280bfa085988e31a52acfeef015c5f0 --- /dev/null +++ b/viXTTS/TTS/tts/configs/tortoise_config.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass, field + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig + + +@dataclass +class TortoiseConfig(BaseTTSConfig): + """Defines parameters for Tortoise TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (TortoiseArgs): + Model architecture arguments. Defaults to `TortoiseArgs()`. + + audio (TortoiseAudioConfig): + Audio processing configuration. Defaults to `TortoiseAudioConfig()`. + + model_dir (str): + Path to the folder that has all the Tortoise models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + reperation_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + cond_free_k (float): + Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`. + + diffusion_temperature (float): + Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + Defaults to `1.0`. + + num_autoregressive_samples (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + diffusion_iterations (int): + Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. Defaults to `30`. + + sampler (str): + Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`. + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> config = TortoiseConfig() + """ + + model: str = "tortoise" + # model specific params + model_args: TortoiseArgs = field(default_factory=TortoiseArgs) + audio: TortoiseAudioConfig = field(default_factory=TortoiseAudioConfig) + model_dir: str = None + + # settings + temperature: float = 0.2 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_p: float = 0.8 + cond_free_k: float = 2.0 + diffusion_temperature: float = 1.0 + + # inference params + num_autoregressive_samples: int = 16 + diffusion_iterations: int = 30 + sampler: str = "ddim" diff --git a/viXTTS/TTS/tts/configs/vits_config.py b/viXTTS/TTS/tts/configs/vits_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0242bf131a25d6b2cef7a297a3c32b283f908a --- /dev/null +++ b/viXTTS/TTS/tts/configs/vits_config.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.vits import VitsArgs, VitsAudioConfig + + +@dataclass +class VitsConfig(BaseTTSConfig): + """Defines parameters for VITS End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (VitsArgs): + Model architecture arguments. Defaults to `VitsArgs()`. + + audio (VitsAudioConfig): + Audio processing configuration. Defaults to `VitsAudioConfig()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> config = VitsConfig() + """ + + model: str = "vits" + # model specific params + model_args: VitsArgs = field(default_factory=VitsArgs) + audio: VitsAudioConfig = field(default_factory=VitsAudioConfig) + + # optimizer + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # loss params + kl_loss_alpha: float = 1.0 + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + mel_loss_alpha: float = 45.0 + dur_loss_alpha: float = 1.0 + speaker_encoder_loss_alpha: float = 1.0 + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # testing + test_sentences: List[List] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: List[str] = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/viXTTS/TTS/tts/configs/xtts_config.py b/viXTTS/TTS/tts/configs/xtts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a0766d425c2e950a7078e77d3f3d7cf9ba0278ae --- /dev/null +++ b/viXTTS/TTS/tts/configs/xtts_config.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + + +@dataclass +class XttsConfig(BaseTTSConfig): + """Defines parameters for XTTS TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (XttsArgs): + Model architecture arguments. Defaults to `XttsArgs()`. + + audio (XttsAudioConfig): + Audio processing configuration. Defaults to `XttsAudioConfig()`. + + model_dir (str): + Path to the folder that has all the XTTS models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + repetition_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + num_gpt_outputs (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As XTTS is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + gpt_cond_len (int): + Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`. + + gpt_cond_chunk_len (int): + Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the + latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`. + + max_ref_len (int): + Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`. + + sound_norm_refs (bool): + Whether to normalize the conditioning audio. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> config = XttsConfig() + """ + + model: str = "xtts" + # model specific params + model_args: XttsArgs = field(default_factory=XttsArgs) + audio: XttsAudioConfig = field(default_factory=XttsAudioConfig) + model_dir: str = None + languages: List[str] = field( + default_factory=lambda: [ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh-cn", + "hu", + "ko", + "ja", + "hi", + "vi", + ] + ) + + # inference params + temperature: float = 0.85 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_k: int = 50 + top_p: float = 0.85 + num_gpt_outputs: int = 1 + + # cloning + gpt_cond_len: int = 12 + gpt_cond_chunk_len: int = 4 + max_ref_len: int = 10 + sound_norm_refs: bool = False diff --git a/viXTTS/TTS/tts/datasets/__init__.py b/viXTTS/TTS/tts/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..192138561fdb4e85978fe8beb52eae2edf73888e --- /dev/null +++ b/viXTTS/TTS/tts/datasets/__init__.py @@ -0,0 +1,181 @@ +import os +import sys +from collections import Counter +from pathlib import Path +from typing import Callable, Dict, List, Tuple, Union + +import numpy as np + +from TTS.tts.datasets.dataset import * +from TTS.tts.datasets.formatters import * + + +def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): + """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. + + Args: + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + """ + speakers = [item["speaker_name"] for item in items] + is_multi_speaker = len(set(speakers)) > 1 + if eval_split_size > 1: + eval_split_size = int(eval_split_size) + else: + if eval_split_max_size: + eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size)) + else: + eval_split_size = int(len(items) * eval_split_size) + + assert ( + eval_split_size > 0 + ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( + 1 / len(items) + ) + np.random.seed(0) + np.random.shuffle(items) + if is_multi_speaker: + items_eval = [] + speakers = [item["speaker_name"] for item in items] + speaker_counter = Counter(speakers) + while len(items_eval) < eval_split_size: + item_idx = np.random.randint(0, len(items)) + speaker_to_be_removed = items[item_idx]["speaker_name"] + if speaker_counter[speaker_to_be_removed] > 1: + items_eval.append(items[item_idx]) + speaker_counter[speaker_to_be_removed] -= 1 + del items[item_idx] + return items_eval, items + return items[:eval_split_size], items[eval_split_size:] + + +def add_extra_keys(metadata, language, dataset_name): + for item in metadata: + # add language name + item["language"] = language + # add unique audio name + relfilepath = os.path.splitext(os.path.relpath(item["audio_file"], item["root_path"]))[0] + audio_unique_name = f"{dataset_name}#{relfilepath}" + item["audio_unique_name"] = audio_unique_name + return metadata + + +def load_tts_samples( + datasets: Union[List[Dict], Dict], + eval_split=True, + formatter: Callable = None, + eval_split_max_size=None, + eval_split_size=0.01, +) -> Tuple[List[List], List[List]]: + """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. + If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based + on the dataset name. + + Args: + datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are + in the list, they are all merged. + + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + + formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It + must take the root_path and the meta_file name and return a list of samples in the format of + `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as + example. Defaults to None. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ + meta_data_train_all = [] + meta_data_eval_all = [] if eval_split else None + if not isinstance(datasets, list): + datasets = [datasets] + for dataset in datasets: + formatter_name = dataset["formatter"] + dataset_name = dataset["dataset_name"] + root_path = dataset["path"] + meta_file_train = dataset["meta_file_train"] + meta_file_val = dataset["meta_file_val"] + ignored_speakers = dataset["ignored_speakers"] + language = dataset["language"] + + # setup the right data processor + if formatter is None: + formatter = _get_formatter_by_name(formatter_name) + # load train set + meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) + assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}" + + meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) + + print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") + # load evaluation split if set + if eval_split: + if meta_file_val: + meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) + meta_data_eval = add_extra_keys(meta_data_eval, language, dataset_name) + else: + eval_size_per_dataset = eval_split_max_size // len(datasets) if eval_split_max_size else None + meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_size_per_dataset, eval_split_size) + meta_data_eval_all += meta_data_eval + meta_data_train_all += meta_data_train + # load attention masks for the duration predictor training + if dataset.meta_file_attn_mask: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) + for idx, ins in enumerate(meta_data_train_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_train_all[idx].update({"alignment_file": attn_file}) + if meta_data_eval_all: + for idx, ins in enumerate(meta_data_eval_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_eval_all[idx].update({"alignment_file": attn_file}) + # set none for the next iter + formatter = None + return meta_data_train_all, meta_data_eval_all + + +def load_attention_mask_meta_data(metafile_path): + """Load meta data file created by compute_attention_masks.py""" + with open(metafile_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + meta_data = [] + for line in lines: + wav_file, attn_file = line.split("|") + meta_data.append([wav_file, attn_file]) + return meta_data + + +def _get_formatter_by_name(name): + """Returns the respective preprocessing function.""" + thismodule = sys.modules[__name__] + return getattr(thismodule, name.lower()) + + +def find_unique_chars(data_samples, verbose=True): + texts = "".join(item[0] for item in data_samples) + chars = set(texts) + lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + + if verbose: + print(f" > Number of unique characters: {len(chars)}") + print(f" > Unique characters: {''.join(sorted(chars))}") + print(f" > Unique lower characters: {''.join(sorted(lower_chars))}") + print(f" > Unique all forced to lower characters: {''.join(sorted(chars_force_lower))}") + return chars_force_lower diff --git a/viXTTS/TTS/tts/datasets/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/tts/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69850d7cbc210808e22c17c6730c4a40e9f4c4f5 Binary files /dev/null and b/viXTTS/TTS/tts/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/datasets/__pycache__/dataset.cpython-310.pyc b/viXTTS/TTS/tts/datasets/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ec51082024341d10e8e04dcffbca52e30051c34 Binary files /dev/null and b/viXTTS/TTS/tts/datasets/__pycache__/dataset.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/datasets/__pycache__/formatters.cpython-310.pyc b/viXTTS/TTS/tts/datasets/__pycache__/formatters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2e4b257bac8a86accadacda0e1c4268a7718cbf Binary files /dev/null and b/viXTTS/TTS/tts/datasets/__pycache__/formatters.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/datasets/dataset.py b/viXTTS/TTS/tts/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..19fb25bef88375c99b5aeb7f608843b15393263e --- /dev/null +++ b/viXTTS/TTS/tts/datasets/dataset.py @@ -0,0 +1,973 @@ +import base64 +import collections +import os +import random +from typing import Dict, List, Union + +import numpy as np +import torch +import tqdm +from torch.utils.data import Dataset + +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy + +import mutagen + +# to prevent too many open files error as suggested here +# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +torch.multiprocessing.set_sharing_strategy("file_system") + + +def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + raise ValueError(" [!] Dataset cannot parse the sample.") + return text, wav_file, speaker_name, language_name, attn_file + + +def noise_augment_audio(wav): + return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + + +def string2filename(string): + # generate a safe and reversible filename based on a string + filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") + return filename + + +def get_audio_size(audiopath): + extension = audiopath.rpartition(".")[-1].lower() + if extension not in {"mp3", "wav", "flac"}: + raise RuntimeError(f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!") + + audio_info = mutagen.File(audiopath).info + return int(audio_info.length * audio_info.sample_rate) + + +class TTSDataset(Dataset): + def __init__( + self, + outputs_per_step: int = 1, + compute_linear_spec: bool = False, + ap: AudioProcessor = None, + samples: List[Dict] = None, + tokenizer: "TTSTokenizer" = None, + compute_f0: bool = False, + compute_energy: bool = False, + f0_cache_path: str = None, + energy_cache_path: str = None, + return_wav: bool = False, + batch_group_size: int = 0, + min_text_len: int = 0, + max_text_len: int = float("inf"), + min_audio_len: int = 0, + max_audio_len: int = float("inf"), + phoneme_cache_path: str = None, + precompute_num_workers: int = 0, + speaker_id_mapping: Dict = None, + d_vector_mapping: Dict = None, + language_id_mapping: Dict = None, + use_noise_augment: bool = False, + start_by_longest: bool = False, + verbose: bool = False, + ): + """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. + + If you need something different, you can subclass and override. + + Args: + outputs_per_step (int): Number of time frames predicted per step. + + compute_linear_spec (bool): compute linear spectrogram if True. + + ap (TTS.tts.utils.AudioProcessor): Audio processor object. + + samples (list): List of dataset samples. + + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + + compute_f0 (bool): compute f0 if True. Defaults to False. + + compute_energy (bool): compute energy if True. Defaults to False. + + f0_cache_path (str): Path to store f0 cache. Defaults to None. + + energy_cache_path (str): Path to store energy cache. Defaults to None. + + return_wav (bool): Return the waveform of the sample. Defaults to False. + + batch_group_size (int): Range of batch randomization after sorting + sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a + batch. Set 0 to disable. Defaults to 0. + + min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. + Defaults to 0. + + max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. + Defaults to float("inf"). + + min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. + Defaults to 0. + + max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. + The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to + this value if you encounter an OOM error in training. Defaults to float("inf"). + + phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a + separate file. Defaults to None. + + precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. + + speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the + embedding layer. Defaults to None. + + d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. + + use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + + start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + + verbose (bool): Print diagnostic information. Defaults to false. + """ + super().__init__() + self.batch_group_size = batch_group_size + self._samples = samples + self.outputs_per_step = outputs_per_step + self.compute_linear_spec = compute_linear_spec + self.return_wav = return_wav + self.compute_f0 = compute_f0 + self.compute_energy = compute_energy + self.f0_cache_path = f0_cache_path + self.energy_cache_path = energy_cache_path + self.min_audio_len = min_audio_len + self.max_audio_len = max_audio_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len + self.ap = ap + self.phoneme_cache_path = phoneme_cache_path + self.speaker_id_mapping = speaker_id_mapping + self.d_vector_mapping = d_vector_mapping + self.language_id_mapping = language_id_mapping + self.use_noise_augment = use_noise_augment + self.start_by_longest = start_by_longest + + self.verbose = verbose + self.rescue_item_idx = 1 + self.pitch_computed = False + self.tokenizer = tokenizer + + if self.tokenizer.use_phonemes: + self.phoneme_dataset = PhonemeDataset( + self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers + ) + + if compute_f0: + self.f0_dataset = F0Dataset( + self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers + ) + if compute_energy: + self.energy_dataset = EnergyDataset( + self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers + ) + if self.verbose: + self.print_logs() + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = get_audio_size(wav_file) + lens.append(audio_len) + return lens + + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, new_samples): + self._samples = new_samples + if hasattr(self, "f0_dataset"): + self.f0_dataset.samples = new_samples + if hasattr(self, "energy_dataset"): + self.energy_dataset.samples = new_samples + if hasattr(self, "phoneme_dataset"): + self.phoneme_dataset.samples = new_samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + return self.load_data(idx) + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> DataLoader initialization") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + def load_wav(self, filename): + waveform = self.ap.load_wav(filename) + assert waveform.size > 0 + return waveform + + def get_phonemes(self, idx, text): + out_dict = self.phoneme_dataset[idx] + assert text == out_dict["text"], f"{text} != {out_dict['text']}" + assert len(out_dict["token_ids"]) > 0 + return out_dict + + def get_f0(self, idx): + out_dict = self.f0_dataset[idx] + item = self.samples[idx] + assert item["audio_unique_name"] == out_dict["audio_unique_name"] + return out_dict + + def get_energy(self, idx): + out_dict = self.energy_dataset[idx] + item = self.samples[idx] + assert item["audio_unique_name"] == out_dict["audio_unique_name"] + return out_dict + + @staticmethod + def get_attn_mask(attn_file): + return np.load(attn_file) + + def get_token_ids(self, idx, text): + if self.tokenizer.use_phonemes: + token_ids = self.get_phonemes(idx, text)["token_ids"] + else: + token_ids = self.tokenizer.text_to_ids(text) + return np.array(token_ids, dtype=np.int32) + + def load_data(self, idx): + item = self.samples[idx] + + raw_text = item["text"] + + wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) + + # apply noise for augmentation + if self.use_noise_augment: + wav = noise_augment_audio(wav) + + # get token ids + token_ids = self.get_token_ids(idx, item["text"]) + + # get pre-computed attention maps + attn = None + if "alignment_file" in item: + attn = self.get_attn_mask(item["alignment_file"]) + + # after phonemization the text length may change + # this is a shareful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: + self.rescue_item_idx += 1 + return self.load_data(self.rescue_item_idx) + + # get f0 values + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + energy = None + if self.compute_energy: + energy = self.get_energy(idx)["energy"] + + sample = { + "raw_text": raw_text, + "token_ids": token_ids, + "wav": wav, + "pitch": f0, + "energy": energy, + "attn": attn, + "item_idx": item["audio_file"], + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "wav_file_name": os.path.basename(item["audio_file"]), + "audio_unique_name": item["audio_unique_name"], + } + return sample + + @staticmethod + def _compute_lengths(samples): + new_samples = [] + for item in samples: + audio_length = get_audio_size(item["audio_file"]) + text_lenght = len(item["text"]) + item["audio_length"] = audio_length + item["text_length"] = text_lenght + new_samples += [item] + return new_samples + + @staticmethod + def filter_by_length(lengths: List[int], min_len: int, max_len: int): + idxs = np.argsort(lengths) # ascending order + ignore_idx = [] + keep_idx = [] + for idx in idxs: + length = lengths[idx] + if length < min_len or length > max_len: + ignore_idx.append(idx) + else: + keep_idx.append(idx) + return ignore_idx, keep_idx + + @staticmethod + def sort_by_length(samples: List[List]): + audio_lengths = [s["audio_length"] for s in samples] + idxs = np.argsort(audio_lengths) # ascending order + return idxs + + @staticmethod + def create_buckets(samples, batch_group_size: int): + assert batch_group_size > 0 + for i in range(len(samples) // batch_group_size): + offset = i * batch_group_size + end_offset = offset + batch_group_size + temp_items = samples[offset:end_offset] + random.shuffle(temp_items) + samples[offset:end_offset] = temp_items + return samples + + @staticmethod + def _select_samples_by_idx(idxs, samples): + samples_new = [] + for idx in idxs: + samples_new.append(samples[idx]) + return samples_new + + def preprocess_samples(self): + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + """ + samples = self._compute_lengths(self.samples) + + # sort items based on the sequence length in ascending order + text_lengths = [i["text_length"] for i in samples] + audio_lengths = [i["audio_length"] for i in samples] + text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) + keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) + ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) + + samples = self._select_samples_by_idx(keep_idx, samples) + + sorted_idxs = self.sort_by_length(samples) + + if self.start_by_longest: + longest_idxs = sorted_idxs[-1] + sorted_idxs[-1] = sorted_idxs[0] + sorted_idxs[0] = longest_idxs + + samples = self._select_samples_by_idx(sorted_idxs, samples) + + if len(samples) == 0: + raise RuntimeError(" [!] No samples left") + + # shuffle batch groups + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. + if self.batch_group_size > 0: + samples = self.create_buckets(samples, self.batch_group_size) + + # update items to the new sorted items + audio_lengths = [s["audio_length"] for s in samples] + text_lengths = [s["text_length"] for s in samples] + self.samples = samples + + if self.verbose: + print(" | > Preprocessing samples") + print(" | > Max text length: {}".format(np.max(text_lengths))) + print(" | > Min text length: {}".format(np.min(text_lengths))) + print(" | > Avg text length: {}".format(np.mean(text_lengths))) + print(" | ") + print(" | > Max audio length: {}".format(np.max(audio_lengths))) + print(" | > Min audio length: {}".format(np.min(audio_lengths))) + print(" | > Avg audio length: {}".format(np.mean(audio_lengths))) + print(f" | > Num. instances discarded samples: {len(ignore_idx)}") + print(" | > Batch group size: {}.".format(self.batch_group_size)) + + @staticmethod + def _sort_batch(batch, text_lengths): + """Sort the batch by the input text length for RNN efficiency. + + Args: + batch (Dict): Batch returned by `__getitem__`. + text_lengths (List[int]): Lengths of the input character sequences. + """ + text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) + batch = [batch[idx] for idx in ids_sorted_decreasing] + return batch, text_lengths, ids_sorted_decreasing + + def collate_fn(self, batch): + r""" + Perform preprocessing and create a final data batch: + 1. Sort batch instances by text-length + 2. Convert Audio signal to features. + 3. PAD sequences wrt r. + 4. Load to Torch. + """ + + # Puts each data field into a tensor with outer dimension batch size + if isinstance(batch[0], collections.abc.Mapping): + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) + + # sort items with text input length for RNN efficiency + batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) + + # convert list of dicts to dict of lists + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + # get language ids from language names + if self.language_id_mapping is not None: + language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] + else: + language_ids = None + # get pre-computed d-vectors + if self.d_vector_mapping is not None: + embedding_keys = list(batch["audio_unique_name"]) + d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys] + else: + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_id_mapping: + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] + else: + speaker_ids = None + # compute features + mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] + + mel_lengths = [m.shape[1] for m in mel] + + # lengths adjusted by the reduction factor + mel_lengths_adjusted = [ + m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) + if m.shape[1] % self.outputs_per_step + else m.shape[1] + for m in mel + ] + + # compute 'stop token' targets + stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] + + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) + + # PAD sequences with longest instance in the batch + token_ids = prepare_data(batch["token_ids"]).astype(np.int32) + + # PAD features with longest instance + mel = prepare_tensor(mel, self.outputs_per_step) + + # B x D x T --> B x T x D + mel = mel.transpose(0, 2, 1) + + # convert things to pytorch + token_ids_lengths = torch.LongTensor(token_ids_lengths) + token_ids = torch.LongTensor(token_ids) + mel = torch.FloatTensor(mel).contiguous() + mel_lengths = torch.LongTensor(mel_lengths) + stop_targets = torch.FloatTensor(stop_targets) + + # speaker vectors + if d_vectors is not None: + d_vectors = torch.FloatTensor(d_vectors) + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + # compute linear spectrogram + linear = None + if self.compute_linear_spec: + linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] + linear = prepare_tensor(linear, self.outputs_per_step) + linear = linear.transpose(0, 2, 1) + assert mel.shape[1] == linear.shape[1] + linear = torch.FloatTensor(linear).contiguous() + + # format waveforms + wav_padded = None + if self.return_wav: + wav_lengths = [w.shape[0] for w in batch["wav"]] + max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length + wav_lengths = torch.LongTensor(wav_lengths) + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + for i, w in enumerate(batch["wav"]): + mel_length = mel_lengths_adjusted[i] + w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") + w = w[: mel_length * self.ap.hop_length] + wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) + wav_padded.transpose_(1, 2) + + # format F0 + if self.compute_f0: + pitch = prepare_data(batch["pitch"]) + assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None + # format energy + if self.compute_energy: + energy = prepare_data(batch["energy"]) + assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}" + energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT + else: + energy = None + # format attention masks + attns = None + if batch["attn"][0] is not None: + attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] + for idx, attn in enumerate(attns): + pad2 = mel.shape[1] - attn.shape[1] + pad1 = token_ids.shape[1] - attn.shape[0] + assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" + attn = np.pad(attn, [[0, pad1], [0, pad2]]) + attns[idx] = attn + attns = prepare_tensor(attns, self.outputs_per_step) + attns = torch.FloatTensor(attns).unsqueeze(1) + + return { + "token_id": token_ids, + "token_id_lengths": token_ids_lengths, + "speaker_names": batch["speaker_name"], + "linear": linear, + "mel": mel, + "mel_lengths": mel_lengths, + "stop_targets": stop_targets, + "item_idxs": batch["item_idx"], + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "attns": attns, + "waveform": wav_padded, + "raw_text": batch["raw_text"], + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_name"], + } + + raise TypeError( + ( + "batch must contain tensors, numbers, dicts or lists;\ + found {}".format( + type(batch[0]) + ) + ) + ) + + +class PhonemeDataset(Dataset): + """Phoneme Dataset for converting input text to phonemes and then token IDs + + At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data + loading latency. If `cache_path` is already present, it skips the pre-computation. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + tokenizer (TTSTokenizer): + Tokenizer to convert input text to phonemes. + + cache_path (str): + Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. + + precompute_num_workers (int): + Number of workers used for pre-computing the phonemes. Defaults to 0. + """ + + def __init__( + self, + samples: Union[List[Dict], List[List]], + tokenizer: "TTSTokenizer", + cache_path: str, + precompute_num_workers=0, + ): + self.samples = samples + self.tokenizer = tokenizer + self.cache_path = cache_path + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + + def __getitem__(self, index): + item = self.samples[index] + ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"]) + ph_hat = self.tokenizer.ids_to_text(ids) + return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + + def __len__(self): + return len(self.samples) + + def compute_or_load(self, file_name, text, language): + """Compute phonemes for the given text. + + If the phonemes are already cached, load them from cache. + """ + file_ext = "_phoneme.npy" + cache_path = os.path.join(self.cache_path, file_name + file_ext) + try: + ids = np.load(cache_path) + except FileNotFoundError: + ids = self.tokenizer.text_to_ids(text, language=language) + np.save(cache_path, ids) + return ids + + def get_pad_id(self): + """Get pad token ID for sequence padding""" + return self.tokenizer.pad_id + + def precompute(self, num_workers=1): + """Precompute phonemes for all samples. + + We use pytorch dataloader because we are lazy. + """ + print("[*] Pre-computing phonemes...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + for _ in dataloder: + pbar.update(batch_size) + + def collate_fn(self, batch): + ids = [item["token_ids"] for item in batch] + ids_lens = [item["token_ids_len"] for item in batch] + texts = [item["text"] for item in batch] + texts_hat = [item["ph_hat"] for item in batch] + ids_lens_max = max(ids_lens) + ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) + for i, ids_len in enumerate(ids_lens): + ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) + return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> PhonemeDataset ") + print(f"{indent}| > Tokenizer:") + self.tokenizer.print_logs(level + 1) + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class F0Dataset: + """F0 Dataset for computing F0 from wav files in CPU + + Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of F0 values if `normalize_f0` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute F0 from wav files. + + cache_path (str): + Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the F0 values. Defaults to 0. + + normalize_f0 (bool): + Whether to normalize F0 values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + audio_config=None, # pylint: disable=unused-argument + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + self.samples = samples + self.ap = ap + self.verbose = verbose + self.cache_path = cache_path + self.normalize_f0 = normalize_f0 + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_f0: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) + if self.normalize_f0: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + f0 = self.normalize(f0) + return {"audio_unique_name": item["audio_unique_name"], "f0": f0} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + print("[*] Pre-computing F0s...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_f0 = self.normalize_f0 + self.normalize_f0 = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + f0 = batch["f0"] + computed_data.append(f for f in f0) + pbar.update(batch_size) + self.normalize_f0 = normalize_f0 + + if self.normalize_f0: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_pitch_file_path(file_name, cache_path): + pitch_file = os.path.join(cache_path, file_name + "_pitch.npy") + return pitch_file + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean + pitch[zero_idxs] = 0.0 + return pitch + + def compute_or_load(self, wav_file, audio_unique_name): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + def collate_fn(self, batch): + audio_unique_name = [item["audio_unique_name"] for item in batch] + f0s = [item["f0"] for item in batch] + f0_lens = [len(item["f0"]) for item in batch] + f0_lens_max = max(f0_lens) + f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + for i, f0_len in enumerate(f0_lens): + f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> F0Dataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") + + +class EnergyDataset: + """Energy Dataset for computing Energy from wav files in CPU + + Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of Energy values if `normalize_Energy` is True. + + Args: + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute Energy from wav files. + + cache_path (str): + Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the Energy values. Defaults to 0. + + normalize_Energy (bool): + Whether to normalize Energy values by mean and std. Defaults to True. + """ + + def __init__( + self, + samples: Union[List[List], List[Dict]], + ap: "AudioProcessor", + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_energy=True, + ): + self.samples = samples + self.ap = ap + self.verbose = verbose + self.cache_path = cache_path + self.normalize_energy = normalize_energy + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_energy: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) + if self.normalize_energy: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + energy = self.normalize(energy) + return {"audio_unique_name": item["audio_unique_name"], "energy": energy} + + def __len__(self): + return len(self.samples) + + def precompute(self, num_workers=0): + print("[*] Pre-computing energys...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_energy = self.normalize_energy + self.normalize_energy = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn + ) + computed_data = [] + for batch in dataloder: + energy = batch["energy"] + computed_data.append(e for e in energy) + pbar.update(batch_size) + self.normalize_energy = normalize_energy + + if self.normalize_energy: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + energy_mean, energy_std = self.compute_energy_stats(computed_data) + energy_stats = {"mean": energy_mean, "std": energy_std} + np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_energy_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + energy_file = os.path.join(cache_path, file_name + "_energy.npy") + return energy_file + + @staticmethod + def _compute_and_save_energy(ap, wav_file, energy_file=None): + wav = ap.load_wav(wav_file) + energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length) + if energy_file: + np.save(energy_file, energy) + return energy + + @staticmethod + def compute_energy_stats(energy_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path): + stats_path = os.path.join(cache_path, "energy_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy = energy - self.mean + energy = energy / self.std + energy[zero_idxs] = 0.0 + return energy + + def denormalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy *= self.std + energy += self.mean + energy[zero_idxs] = 0.0 + return energy + + def compute_or_load(self, wav_file, audio_unique_name): + """ + compute energy and return a numpy array of energy values + """ + energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) + if not os.path.exists(energy_file): + energy = self._compute_and_save_energy(self.ap, wav_file, energy_file) + else: + energy = np.load(energy_file) + return energy.astype(np.float32) + + def collate_fn(self, batch): + audio_unique_name = [item["audio_unique_name"] for item in batch] + energys = [item["energy"] for item in batch] + energy_lens = [len(item["energy"]) for item in batch] + energy_lens_max = max(energy_lens) + energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id()) + for i, energy_len in enumerate(energy_lens): + energys_torch[i, :energy_len] = torch.LongTensor(energys[i]) + return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + print("\n") + print(f"{indent}> energyDataset ") + print(f"{indent}| > Number of instances : {len(self.samples)}") diff --git a/viXTTS/TTS/tts/datasets/formatters.py b/viXTTS/TTS/tts/datasets/formatters.py new file mode 100644 index 0000000000000000000000000000000000000000..053444b0c1010d5a93970cc26b2c225e305279a9 --- /dev/null +++ b/viXTTS/TTS/tts/datasets/formatters.py @@ -0,0 +1,655 @@ +import os +import re +import xml.etree.ElementTree as ET +from glob import glob +from pathlib import Path +from typing import List + +import pandas as pd +from tqdm import tqdm + +######################## +# DATASETS +######################## + + +def cml_tts(root_path, meta_file, ignored_speakers=None): + """Normalizes the CML-TTS meta data file to TTS format + https://github.com/freds0/CML-TTS-Dataset/""" + filepath = os.path.join(root_path, meta_file) + # ensure there are 4 columns for every line + with open(filepath, "r", encoding="utf8") as f: + lines = f.readlines() + num_cols = len(lines[0].split("|")) # take the first row as reference + for idx, line in enumerate(lines[1:]): + if len(line.split("|")) != num_cols: + print(f" > Missing column in line {idx + 1} -> {line.strip()}") + # load metadata + metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|") + assert all(x in metadata.columns for x in ["wav_filename", "transcript"]) + client_id = None if "client_id" in metadata.columns else "default" + emotion_name = None if "emotion_name" in metadata.columns else "neutral" + items = [] + not_found_counter = 0 + for row in metadata.itertuples(): + if client_id is None and ignored_speakers is not None and row.client_id in ignored_speakers: + continue + audio_path = os.path.join(root_path, row.wav_filename) + if not os.path.exists(audio_path): + not_found_counter += 1 + continue + items.append( + { + "text": row.transcript, + "audio_file": audio_path, + "speaker_name": client_id if client_id is not None else row.client_id, + "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "root_path": root_path, + } + ) + if not_found_counter > 0: + print(f" | > [!] {not_found_counter} files not found") + return items + + +def coqui(root_path, meta_file, ignored_speakers=None): + """Interal dataset formatter.""" + filepath = os.path.join(root_path, meta_file) + # ensure there are 4 columns for every line + with open(filepath, "r", encoding="utf8") as f: + lines = f.readlines() + num_cols = len(lines[0].split("|")) # take the first row as reference + for idx, line in enumerate(lines[1:]): + if len(line.split("|")) != num_cols: + print(f" > Missing column in line {idx + 1} -> {line.strip()}") + # load metadata + metadata = pd.read_csv(os.path.join(root_path, meta_file), sep="|") + assert all(x in metadata.columns for x in ["audio_file", "text"]) + speaker_name = None if "speaker_name" in metadata.columns else "coqui" + emotion_name = None if "emotion_name" in metadata.columns else "neutral" + items = [] + not_found_counter = 0 + for row in metadata.itertuples(): + if speaker_name is None and ignored_speakers is not None and row.speaker_name in ignored_speakers: + continue + audio_path = os.path.join(root_path, row.audio_file) + if not os.path.exists(audio_path): + not_found_counter += 1 + continue + items.append( + { + "text": row.text, + "audio_file": audio_path, + "speaker_name": speaker_name if speaker_name is not None else row.speaker_name, + "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "root_path": root_path, + } + ) + if not_found_counter > 0: + print(f" | > [!] {not_found_counter} files not found") + return items + + +def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalize TWEB dataset. + https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "tweb" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + wav_file = os.path.join(root_path, cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = cols[1].strip() + text = cols[0].strip() + wav_file = os.path.join(root_path, "wavs", wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="ISO 8859-1") as ttf: + for line in ttf: + cols = line.strip().split("|") + wav_file = cols[0].strip() + text = cols[1].strip() + folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" + wav_file = os.path.join(root_path, folder_name, wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mailabs(root_path, meta_files=None, ignored_speakers=None): + """Normalizes M-AI-Labs meta data files to TTS format + + Args: + root_path (str): root folder of the MAILAB language folder. + meta_files (str): list of meta files to be used in the training. If None, finds all the csv files + recursively. Defaults to None + """ + speaker_regex = re.compile(f"by_book{os.sep}(male|female){os.sep}(?P[^{os.sep}]+){os.sep}") + if not meta_files: + csv_files = glob(root_path + f"{os.sep}**{os.sep}metadata.csv", recursive=True) + else: + csv_files = meta_files + + # meta_files = [f.strip() for f in meta_files.split(",")] + items = [] + for csv_file in csv_files: + if os.path.isfile(csv_file): + txt_file = csv_file + else: + txt_file = os.path.join(root_path, csv_file) + + folder = os.path.dirname(txt_file) + # determine speaker based on folder structure... + speaker_name_match = speaker_regex.search(txt_file) + if speaker_name_match is None: + continue + speaker_name = speaker_name_match.group("speaker_name") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + print(" | > {}".format(csv_file)) + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + if not meta_files: + wav_file = os.path.join(folder, "wavs", cols[0] + ".wav") + else: + wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") + if os.path.isfile(wav_file): + text = cols[1].strip() + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path} + ) + else: + # M-AI-Labs have some missing samples, so just print the warning + print("> File %s does not exist!" % (wav_file)) + return items + + +def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file to TTS format + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ljspeech" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file for TTS testing + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + speaker_id = 0 + for idx, line in enumerate(ttf): + # 2 samples per speaker to avoid eval split issues + if idx % 2 == 0: + speaker_id += 1 + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path} + ) + return items + + +def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the thorsten meta data file to TTS format + https://github.com/thorstenMueller/deep-learning-german-tts/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "thorsten" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the sam-accenture meta data file to TTS format + https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files""" + xml_file = os.path.join(root_path, "voice_over_recordings", meta_file) + xml_root = ET.parse(xml_file).getroot() + items = [] + speaker_name = "sam_accenture" + for item in xml_root.findall("./fileid"): + text = item.text + wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") + if not os.path.exists(wav_file): + print(f" [!] {wav_file} in metafile does not exist. Skipping...") + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the RUSLAN meta data file to TTS format + https://ruslan-corpus.github.io/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ruslan" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the CSS10 dataset file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "css10" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the Nancy meta data file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "nancy" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + utt_id = line.split()[1] + text = line[line.find('"') + 1 : line.rfind('"') - 1] + wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def common_voice(root_path, meta_file, ignored_speakers=None): + """Normalize the common voice meta data file to TTS format.""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("client_id"): + continue + cols = line.split("\t") + text = cols[2] + speaker_name = cols[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path} + ) + return items + + +def libri_tts(root_path, meta_files=None, ignored_speakers=None): + """https://ai.google/tools/datasets/libri-tts/""" + items = [] + if not meta_files: + meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) + else: + if isinstance(meta_files, str): + meta_files = [os.path.join(root_path, meta_files)] + + for meta_file in meta_files: + _meta_file = os.path.basename(meta_file).split(".")[0] + with open(meta_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + file_name = cols[0] + speaker_name, chapter_id, *_ = cols[0].split("_") + _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") + wav_file = os.path.join(_root_path, file_name + ".wav") + text = cols[2] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + items.append( + { + "text": text, + "audio_file": wav_file, + "speaker_name": f"LTTS_{speaker_name}", + "root_path": root_path, + } + ) + for item in items: + assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}" + return items + + +def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "turkish-female" + skipped_files = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav") + if not os.path.exists(wav_file): + skipped_files.append(wav_file) + continue + text = cols[1].strip() + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + print(f" [!] {len(skipped_files)} files skipped. They don't exist...") + return items + + +# ToDo: add the dataset link when the dataset is released publicly +def brspeech(root_path, meta_file, ignored_speakers=None): + """BRSpeech 3.0 beta""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("wav_filename"): + continue + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] + speaker_id = cols[3] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path}) + return items + + +def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None): + """VCTK dataset v0.92. + + URL: + https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip + + This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```. + It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset. + + mic1: + Audio recorded using an omni-directional microphone (DPA 4035). + Contains very low frequency noises. + This is the same audio released in previous versions of VCTK: + https://doi.org/10.7488/ds/1994 + + mic2: + Audio recorded using a small diaphragm condenser microphone with + very wide bandwidth (Sennheiser MKH 800). + Two speakers, p280 and p315 had technical issues of the audio + recordings using MKH 800. + """ + file_ext = "flac" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + # p280 has no mic2 recordings + if speaker_id == "p280": + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}") + else: + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") + if os.path.exists(wav_file): + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} + ) + else: + print(f" [!] wav files don't exist - {wav_file}") + return items + + +def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): + """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path} + ) + return items + + +def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-argument + items = [] + speaker_name = "synpaflex" + root_path = os.path.join(root_path, "") + wav_files = glob(f"{root_path}**/*.wav", recursive=True) + for wav_file in wav_files: + if os.sep + "wav" + os.sep in wav_file: + txt_file = wav_file.replace("wav", "txt") + else: + txt_file = os.path.join( + os.path.dirname(wav_file), "txt", os.path.basename(wav_file).replace(".wav", ".txt") + ) + if os.path.exists(txt_file) and os.path.exists(wav_file): + with open(txt_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, ignored_speakers=None): + """ToDo: Refer the paper when available""" + items = [] + split_dir = meta_files + meta_files = glob(f"{os.path.join(root_path, split_dir)}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readline().replace("\n", "") + # ignore sentences that contains digits + if ignore_digits_sentences and any(map(str.isdigit, text)): + continue + wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac") + items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path}) + return items + + +def mls(root_path, meta_files=None, ignored_speakers=None): + """http://www.openslr.org/94/""" + items = [] + with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta: + for line in meta: + file, text = line.split("\t") + text = text[:-1] + speaker, book, *_ = file.split("_") + wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker in ignored_speakers: + continue + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path} + ) + return items + + +# ======================================== VOX CELEB =========================================== +def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="2") + + +def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="1") + + +def _voxcel_x(root_path, meta_file, voxcel_idx): + assert voxcel_idx in ["1", "2"] + expected_count = 148_000 if voxcel_idx == "1" else 1_000_000 + voxceleb_path = Path(root_path) + cache_to = voxceleb_path / f"metafile_voxceleb{voxcel_idx}.csv" + cache_to.parent.mkdir(exist_ok=True) + + # if not exists meta file, crawl recursively for 'wav' files + if meta_file is not None: + with open(str(meta_file), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + elif not cache_to.exists(): + cnt = 0 + meta_data = [] + wav_files = voxceleb_path.rglob("**/*.wav") + for path in tqdm( + wav_files, + desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", + total=expected_count, + ): + speaker_id = str(Path(path).parent.parent.stem) + assert speaker_id.startswith("id") + text = None # VoxCel does not provide transciptions, and they are not needed for training the SE + meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") + cnt += 1 + with open(str(cache_to), "w", encoding="utf-8") as f: + f.write("".join(meta_data)) + if cnt < expected_count: + raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") + + with open(str(cache_to), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + +def emotion(root_path, meta_file, ignored_speakers=None): + """Generic emotion dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("file_path"): + continue + cols = line.split(",") + wav_file = os.path.join(root_path, cols[0]) + speaker_id = cols[1] + emotion_id = cols[2].replace("\n", "") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append( + {"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path} + ) + return items + + +def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument + """Normalizes the Baker meta data file to TTS format + + Args: + root_path (str): path to the baker dataset + meta_file (str): name of the meta dataset containing names of wav to select and the transcript of the sentence + Returns: + List[List[str]]: List of (text, wav_path, speaker_name) associated with each sentences + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "baker" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + wav_name, text = line.rstrip("\n").split("|") + wav_path = os.path.join(root_path, "clips_22", wav_name) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kokoro" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2].replace(" ", "") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def kss(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Korean single-speaker dataset from https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kss" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] # cols[1] => 6월, cols[2] => 유월 + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def bel_tts_formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "bel_tts" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items diff --git a/viXTTS/TTS/tts/layers/__init__.py b/viXTTS/TTS/tts/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f93efdb7fc41109ec3497d8e5e37ba05b0a4315e --- /dev/null +++ b/viXTTS/TTS/tts/layers/__init__.py @@ -0,0 +1 @@ +from TTS.tts.layers.losses import * diff --git a/viXTTS/TTS/tts/layers/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/tts/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e7c9eec331c8ae4ad14428d0456aefe3475d47f Binary files /dev/null and b/viXTTS/TTS/tts/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/__pycache__/losses.cpython-310.pyc b/viXTTS/TTS/tts/layers/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bca27e428c3ca707e8cfadc181096e569c3e8ffd Binary files /dev/null and b/viXTTS/TTS/tts/layers/__pycache__/losses.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/align_tts/__init__.py b/viXTTS/TTS/tts/layers/align_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/align_tts/duration_predictor.py b/viXTTS/TTS/tts/layers/align_tts/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b83894cc3f87575a89ea8fd7bf4a584ca22c28 --- /dev/null +++ b/viXTTS/TTS/tts/layers/align_tts/duration_predictor.py @@ -0,0 +1,21 @@ +from torch import nn + +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.generic.transformer import FFTransformerBlock + + +class DurationPredictor(nn.Module): + def __init__(self, num_chars, hidden_channels, hidden_channels_ffn, num_heads): + super().__init__() + self.embed = nn.Embedding(num_chars, hidden_channels) + self.pos_enc = PositionalEncoding(hidden_channels, dropout_p=0.1) + self.FFT = FFTransformerBlock(hidden_channels, num_heads, hidden_channels_ffn, 2, 0.1) + self.out_layer = nn.Conv1d(hidden_channels, 1, 1) + + def forward(self, text, text_lengths): + # B, L -> B, L + emb = self.embed(text) + emb = self.pos_enc(emb.transpose(1, 2)) + x = self.FFT(emb, text_lengths) + x = self.out_layer(x).squeeze(-1) + return x diff --git a/viXTTS/TTS/tts/layers/align_tts/mdn.py b/viXTTS/TTS/tts/layers/align_tts/mdn.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb332524bf7a5fec6a23da9e7977de6325a0324 --- /dev/null +++ b/viXTTS/TTS/tts/layers/align_tts/mdn.py @@ -0,0 +1,30 @@ +from torch import nn + + +class MDNBlock(nn.Module): + """Mixture of Density Network implementation + https://arxiv.org/pdf/2003.01950.pdf + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv1d(in_channels, in_channels, 1) + self.norm = nn.LayerNorm(in_channels) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.1) + self.conv2 = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x): + o = self.conv1(x) + o = o.transpose(1, 2) + o = self.norm(o) + o = o.transpose(1, 2) + o = self.relu(o) + o = self.dropout(o) + mu_sigma = self.conv2(o) + # TODO: check this sigmoid + # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) + mu = mu_sigma[:, : self.out_channels // 2, :] + log_sigma = mu_sigma[:, self.out_channels // 2 :, :] + return mu, log_sigma diff --git a/viXTTS/TTS/tts/layers/bark/__init__.py b/viXTTS/TTS/tts/layers/bark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/bark/hubert/__init__.py b/viXTTS/TTS/tts/layers/bark/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/bark/hubert/hubert_manager.py b/viXTTS/TTS/tts/layers/bark/hubert/hubert_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc199294164da0e8c480e292dd5a478e72f4daf --- /dev/null +++ b/viXTTS/TTS/tts/layers/bark/hubert/hubert_manager.py @@ -0,0 +1,35 @@ +# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer + +import os.path +import shutil +import urllib.request + +import huggingface_hub + + +class HubertManager: + @staticmethod + def make_sure_hubert_installed( + download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" + ): + if not os.path.isfile(model_path): + print("Downloading HuBERT base model") + urllib.request.urlretrieve(download_url, model_path) + print("Downloaded HuBERT") + return model_path + return None + + @staticmethod + def make_sure_tokenizer_installed( + model: str = "quantifier_hubert_base_ls960_14.pth", + repo: str = "GitMylo/bark-voice-cloning", + model_path: str = "", + ): + model_dir = os.path.dirname(model_path) + if not os.path.isfile(model_path): + print("Downloading HuBERT custom tokenizer") + huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) + shutil.move(os.path.join(model_dir, model), model_path) + print("Downloaded tokenizer") + return model_path + return None diff --git a/viXTTS/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/viXTTS/TTS/tts/layers/bark/hubert/kmeans_hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a3b9aeb1111ca0abeccb6142007ecc5b39d78d --- /dev/null +++ b/viXTTS/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -0,0 +1,82 @@ +""" +Modified HuBERT model without kmeans. +Original author: https://github.com/lucidrains/ +Modified by: https://www.github.com/gitmylo/ +License: MIT +""" + +# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py + +import logging +from pathlib import Path + +import torch +from einops import pack, unpack +from torch import nn +from torchaudio.functional import resample +from transformers import HubertModel + + +def round_down_nearest_multiple(num, divisor): + return num // divisor * divisor + + +def curtail_to_multiple(t, mult, from_left=False): + data_len = t.shape[-1] + rounded_seq_len = round_down_nearest_multiple(data_len, mult) + seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None) + return t[..., seq_slice] + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +class CustomHubert(nn.Module): + """ + checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert + or you can train your own + """ + + def __init__(self, checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None): + super().__init__() + self.target_sample_hz = target_sample_hz + self.seq_len_multiple_of = seq_len_multiple_of + self.output_layer = output_layer + if device is not None: + self.to(device) + self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960") + if device is not None: + self.model.to(device) + self.model.eval() + + @property + def groups(self): + return 1 + + @torch.no_grad() + def forward(self, wav_input, flatten=True, input_sample_hz=None): + device = wav_input.device + + if exists(input_sample_hz): + wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) + + if exists(self.seq_len_multiple_of): + wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) + + outputs = self.model.forward( + wav_input, + output_hidden_states=True, + ) + embed = outputs["hidden_states"][self.output_layer] + embed, packed_shape = pack([embed], "* d") + codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) + if flatten: + return codebook_indices + + (codebook_indices,) = unpack(codebook_indices, packed_shape, "*") + return codebook_indices diff --git a/viXTTS/TTS/tts/layers/bark/hubert/tokenizer.py b/viXTTS/TTS/tts/layers/bark/hubert/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3070241f1cc1ac95867f2d4173495b9a7047a15e --- /dev/null +++ b/viXTTS/TTS/tts/layers/bark/hubert/tokenizer.py @@ -0,0 +1,195 @@ +""" +Custom tokenizer model. +Author: https://www.github.com/gitmylo/ +License: MIT +""" + +import json +import os.path +from zipfile import ZipFile + +import numpy +import torch +from torch import nn, optim + + +class HubertTokenizer(nn.Module): + def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): + super().__init__() + next_size = input_size + if version == 0: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + next_size = hidden_size + if version == 1: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + self.intermediate = nn.Linear(hidden_size, 4096) + next_size = 4096 + + self.fc = nn.Linear(next_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + self.optimizer: optim.Optimizer = None + self.lossfunc = nn.CrossEntropyLoss() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + def forward(self, x): + x, _ = self.lstm(x) + if self.version == 1: + x = self.intermediate(x) + x = self.fc(x) + x = self.softmax(x) + return x + + @torch.no_grad() + def get_token(self, x): + """ + Used to get the token for the first + :param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model. + :return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model. + """ + return torch.argmax(self(x), dim=1) + + def prepare_training(self): + self.optimizer = optim.Adam(self.parameters(), 0.001) + + def train_step(self, x_train, y_train, log_loss=False): + # y_train = y_train[:-1] + # y_train = y_train[1:] + + optimizer = self.optimizer + lossfunc = self.lossfunc + # Zero the gradients + self.zero_grad() + + # Forward pass + y_pred = self(x_train) + + y_train_len = len(y_train) + y_pred_len = y_pred.shape[0] + + if y_train_len > y_pred_len: + diff = y_train_len - y_pred_len + y_train = y_train[diff:] + elif y_train_len < y_pred_len: + diff = y_pred_len - y_train_len + y_pred = y_pred[:-diff, :] + + y_train_hot = torch.zeros(len(y_train), self.output_size) + y_train_hot[range(len(y_train)), y_train] = 1 + y_train_hot = y_train_hot.to("cuda") + + # Calculate the loss + loss = lossfunc(y_pred, y_train_hot) + + # Print loss + if log_loss: + print("Loss", loss.item()) + + # Backward pass + loss.backward() + + # Update the weights + optimizer.step() + + def save(self, path): + info_path = ".".join(os.path.basename(path).split(".")[:-1]) + "/.info" + torch.save(self.state_dict(), path) + data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version) + with ZipFile(path, "a") as model_zip: + model_zip.writestr(info_path, data_from_model.save()) + model_zip.close() + + @staticmethod + def load_from_checkpoint(path, map_location=None): + old = True + with ZipFile(path) as model_zip: + filesMatch = [file for file in model_zip.namelist() if file.endswith("/.info")] + file = filesMatch[0] if filesMatch else None + if file: + old = False + data_from_model = Data.load(model_zip.read(file).decode("utf-8")) + model_zip.close() + if old: + model = HubertTokenizer() + else: + model = HubertTokenizer( + data_from_model.hidden_size, + data_from_model.input_size, + data_from_model.output_size, + data_from_model.version, + ) + model.load_state_dict(torch.load(path, map_location=map_location)) + if map_location: + model = model.to(map_location) + return model + + +class Data: + input_size: int + hidden_size: int + output_size: int + version: int + + def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0): + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + @staticmethod + def load(string): + data = json.loads(string) + return Data(data["input_size"], data["hidden_size"], data["output_size"], data["version"]) + + def save(self): + data = { + "input_size": self.input_size, + "hidden_size": self.hidden_size, + "output_size": self.output_size, + "version": self.version, + } + return json.dumps(data) + + +def auto_train(data_path, save_path="model.pth", load_model: str = None, save_epochs=1): + data_x, data_y = [], [] + + if load_model and os.path.isfile(load_model): + print("Loading model from", load_model) + model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda") + else: + print("Creating new model.") + model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm + save_path = os.path.join(data_path, save_path) + base_save_path = ".".join(save_path.split(".")[:-1]) + + sem_string = "_semantic.npy" + feat_string = "_semantic_features.npy" + + ready = os.path.join(data_path, "ready") + for input_file in os.listdir(ready): + full_path = os.path.join(ready, input_file) + if input_file.endswith(sem_string): + data_y.append(numpy.load(full_path)) + elif input_file.endswith(feat_string): + data_x.append(numpy.load(full_path)) + model_training.prepare_training() + + epoch = 1 + + while 1: + for _ in range(save_epochs): + j = 0 + for x, y in zip(data_x, data_y): + model_training.train_step( + torch.tensor(x).to("cuda"), torch.tensor(y).to("cuda"), j % 50 == 0 + ) # Print loss every 50 steps + j += 1 + save_p = save_path + save_p_2 = f"{base_save_path}_epoch_{epoch}.pth" + model_training.save(save_p) + model_training.save(save_p_2) + print(f"Epoch {epoch} completed") + epoch += 1 diff --git a/viXTTS/TTS/tts/layers/bark/inference_funcs.py b/viXTTS/TTS/tts/layers/bark/inference_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d3fee9371fae0cd06187c967a5b0028940138e --- /dev/null +++ b/viXTTS/TTS/tts/layers/bark/inference_funcs.py @@ -0,0 +1,606 @@ +import logging +import os +import re +from glob import glob +from typing import Dict, List + +import librosa +import numpy as np +import torch +import torchaudio +import tqdm +from encodec.utils import convert_audio +from scipy.special import softmax +from torch.nn import functional as F + +from TTS.tts.layers.bark.hubert.hubert_manager import HubertManager +from TTS.tts.layers.bark.hubert.kmeans_hubert import CustomHubert +from TTS.tts.layers.bark.hubert.tokenizer import HubertTokenizer +from TTS.tts.layers.bark.load_model import clear_cuda_cache, inference_mode + +logger = logging.getLogger(__name__) + + +def _tokenize(tokenizer, text): + return tokenizer.encode(text, add_special_tokens=False) + + +def _detokenize(tokenizer, enc_text): + return tokenizer.decode(enc_text) + + +def _normalize_whitespace(text): + return re.sub(r"\s+", " ", text).strip() + + +def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value + dirs = extra_voice_dirs + voices: Dict[str, List[str]] = {} + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f"{subj}/*.npz")) + # fetch audio files if no npz files are found + if len(voices[sub]) == 0: + voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + return voices + + +def load_npz(npz_file): + x_history = np.load(npz_file) + semantic = x_history["semantic_prompt"] + coarse = x_history["coarse_prompt"] + fine = x_history["fine_prompt"] + return semantic, coarse, fine + + +def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value + if voice == "random": + return None, None, None + + voices = get_voices(extra_voice_dirs) + paths = voices[voice] + + # bark only uses a single sample for cloning + if len(paths) > 1: + raise ValueError(f"Voice {voice} has multiple paths: {paths}") + + try: + path = voices[voice] + except KeyError as e: + raise KeyError(f"Voice {voice} not found in {extra_voice_dirs}") from e + + if len(paths) == 1 and paths[0].endswith(".npz"): + return load_npz(path[0]) + + audio_path = paths[0] + # replace the file extension with .npz + output_path = os.path.splitext(audio_path)[0] + ".npz" + generate_voice(audio=audio_path, model=model, output_path=output_path) + return load_voice(model, voice, extra_voice_dirs) + + +def zero_crossing_rate(audio, frame_length=1024, hop_length=512): + zero_crossings = np.sum(np.abs(np.diff(np.sign(audio))) / 2) + total_frames = 1 + int((len(audio) - frame_length) / hop_length) + return zero_crossings / total_frames + + +def compute_spectral_contrast(audio_data, sample_rate, n_bands=6, fmin=200.0): + spectral_contrast = librosa.feature.spectral_contrast(y=audio_data, sr=sample_rate, n_bands=n_bands, fmin=fmin) + return np.mean(spectral_contrast) + + +def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250): + stft = librosa.stft(audio_data) + power_spectrogram = np.abs(stft) ** 2 + frequencies = librosa.fft_frequencies(sr=sample_rate, n_fft=stft.shape[0]) + bass_mask = frequencies <= max_bass_freq + bass_energy = power_spectrogram[np.ix_(bass_mask, np.arange(power_spectrogram.shape[1]))].mean() + return bass_energy + + +def generate_voice( + audio, + model, + output_path, +): + """Generate a new voice from a given audio and text prompt. + + Args: + audio (np.ndarray): The audio to use as a base for the new voice. + text (str): Transcription of the audio you are clonning. + model (BarkModel): The BarkModel to use for generating the new voice. + output_path (str): The path to save the generated voice to. + """ + if isinstance(audio, str): + audio, sr = torchaudio.load(audio) + audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels) + audio = audio.unsqueeze(0).to(model.device) + + with torch.no_grad(): + encoded_frames = model.encodec.encode(audio) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] + + # move codes to cpu + codes = codes.cpu().numpy() + + # generate semantic tokens + # Load the HuBERT model + hubert_manager = HubertManager() + # hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) + hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) + + hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) + + # Load the CustomTokenizer model + tokenizer = HubertTokenizer.load_from_checkpoint( + model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device + ) + # semantic_tokens = model.text_to_semantic( + # text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7 + # ) # not 100% + semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate) + semantic_tokens = tokenizer.get_token(semantic_vectors) + semantic_tokens = semantic_tokens.cpu().numpy() + + np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) + + +def generate_text_semantic( + text, + model, + history_prompt=None, + temp=0.7, + top_k=None, + top_p=None, + silent=False, + min_eos_p=0.2, + max_gen_duration_s=None, + allow_early_stop=True, + base=None, + use_kv_caching=True, + **kwargs, # pylint: disable=unused-argument +): + """Generate semantic tokens from text. + + Args: + text (str): The text to generate semantic tokens from. + model (BarkModel): The BarkModel to use for generating the semantic tokens. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + top_k (int): The number of top tokens to consider for the generation. + top_p (float): The cumulative probability to consider for the generation. + silent (bool): Whether to silence the tqdm progress bar. + min_eos_p (float): The minimum probability to consider for the end of sentence token. + max_gen_duration_s (float): The maximum duration in seconds to generate for. + allow_early_stop (bool): Whether to allow the generation to stop early. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + use_kv_caching (bool): Whether to use key-value caching for the generation. + **kwargs: Additional keyword arguments. They are ignored. + + Returns: + np.ndarray: The generated semantic tokens. + """ + assert isinstance(text, str) + text = _normalize_whitespace(text) + assert len(text.strip()) > 0 + if all(v is not None for v in history_prompt) or base is not None: + if history_prompt is not None: + semantic_history = history_prompt[0] + if base is not None: + semantic_history = base[0] + assert ( + isinstance(semantic_history, np.ndarray) + and len(semantic_history.shape) == 1 + and len(semantic_history) > 0 + and semantic_history.min() >= 0 + and semantic_history.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + ) + else: + semantic_history = None + encoded_text = np.array(_tokenize(model.tokenizer, text)) + model.config.TEXT_ENCODING_OFFSET + if len(encoded_text) > 256: + p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) + logger.warning(f"warning, text too long, lopping of last {p}%") + encoded_text = encoded_text[:256] + encoded_text = np.pad( + encoded_text, + (0, 256 - len(encoded_text)), + constant_values=model.config.TEXT_PAD_TOKEN, + mode="constant", + ) + if semantic_history is not None: + semantic_history = semantic_history.astype(np.int64) + # lop off if history is too long, pad if needed + semantic_history = semantic_history[-256:] + semantic_history = np.pad( + semantic_history, + (0, 256 - len(semantic_history)), + constant_values=model.config.SEMANTIC_PAD_TOKEN, + mode="constant", + ) + else: + semantic_history = np.array([model.config.SEMANTIC_PAD_TOKEN] * 256) + x = torch.from_numpy( + np.hstack([encoded_text, semantic_history, np.array([model.config.SEMANTIC_INFER_TOKEN])]).astype(np.int64) + )[None] + assert x.shape[1] == 256 + 256 + 1 + with inference_mode(): + x = x.to(model.device) + n_tot_steps = 768 + # custom tqdm updates since we don't know when eos will occur + pbar = tqdm.tqdm(disable=silent, total=100) + pbar_state = 0 + tot_generated_duration_s = 0 + kv_cache = None + for n in range(n_tot_steps): + if use_kv_caching and kv_cache is not None: + x_input = x[:, [-1]] + else: + x_input = x + logits, kv_cache = model.semantic_model( + x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache + ) + relevant_logits = logits[0, 0, : model.config.SEMANTIC_VOCAB_SIZE] + if allow_early_stop: + relevant_logits = torch.hstack( + (relevant_logits, logits[0, 0, [model.config.SEMANTIC_PAD_TOKEN]]) + ) # eos + if top_p is not None: + # faster to convert to numpy + logits_device = relevant_logits.device + logits_dtype = relevant_logits.type() + relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() + sorted_indices = np.argsort(relevant_logits)[::-1] + sorted_logits = relevant_logits[sorted_indices] + cumulative_probs = np.cumsum(softmax(sorted_logits)) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() + sorted_indices_to_remove[0] = False + relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf + relevant_logits = torch.from_numpy(relevant_logits) + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: + v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) + relevant_logits[relevant_logits < v[-1]] = -float("Inf") + probs = torch.softmax(relevant_logits / temp, dim=-1) + item_next = torch.multinomial(probs, num_samples=1) + if allow_early_stop and ( + item_next == model.config.SEMANTIC_VOCAB_SIZE or (min_eos_p is not None and probs[-1] >= min_eos_p) + ): + # eos found, so break + pbar.update(100 - pbar_state) + break + x = torch.cat((x, item_next[None]), dim=1) + tot_generated_duration_s += 1 / model.config.SEMANTIC_RATE_HZ + if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s: + pbar.update(100 - pbar_state) + break + if n == n_tot_steps - 1: + pbar.update(100 - pbar_state) + break + del logits, relevant_logits, probs, item_next + req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))]) + if req_pbar_state > pbar_state: + pbar.update(req_pbar_state - pbar_state) + pbar_state = req_pbar_state + pbar.close() + out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] + assert all(out >= 0) and all(out < model.config.SEMANTIC_VOCAB_SIZE) + clear_cuda_cache() + return out + + +def _flatten_codebooks(arr, offset_size): + assert len(arr.shape) == 2 + arr = arr.copy() + if offset_size is not None: + for n in range(1, arr.shape[0]): + arr[n, :] += offset_size * n + flat_arr = arr.ravel("F") + return flat_arr + + +def generate_coarse( + x_semantic, + model, + history_prompt=None, + temp=0.7, + top_k=None, + top_p=None, + silent=False, + max_coarse_history=630, # min 60 (faster), max 630 (more context) + sliding_window_len=60, + base=None, + use_kv_caching=True, +): + """Generate coarse audio codes from semantic tokens. + + Args: + x_semantic (np.ndarray): The semantic tokens to generate coarse audio codes from. + model (BarkModel): The BarkModel to use for generating the coarse audio codes. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + top_k (int): The number of top tokens to consider for the generation. + top_p (float): The cumulative probability to consider for the generation. + silent (bool): Whether to silence the tqdm progress bar. + max_coarse_history (int): The maximum number of coarse audio codes to use as history. + sliding_window_len (int): The length of the sliding window to use for the generation. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + use_kv_caching (bool): Whether to use key-value caching for the generation. + + Returns: + np.ndarray: The generated coarse audio codes. + """ + assert ( + isinstance(x_semantic, np.ndarray) + and len(x_semantic.shape) == 1 + and len(x_semantic) > 0 + and x_semantic.min() >= 0 + and x_semantic.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + ) + assert 60 <= max_coarse_history <= 630 + assert max_coarse_history + sliding_window_len <= 1024 - 256 + semantic_to_coarse_ratio = ( + model.config.COARSE_RATE_HZ / model.config.SEMANTIC_RATE_HZ * model.config.N_COARSE_CODEBOOKS + ) + max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) + if all(v is not None for v in history_prompt) or base is not None: + if history_prompt is not None: + x_history = history_prompt + x_semantic_history = x_history[0] + x_coarse_history = x_history[1] + if base is not None: + x_semantic_history = base[0] + x_coarse_history = base[1] + assert ( + isinstance(x_semantic_history, np.ndarray) + and len(x_semantic_history.shape) == 1 + and len(x_semantic_history) > 0 + and x_semantic_history.min() >= 0 + and x_semantic_history.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + and isinstance(x_coarse_history, np.ndarray) + and len(x_coarse_history.shape) == 2 + and x_coarse_history.shape[0] == model.config.N_COARSE_CODEBOOKS + and x_coarse_history.shape[-1] >= 0 + and x_coarse_history.min() >= 0 + and x_coarse_history.max() <= model.config.CODEBOOK_SIZE - 1 + and ( + round(x_coarse_history.shape[-1] / len(x_semantic_history), 1) + == round(semantic_to_coarse_ratio / model.config.N_COARSE_CODEBOOKS, 1) + ) + ) + x_coarse_history = ( + _flatten_codebooks(x_coarse_history, model.config.CODEBOOK_SIZE) + model.config.SEMANTIC_VOCAB_SIZE + ) + # trim histories correctly + n_semantic_hist_provided = np.min( + [ + max_semantic_history, + len(x_semantic_history) - len(x_semantic_history) % 2, + int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)), + ] + ) + n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio)) + x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32) + x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32) + # TODO: bit of a hack for time alignment (sounds better) + x_coarse_history = x_coarse_history[:-2] + else: + x_semantic_history = np.array([], dtype=np.int32) + x_coarse_history = np.array([], dtype=np.int32) + # start loop + n_steps = int( + round( + np.floor(len(x_semantic) * semantic_to_coarse_ratio / model.config.N_COARSE_CODEBOOKS) + * model.config.N_COARSE_CODEBOOKS + ) + ) + assert n_steps > 0 and n_steps % model.config.N_COARSE_CODEBOOKS == 0 + x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32) + x_coarse = x_coarse_history.astype(np.int32) + base_semantic_idx = len(x_semantic_history) + with inference_mode(): + x_semantic_in = torch.from_numpy(x_semantic)[None].to(model.device) + x_coarse_in = torch.from_numpy(x_coarse)[None].to(model.device) + n_window_steps = int(np.ceil(n_steps / sliding_window_len)) + n_step = 0 + for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent): + semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio)) + # pad from right side + x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :] + x_in = x_in[:, :256] + x_in = F.pad( + x_in, + (0, 256 - x_in.shape[-1]), + "constant", + model.config.COARSE_SEMANTIC_PAD_TOKEN, + ) + x_in = torch.hstack( + [ + x_in, + torch.tensor([model.config.COARSE_INFER_TOKEN])[None].to(model.device), + x_coarse_in[:, -max_coarse_history:], + ] + ) + kv_cache = None + for _ in range(sliding_window_len): + if n_step >= n_steps: + continue + is_major_step = n_step % model.config.N_COARSE_CODEBOOKS == 0 + + if use_kv_caching and kv_cache is not None: + x_input = x_in[:, [-1]] + else: + x_input = x_in + + logits, kv_cache = model.coarse_model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) + logit_start_idx = ( + model.config.SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * model.config.CODEBOOK_SIZE + ) + logit_end_idx = model.config.SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * model.config.CODEBOOK_SIZE + relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] + if top_p is not None: + # faster to convert to numpy + logits_device = relevant_logits.device + logits_dtype = relevant_logits.type() + relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() + sorted_indices = np.argsort(relevant_logits)[::-1] + sorted_logits = relevant_logits[sorted_indices] + cumulative_probs = np.cumsum(torch.nn.functional.softmax(sorted_logits)) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() + sorted_indices_to_remove[0] = False + relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf + relevant_logits = torch.from_numpy(relevant_logits) + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: + v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) + relevant_logits[relevant_logits < v[-1]] = -float("Inf") + probs = torch.nn.functional.softmax(relevant_logits / temp, dim=-1) + item_next = torch.multinomial(probs, num_samples=1) + item_next += logit_start_idx + x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) + x_in = torch.cat((x_in, item_next[None]), dim=1) + del logits, relevant_logits, probs, item_next + n_step += 1 + del x_in + del x_semantic_in + gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] + del x_coarse_in + assert len(gen_coarse_arr) == n_steps + gen_coarse_audio_arr = ( + gen_coarse_arr.reshape(-1, model.config.N_COARSE_CODEBOOKS).T - model.config.SEMANTIC_VOCAB_SIZE + ) + for n in range(1, model.config.N_COARSE_CODEBOOKS): + gen_coarse_audio_arr[n, :] -= n * model.config.CODEBOOK_SIZE + clear_cuda_cache() + return gen_coarse_audio_arr + + +def generate_fine( + x_coarse_gen, + model, + history_prompt=None, + temp=0.5, + silent=True, + base=None, +): + """Generate full audio codes from coarse audio codes. + + Args: + x_coarse_gen (np.ndarray): The coarse audio codes to generate full audio codes from. + model (BarkModel): The BarkModel to use for generating the full audio codes. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + silent (bool): Whether to silence the tqdm progress bar. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + + Returns: + np.ndarray: The generated full audio codes. + """ + assert ( + isinstance(x_coarse_gen, np.ndarray) + and len(x_coarse_gen.shape) == 2 + and 1 <= x_coarse_gen.shape[0] <= model.config.N_FINE_CODEBOOKS - 1 + and x_coarse_gen.shape[1] > 0 + and x_coarse_gen.min() >= 0 + and x_coarse_gen.max() <= model.config.CODEBOOK_SIZE - 1 + ) + if all(v is not None for v in history_prompt) or base is not None: + if history_prompt is not None: + x_fine_history = history_prompt[2] + if base is not None: + x_fine_history = base[2] + assert ( + isinstance(x_fine_history, np.ndarray) + and len(x_fine_history.shape) == 2 + and x_fine_history.shape[0] == model.config.N_FINE_CODEBOOKS + and x_fine_history.shape[1] >= 0 + and x_fine_history.min() >= 0 + and x_fine_history.max() <= model.config.CODEBOOK_SIZE - 1 + ) + else: + x_fine_history = None + n_coarse = x_coarse_gen.shape[0] + # make input arr + in_arr = np.vstack( + [ + x_coarse_gen, + np.zeros((model.config.N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1])) + + model.config.CODEBOOK_SIZE, # padding + ] + ).astype(np.int32) + # prepend history if available (max 512) + if x_fine_history is not None: + x_fine_history = x_fine_history.astype(np.int32) + in_arr = np.hstack( + [ + x_fine_history[:, -512:].astype(np.int32), + in_arr, + ] + ) + n_history = x_fine_history[:, -512:].shape[1] + else: + n_history = 0 + n_remove_from_end = 0 + # need to pad if too short (since non-causal model) + if in_arr.shape[1] < 1024: + n_remove_from_end = 1024 - in_arr.shape[1] + in_arr = np.hstack( + [ + in_arr, + np.zeros((model.config.N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + + model.config.CODEBOOK_SIZE, + ] + ) + # we can be lazy about fractional loop and just keep overwriting codebooks + n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1 + with inference_mode(): + in_arr = torch.tensor(in_arr.T).to(model.device) + for n in tqdm.tqdm(range(n_loops), disable=silent): + start_idx = np.min([n * 512, in_arr.shape[0] - 1024]) + start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512]) + rel_start_fill_idx = start_fill_idx - start_idx + in_buffer = in_arr[start_idx : start_idx + 1024, :][None] + for nn in range(n_coarse, model.config.N_FINE_CODEBOOKS): + logits = model.fine_model(nn, in_buffer) + if temp is None: + relevant_logits = logits[0, rel_start_fill_idx:, : model.config.CODEBOOK_SIZE] + codebook_preds = torch.argmax(relevant_logits, -1) + else: + relevant_logits = logits[0, :, : model.config.CODEBOOK_SIZE] / temp + probs = F.softmax(relevant_logits, dim=-1) + codebook_preds = torch.hstack( + [torch.multinomial(probs[n], num_samples=1) for n in range(rel_start_fill_idx, 1024)] + ) + in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds + del logits, codebook_preds + # transfer over info into model_in and convert to numpy + for nn in range(n_coarse, model.config.N_FINE_CODEBOOKS): + in_arr[start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn] = in_buffer[ + 0, rel_start_fill_idx:, nn + ] + del in_buffer + gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T + del in_arr + gen_fine_arr = gen_fine_arr[:, n_history:] + if n_remove_from_end > 0: + gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] + assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1] + clear_cuda_cache() + return gen_fine_arr + + +def codec_decode(fine_tokens, model): + """Turn quantized audio codes into audio array using encodec.""" + arr = torch.from_numpy(fine_tokens)[None] + arr = arr.to(model.device) + arr = arr.transpose(0, 1) + emb = model.encodec.quantizer.decode(arr) + out = model.encodec.decoder(emb) + audio_arr = out.detach().cpu().numpy().squeeze() + return audio_arr diff --git a/viXTTS/TTS/tts/layers/bark/load_model.py b/viXTTS/TTS/tts/layers/bark/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ce6b757f054ce98b91601b494854ef8e7b56b131 --- /dev/null +++ b/viXTTS/TTS/tts/layers/bark/load_model.py @@ -0,0 +1,160 @@ +import contextlib +import functools +import hashlib +import logging +import os + +import requests +import torch +import tqdm + +from TTS.tts.layers.bark.model import GPT, GPTConfig +from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig + +if ( + torch.cuda.is_available() + and hasattr(torch.cuda, "amp") + and hasattr(torch.cuda.amp, "autocast") + and torch.cuda.is_bf16_supported() +): + autocast = functools.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) +else: + + @contextlib.contextmanager + def autocast(): + yield + + +# hold models in global scope to lazy load + +logger = logging.getLogger(__name__) + + +if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + logger.warning( + "torch version does not support flash attention. You will get significantly faster" + + " inference speed by upgrade torch to newest version / nightly." + ) + + +def _md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def _download(from_s3_path, to_local_path, CACHE_DIR): + os.makedirs(CACHE_DIR, exist_ok=True) + response = requests.get(from_s3_path, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(to_local_path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes not in [0, progress_bar.n]: + raise ValueError("ERROR, something went wrong") + + +class InferenceContext: + def __init__(self, benchmark=False): + # we can't expect inputs to be the same length, so disable benchmarking by default + self._chosen_cudnn_benchmark = benchmark + self._cudnn_benchmark = None + + def __enter__(self): + self._cudnn_benchmark = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark + + def __exit__(self, exc_type, exc_value, exc_traceback): + torch.backends.cudnn.benchmark = self._cudnn_benchmark + + +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +@contextlib.contextmanager +def inference_mode(): + with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): + yield + + +def clear_cuda_cache(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def load_model(ckpt_path, device, config, model_type="text"): + logger.info(f"loading {model_type} model from {ckpt_path}...") + + if device == "cpu": + logger.warning("No GPU being used. Careful, Inference might be extremely slow!") + if model_type == "text": + ConfigClass = GPTConfig + ModelClass = GPT + elif model_type == "coarse": + ConfigClass = GPTConfig + ModelClass = GPT + elif model_type == "fine": + ConfigClass = FineGPTConfig + ModelClass = FineGPT + else: + raise NotImplementedError() + if ( + not config.USE_SMALLER_MODELS + and os.path.exists(ckpt_path) + and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"] + ): + logger.warning(f"found outdated {model_type} model, removing...") + os.remove(ckpt_path) + if not os.path.exists(ckpt_path): + logger.info(f"{model_type} model not found, downloading...") + _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) + + checkpoint = torch.load(ckpt_path, map_location=device) + # this is a hack + model_args = checkpoint["model_args"] + if "input_vocab_size" not in model_args: + model_args["input_vocab_size"] = model_args["vocab_size"] + model_args["output_vocab_size"] = model_args["vocab_size"] + del model_args["vocab_size"] + + gptconf = ConfigClass(**checkpoint["model_args"]) + if model_type == "text": + config.semantic_config = gptconf + elif model_type == "coarse": + config.coarse_config = gptconf + elif model_type == "fine": + config.fine_config = gptconf + + model = ModelClass(gptconf) + state_dict = checkpoint["model"] + # fixup checkpoint + unwanted_prefix = "_orig_mod." + for k, _ in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) + extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) + extra_keys = set(k for k in extra_keys if not k.endswith(".attn.bias")) + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + missing_keys = set(k for k in missing_keys if not k.endswith(".attn.bias")) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + model.load_state_dict(state_dict, strict=False) + n_params = model.get_num_params() + val_loss = checkpoint["best_val_loss"].item() + logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") + model.eval() + model.to(device) + del checkpoint, state_dict + clear_cuda_cache() + return model, config diff --git a/viXTTS/TTS/tts/layers/bark/model.py b/viXTTS/TTS/tts/layers/bark/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c84022bd08bcdd2f3f9f3caadfc15a7bf80ddaf3 --- /dev/null +++ b/viXTTS/TTS/tts/layers/bark/model.py @@ -0,0 +1,233 @@ +""" +Much of this code is adapted from Andrej Karpathy's NanoGPT +(https://github.com/karpathy/nanoGPT) +""" +import math +from dataclasses import dataclass + +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, x): + return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") + if not self.flash: + # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + + def forward(self, x, past_kv=None, use_cache=False): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = (k, v) + else: + present = None + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + if past_kv is not None: + # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains + # the query for the last token. scaled_dot_product_attention interprets this as the first token in the + # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so + # to work around this we set is_causal=False. + is_causal = False + else: + is_causal = True + + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, FULL_T - T : FULL_T, :FULL_T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return (y, present) + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + self.gelu = nn.GELU() + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + self.layer_idx = layer_idx + + def forward(self, x, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) + x = x + attn_output + x = x + self.mlp(self.ln_2(x)) + return (x, prev_kvs) + + +@dataclass +class GPTConfig(Coqpit): + block_size: int = 1024 + input_vocab_size: int = 10_048 + output_vocab_size: int = 10_048 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + assert config.input_vocab_size is not None + assert config.output_vocab_size is not None + assert config.block_size is not None + self.config = config + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.input_vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), + ) + ) + self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wte.weight.numel() + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): + device = idx.device + _, t = idx.size() + if past_kv is not None: + assert t == 1 + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + else: + if merge_context: + assert idx.shape[1] >= 256 + 256 + 1 + t = idx.shape[1] - 256 + else: + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the GPT model itself + if merge_context: + tok_emb = torch.cat( + [ + self.transformer.wte(idx[:, :256]) + self.transformer.wte(idx[:, 256 : 256 + 256]), + self.transformer.wte(idx[:, 256 + 256 :]), + ], + dim=1, + ) + else: + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if past_kv is None: + past_length = 0 + past_kv = tuple([None] * len(self.transformer.h)) + else: + past_length = past_kv[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) + + x = self.transformer.drop(tok_emb + pos_emb) + + new_kv = () if use_cache else None + + for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) + + if use_cache: + new_kv = new_kv + (kv,) + + x = self.transformer.ln_f(x) + + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + + return (logits, new_kv) diff --git a/viXTTS/TTS/tts/layers/bark/model_fine.py b/viXTTS/TTS/tts/layers/bark/model_fine.py new file mode 100644 index 0000000000000000000000000000000000000000..09e5f4765dce8743db2a3ed879e7811d2b9d23d6 --- /dev/null +++ b/viXTTS/TTS/tts/layers/bark/model_fine.py @@ -0,0 +1,142 @@ +""" +Much of this code is adapted from Andrej Karpathy's NanoGPT +(https://github.com/karpathy/nanoGPT) +""" +import math +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn import functional as F + +from .model import GPT, MLP, GPTConfig + + +class NonCausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0 + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class FineBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) + self.attn = NonCausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class FineGPT(GPT): + def __init__(self, config): + super().__init__(config) + del self.lm_head + self.config = config + self.n_codes_total = config.n_codes_total + self.transformer = nn.ModuleDict( + dict( + wtes=nn.ModuleList( + [nn.Embedding(config.input_vocab_size, config.n_embd) for _ in range(config.n_codes_total)] + ), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]), + ln_f=nn.LayerNorm(config.n_embd), + ) + ) + self.lm_heads = nn.ModuleList( + [ + nn.Linear(config.n_embd, config.output_vocab_size, bias=False) + for _ in range(config.n_codes_given, self.n_codes_total) + ] + ) + for i in range(self.n_codes_total - config.n_codes_given): + self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight + + def forward(self, pred_idx, idx): + device = idx.device + b, t, codes = idx.size() + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + assert pred_idx > 0, "cannot predict 0th codebook" + assert codes == self.n_codes_total, (b, t, codes) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + + # forward the GPT model itself + tok_embs = [ + wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes) + ] # token embeddings of shape (b, t, n_embd) + tok_emb = torch.cat(tok_embs, dim=-1) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1) + x = self.transformer.drop(x + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_heads[pred_idx - self.config.n_codes_given](x) + return logits + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + for wte in self.transformer.wtes: + n_params -= wte.weight.numel() + n_params -= self.transformer.wpe.weight.numel() + return n_params + + +@dataclass +class FineGPTConfig(GPTConfig): + n_codes_total: int = 8 + n_codes_given: int = 1 diff --git a/viXTTS/TTS/tts/layers/delightful_tts/__init__.py b/viXTTS/TTS/tts/layers/delightful_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/delightful_tts/acoustic_model.py b/viXTTS/TTS/tts/layers/delightful_tts/acoustic_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c906b882e567fade64139a8b932c71d554117547 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -0,0 +1,563 @@ +### credit: https://github.com/dunky11/voicesmith +from typing import Callable, Dict, Tuple + +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.delightful_tts.conformer import Conformer +from TTS.tts.layers.delightful_tts.encoders import ( + PhonemeLevelProsodyEncoder, + UtteranceLevelProsodyEncoder, + get_mask_from_lengths, +) +from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor +from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding +from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProsodyPredictor +from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask + + +class AcousticModel(torch.nn.Module): + def __init__( + self, + args: "ModelArgs", + tokenizer: "TTSTokenizer" = None, + speaker_manager: "SpeakerManager" = None, + ): + super().__init__() + self.args = args + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + + self.init_multispeaker(args) + # self.set_embedding_dims() + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb_dim = args.n_hidden_conformer_encoder + self.encoder = Conformer( + dim=self.args.n_hidden_conformer_encoder, + n_layers=self.args.n_layers_conformer_encoder, + n_heads=self.args.n_heads_conformer_encoder, + speaker_embedding_dim=self.embedded_speaker_dim, + p_dropout=self.args.dropout_conformer_encoder, + kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + self.pitch_adaptor = PitchAdaptor( + n_input=self.args.n_hidden_conformer_encoder, + n_hidden=self.args.n_hidden_variance_adaptor, + n_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, + p_dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + self.energy_adaptor = EnergyAdaptor( + channels_in=self.args.n_hidden_conformer_encoder, + channels_hidden=self.args.n_hidden_variance_adaptor, + channels_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, + dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, + in_key_channels=self.args.n_hidden_conformer_encoder, + ) + + self.duration_predictor = VariancePredictor( + channels_in=self.args.n_hidden_conformer_encoder, + channels=self.args.n_hidden_variance_adaptor, + channels_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + p_dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + + self.utterance_prosody_encoder = UtteranceLevelProsodyEncoder( + num_mels=self.args.num_mels, + ref_enc_filters=self.args.ref_enc_filters_reference_encoder, + ref_enc_size=self.args.ref_enc_size_reference_encoder, + ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, + ref_enc_strides=self.args.ref_enc_strides_reference_encoder, + n_hidden=self.args.n_hidden_conformer_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size_u=self.args.bottleneck_size_u_reference_encoder, + token_num=self.args.token_num_reference_encoder, + ) + + self.utterance_prosody_predictor = PhonemeProsodyPredictor( + hidden_size=self.args.n_hidden_conformer_encoder, + kernel_size=self.args.predictor_kernel_size_reference_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size=self.args.bottleneck_size_u_reference_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + + self.phoneme_prosody_encoder = PhonemeLevelProsodyEncoder( + num_mels=self.args.num_mels, + ref_enc_filters=self.args.ref_enc_filters_reference_encoder, + ref_enc_size=self.args.ref_enc_size_reference_encoder, + ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, + ref_enc_strides=self.args.ref_enc_strides_reference_encoder, + n_hidden=self.args.n_hidden_conformer_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size_p=self.args.bottleneck_size_p_reference_encoder, + n_heads=self.args.n_heads_conformer_encoder, + ) + + self.phoneme_prosody_predictor = PhonemeProsodyPredictor( + hidden_size=self.args.n_hidden_conformer_encoder, + kernel_size=self.args.predictor_kernel_size_reference_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size=self.args.bottleneck_size_p_reference_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + + self.u_bottle_out = nn.Linear( + self.args.bottleneck_size_u_reference_encoder, + self.args.n_hidden_conformer_encoder, + ) + + self.u_norm = nn.InstanceNorm1d(self.args.bottleneck_size_u_reference_encoder) + self.p_bottle_out = nn.Linear( + self.args.bottleneck_size_p_reference_encoder, + self.args.n_hidden_conformer_encoder, + ) + self.p_norm = nn.InstanceNorm1d( + self.args.bottleneck_size_p_reference_encoder, + ) + self.decoder = Conformer( + dim=self.args.n_hidden_conformer_decoder, + n_layers=self.args.n_layers_conformer_decoder, + n_heads=self.args.n_heads_conformer_decoder, + speaker_embedding_dim=self.embedded_speaker_dim, + p_dropout=self.args.dropout_conformer_decoder, + kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_decoder, + lrelu_slope=self.args.lrelu_slope, + ) + + padding_idx = self.tokenizer.characters.pad_id + self.src_word_emb = EmbeddingPadded( + self.args.num_chars, self.args.n_hidden_conformer_encoder, padding_idx=padding_idx + ) + self.to_mel = nn.Linear( + self.args.n_hidden_conformer_decoder, + self.args.num_mels, + ) + + self.energy_scaler = torch.nn.BatchNorm1d(1, affine=False, track_running_stats=True, momentum=None) + self.energy_scaler.requires_grad_(False) + + def init_multispeaker(self, args: Coqpit): # pylint: disable=unused-argument + """Init for multi-speaker training.""" + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid, durations = None, None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]) # .unsqueeze_(-1) + if g.ndim == 2: + g = g # .unsqueeze_(0) # pylint: disable=self-assigning-variable + + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations + + def get_aux_input(self, aux_input: Dict): + sid, g, lid, _ = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + # def set_embedding_dims(self): + # if self.embedded_speaker_dim > 0: + # self.embedding_dims = self.embedded_speaker_dim + # else: + # self.embedding_dims = 0 + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the linear scale durations. + + Args: + dr (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations + if None. Defaults to None. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def _expand_encoder_with_durations( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.IntTensor, + y_lengths: torch.IntTensor, + ): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en]) + return y_mask, o_en_ex, attn.transpose(1, 2) + + def _forward_aligner( + self, + x: torch.FloatTensor, + y: torch.FloatTensor, + x_mask: torch.IntTensor, + y_mask: torch.IntTensor, + attn_priors: torch.FloatTensor, + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + attn_priors (torch.FloatTensor): Prior for the aligner network map. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + - attn_priors: :math:`[B, T_de, T_en]` + + - aligner_durations: :math:`[B, T_en]` + - aligner_soft: :math:`[B, T_de, T_en]` + - aligner_logprob: :math:`[B, 1, T_de, T_en]` + - aligner_mas: :math:`[B, T_de, T_en]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # [B, 1, T_en, T_de] + aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, attn_priors) + aligner_mas = maximum_path( + aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + aligner_durations = torch.sum(aligner_mas, -1).int() + aligner_soft = aligner_soft.squeeze(1) # [B, T_max2, T_max] + aligner_mas = aligner_mas.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max] + return aligner_durations, aligner_soft, aligner_logprob, aligner_mas + + def average_utterance_prosody( # pylint: disable=no-self-use + self, u_prosody_pred: torch.Tensor, src_mask: torch.Tensor + ) -> torch.Tensor: + lengths = ((~src_mask) * 1.0).sum(1) + u_prosody_pred = u_prosody_pred.sum(1, keepdim=True) / lengths.view(-1, 1, 1) + return u_prosody_pred + + def forward( + self, + tokens: torch.Tensor, + src_lens: torch.Tensor, + mels: torch.Tensor, + mel_lens: torch.Tensor, + pitches: torch.Tensor, + energies: torch.Tensor, + attn_priors: torch.Tensor, + use_ground_truth: bool = True, + d_vectors: torch.Tensor = None, + speaker_idx: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable + {"d_vectors": d_vectors, "speaker_ids": speaker_idx} + ) # pylint: disable=unused-variable + + src_mask = get_mask_from_lengths(src_lens) # [B, T_src] + mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel] + + # Token embeddings + token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden] + token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) + + # Alignment network and durations + aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner( + x=token_embeddings, + y=mels.transpose(1, 2), + x_mask=~src_mask[:, None], + y_mask=~mel_mask[:, None], + attn_priors=attn_priors, + ) + dr = aligner_durations # [B, T_en] + + # Embeddings + speaker_embedding = None + if d_vectors is not None: + speaker_embedding = g + elif speaker_idx is not None: + speaker_embedding = F.normalize(self.emb_g(sid)) + + pos_encoding = positional_encoding( + self.emb_dim, + max(token_embeddings.shape[1], max(mel_lens)), + device=token_embeddings.device, + ) + encoder_outputs = self.encoder( + token_embeddings, + src_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + + u_prosody_ref = self.u_norm(self.utterance_prosody_encoder(mels=mels, mel_lens=mel_lens)) + u_prosody_pred = self.u_norm( + self.average_utterance_prosody( + u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), + src_mask=src_mask, + ) + ) + + if use_ground_truth: + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_ref) + else: + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred) + + p_prosody_ref = self.p_norm( + self.phoneme_prosody_encoder( + x=encoder_outputs, src_mask=src_mask, mels=mels, mel_lens=mel_lens, encoding=pos_encoding + ) + ) + p_prosody_pred = self.p_norm(self.phoneme_prosody_predictor(x=encoder_outputs, mask=src_mask)) + + if use_ground_truth: + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_ref) + else: + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred) + + encoder_outputs_res = encoder_outputs + + pitch_pred, avg_pitch_target, pitch_emb = self.pitch_adaptor.get_pitch_embedding_train( + x=encoder_outputs, + target=pitches, + dr=dr, + mask=src_mask, + ) + + energy_pred, avg_energy_target, energy_emb = self.energy_adaptor.get_energy_embedding_train( + x=encoder_outputs, + target=energies, + dr=dr, + mask=src_mask, + ) + + encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb + log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask) + + mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( + o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None] + ) + + x = self.decoder( + encoder_outputs_ex.transpose(1, 2), + mel_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + x = self.to_mel(x) + + dr = torch.log(dr + 1) + + dr_pred = torch.exp(log_duration_prediction) - 1 + alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2'] + + return { + "model_outputs": x, + "pitch_pred": pitch_pred, + "pitch_target": avg_pitch_target, + "energy_pred": energy_pred, + "energy_target": avg_energy_target, + "u_prosody_pred": u_prosody_pred, + "u_prosody_ref": u_prosody_ref, + "p_prosody_pred": p_prosody_pred, + "p_prosody_ref": p_prosody_ref, + "alignments_dp": alignments_dp, + "alignments": alignments, # [B, T_de, T_en] + "aligner_soft": aligner_soft, + "aligner_mas": aligner_mas, + "aligner_durations": aligner_durations, + "aligner_logprob": aligner_logprob, + "dr_log_pred": log_duration_prediction.squeeze(1), # [B, T] + "dr_log_target": dr.squeeze(1), # [B, T] + "spk_emb": speaker_embedding, + } + + @torch.no_grad() + def inference( + self, + tokens: torch.Tensor, + speaker_idx: torch.Tensor, + p_control: float = None, # TODO # pylint: disable=unused-argument + d_control: float = None, # TODO # pylint: disable=unused-argument + d_vectors: torch.Tensor = None, + pitch_transform: Callable = None, + energy_transform: Callable = None, + ) -> torch.Tensor: + src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) + src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable + sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable + {"d_vectors": d_vectors, "speaker_ids": speaker_idx} + ) # pylint: disable=unused-variable + + token_embeddings = self.src_word_emb(tokens) + token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) + + # Embeddings + speaker_embedding = None + if d_vectors is not None: + speaker_embedding = g + elif speaker_idx is not None: + speaker_embedding = F.normalize(self.emb_g(sid)) + + pos_encoding = positional_encoding( + self.emb_dim, + token_embeddings.shape[1], + device=token_embeddings.device, + ) + encoder_outputs = self.encoder( + token_embeddings, + src_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + + u_prosody_pred = self.u_norm( + self.average_utterance_prosody( + u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), + src_mask=src_mask, + ) + ) + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred).expand_as(encoder_outputs) + + p_prosody_pred = self.p_norm( + self.phoneme_prosody_predictor( + x=encoder_outputs, + mask=src_mask, + ) + ) + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred).expand_as(encoder_outputs) + + encoder_outputs_res = encoder_outputs + + pitch_emb_pred, pitch_pred = self.pitch_adaptor.get_pitch_embedding( + x=encoder_outputs, + mask=src_mask, + pitch_transform=pitch_transform, + pitch_mean=self.pitch_mean if hasattr(self, "pitch_mean") else None, + pitch_std=self.pitch_std if hasattr(self, "pitch_std") else None, + ) + + energy_emb_pred, energy_pred = self.energy_adaptor.get_energy_embedding( + x=encoder_outputs, mask=src_mask, energy_transform=energy_transform + ) + encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb_pred + energy_emb_pred + + log_duration_pred = self.duration_predictor( + x=encoder_outputs_res.detach(), mask=src_mask + ) # [B, C_hidden, T_src] -> [B, T_src] + duration_pred = (torch.exp(log_duration_pred) - 1) * (~src_mask) * self.length_scale # -> [B, T_src] + duration_pred[duration_pred < 1] = 1.0 # -> [B, T_src] + duration_pred = torch.round(duration_pred) # -> [B, T_src] + mel_lens = duration_pred.sum(1) # -> [B,] + + _, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( + o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None] + ) + + mel_mask = get_mask_from_lengths( + torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device) + ) + + if encoder_outputs_ex.shape[1] > pos_encoding.shape[1]: + encoding = positional_encoding(self.emb_dim, encoder_outputs_ex.shape[2], device=tokens.device) + + # [B, C_hidden, T_src], [B, 1, T_src], [B, C_emb], [B, T_src, C_hidden] -> [B, C_hidden, T_src] + x = self.decoder( + encoder_outputs_ex.transpose(1, 2), + mel_mask, + speaker_embedding=speaker_embedding, + encoding=encoding, + ) + x = self.to_mel(x) + outputs = { + "model_outputs": x, + "alignments": alignments, + # "pitch": pitch_emb_pred, + "durations": duration_pred, + "pitch": pitch_pred, + "energy": energy_pred, + "spk_emb": speaker_embedding, + } + return outputs diff --git a/viXTTS/TTS/tts/layers/delightful_tts/conformer.py b/viXTTS/TTS/tts/layers/delightful_tts/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b2175b3b965c6b100846e87d405a753dc272c9e7 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/conformer.py @@ -0,0 +1,450 @@ +### credit: https://github.com/dunky11/voicesmith +import math +from typing import Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d +from TTS.tts.layers.delightful_tts.networks import GLUActivation + + +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +class Conformer(nn.Module): + def __init__( + self, + dim: int, + n_layers: int, + n_heads: int, + speaker_embedding_dim: int, + p_dropout: float, + kernel_size_conv_mod: int, + lrelu_slope: float, + ): + """ + A Transformer variant that integrates both CNNs and Transformers components. + Conformer proposes a novel combination of self-attention and convolution, in which self-attention + learns the global interaction while the convolutions efficiently capture the local correlations. + + Args: + dim (int): Number of the dimensions for the model. + n_layers (int): Number of model layers. + n_heads (int): The number of attention heads. + speaker_embedding_dim (int): Number of speaker embedding dimensions. + p_dropout (float): Probabilty of dropout. + kernel_size_conv_mod (int): Size of kernels for convolution modules. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **encoding** (batch, time, dim): Positional embedding tensor + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produced by Conformer Encoder. + """ + super().__init__() + d_k = d_v = dim // n_heads + self.layer_stack = nn.ModuleList( + [ + ConformerBlock( + dim, + n_heads, + d_k, + d_v, + kernel_size_conv_mod=kernel_size_conv_mod, + dropout=p_dropout, + speaker_embedding_dim=speaker_embedding_dim, + lrelu_slope=lrelu_slope, + ) + for _ in range(n_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + speaker_embedding: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + Shapes: + - x: :math:`[B, T_src, C]` + - mask: :math: `[B]` + - speaker_embedding: :math: `[B, C]` + - encoding: :math: `[B, T_max2, C]` + """ + + attn_mask = mask.view((mask.shape[0], 1, 1, mask.shape[1])) + for enc_layer in self.layer_stack: + x = enc_layer( + x, + mask=mask, + slf_attn_mask=attn_mask, + speaker_embedding=speaker_embedding, + encoding=encoding, + ) + return x + + +class ConformerBlock(torch.nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + d_k: int, # pylint: disable=unused-argument + d_v: int, # pylint: disable=unused-argument + kernel_size_conv_mod: int, + speaker_embedding_dim: int, + dropout: float, + lrelu_slope: float = 0.3, + ): + """ + A Conformer block is composed of four modules stacked together, + A feed-forward module, a self-attention module, a convolution module, + and a second feed-forward module in the end. The block starts with two Feed forward + modules sandwiching the Multi-Headed Self-Attention module and the Conv module. + + Args: + d_model (int): The dimension of model + n_head (int): The number of attention heads. + kernel_size_conv_mod (int): Size of kernels for convolution modules. + speaker_embedding_dim (int): Number of speaker embedding dimensions. + emotion_embedding_dim (int): Number of emotion embedding dimensions. + dropout (float): Probabilty of dropout. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **encoding** (batch, time, dim): Positional embedding tensor + - **slf_attn_mask** (batch, 1, 1, time1): Tensor containing indices to be masked in self attention module + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produced by the Conformer Block. + """ + super().__init__() + if isinstance(speaker_embedding_dim, int): + self.conditioning = Conv1dGLU( + d_model=d_model, + kernel_size=kernel_size_conv_mod, + padding=kernel_size_conv_mod // 2, + embedding_dim=speaker_embedding_dim, + ) + + self.ff = FeedForward(d_model=d_model, dropout=dropout, kernel_size=3, lrelu_slope=lrelu_slope) + self.conformer_conv_1 = ConformerConvModule( + d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope + ) + self.ln = nn.LayerNorm(d_model) + self.slf_attn = ConformerMultiHeadedSelfAttention(d_model=d_model, num_heads=n_head, dropout_p=dropout) + self.conformer_conv_2 = ConformerConvModule( + d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope + ) + + def forward( + self, + x: torch.Tensor, + speaker_embedding: torch.Tensor, + mask: torch.Tensor, + slf_attn_mask: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + Shapes: + - x: :math:`[B, T_src, C]` + - mask: :math: `[B]` + - slf_attn_mask: :math: `[B, 1, 1, T_src]` + - speaker_embedding: :math: `[B, C]` + - emotion_embedding: :math: `[B, C]` + - encoding: :math: `[B, T_max2, C]` + """ + if speaker_embedding is not None: + x = self.conditioning(x, embeddings=speaker_embedding) + x = self.ff(x) + x + x = self.conformer_conv_1(x) + x + res = x + x = self.ln(x) + x, _ = self.slf_attn(query=x, key=x, value=x, mask=slf_attn_mask, encoding=encoding) + x = x + res + x = x.masked_fill(mask.unsqueeze(-1), 0) + + x = self.conformer_conv_2(x) + x + return x + + +class FeedForward(nn.Module): + def __init__( + self, + d_model: int, + kernel_size: int, + dropout: float, + lrelu_slope: float, + expansion_factor: int = 4, + ): + """ + Feed Forward module for conformer block. + + Args: + d_model (int): The dimension of model. + kernel_size (int): Size of the kernels for conv layers. + dropout (float): probability of dropout. + expansion_factor (int): The factor by which to project the number of channels. + lrelu_slope (int): the negative slope factor for the leaky relu activation. + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing input vector + Returns: + - **outputs** (batch, time, dim): Tensor produced by the feed forward module. + """ + super().__init__() + self.dropout = nn.Dropout(dropout) + self.ln = nn.LayerNorm(d_model) + self.conv_1 = nn.Conv1d( + d_model, + d_model * expansion_factor, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + self.act = nn.LeakyReLU(lrelu_slope) + self.conv_2 = nn.Conv1d(d_model * expansion_factor, d_model, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, C]` + """ + x = self.ln(x) + x = x.permute((0, 2, 1)) + x = self.conv_1(x) + x = x.permute((0, 2, 1)) + x = self.act(x) + x = self.dropout(x) + x = x.permute((0, 2, 1)) + x = self.conv_2(x) + x = x.permute((0, 2, 1)) + x = self.dropout(x) + x = 0.5 * x + return x + + +class ConformerConvModule(nn.Module): + def __init__( + self, + d_model: int, + expansion_factor: int = 2, + kernel_size: int = 7, + dropout: float = 0.1, + lrelu_slope: float = 0.3, + ): + """ + Convolution module for conformer. Starts with a gating machanism. + a pointwise convolution and a gated linear unit (GLU). This is followed + by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution + to help with training. it also contains an expansion factor to project the number of channels. + + Args: + d_model (int): The dimension of model. + expansion_factor (int): The factor by which to project the number of channels. + kernel_size (int): Size of kernels for convolution modules. + dropout (float): Probabilty of dropout. + lrelu_slope (float): The slope coefficient for leaky relu activation. + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing input vector + Returns: + - **outputs** (batch, time, dim): Tensor produced by the conv module. + + """ + super().__init__() + inner_dim = d_model * expansion_factor + self.ln_1 = nn.LayerNorm(d_model) + self.conv_1 = PointwiseConv1d(d_model, inner_dim * 2) + self.conv_act = GLUActivation(slope=lrelu_slope) + self.depthwise = DepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=kernel_size, + padding=calc_same_padding(kernel_size)[0], + ) + self.ln_2 = nn.GroupNorm(1, inner_dim) + self.activation = nn.LeakyReLU(lrelu_slope) + self.conv_2 = PointwiseConv1d(inner_dim, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, C]` + """ + x = self.ln_1(x) + x = x.permute(0, 2, 1) + x = self.conv_1(x) + x = self.conv_act(x) + x = self.depthwise(x) + x = self.ln_2(x) + x = self.activation(x) + x = self.conv_2(x) + x = x.permute(0, 2, 1) + x = self.dropout(x) + return x + + +class ConformerMultiHeadedSelfAttention(nn.Module): + """ + Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, + the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention + module to generalize better on different input length and the resulting encoder is more robust to the variance of + the utterance length. Conformer use prenorm residual units with dropout which helps training + and regularizing deeper models. + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + dropout_p (float): probability of dropout + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. + """ + + def __init__(self, d_model: int, num_heads: int, dropout_p: float): + super().__init__() + self.attention = RelativeMultiHeadAttention(d_model=d_model, num_heads=num_heads) + self.dropout = nn.Dropout(p=dropout_p) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + encoding: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_length, _ = key.size() # pylint: disable=unused-variable + encoding = encoding[:, : key.shape[1]] + encoding = encoding.repeat(batch_size, 1, 1) + outputs, attn = self.attention(query, key, value, pos_embedding=encoding, mask=mask) + outputs = self.dropout(outputs) + return outputs, attn + + +class RelativeMultiHeadAttention(nn.Module): + """ + Multi-head attention with relative positional encoding. + This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + Inputs: query, key, value, pos_embedding, mask + - **query** (batch, time, dim): Tensor containing query vector + - **key** (batch, time, dim): Tensor containing key vector + - **value** (batch, time, dim): Tensor containing value vector + - **pos_embedding** (batch, time, dim): Positional embedding tensor + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs**: Tensor produces by relative multi head attention module. + """ + + def __init__( + self, + d_model: int = 512, + num_heads: int = 16, + ): + super().__init__() + assert d_model % num_heads == 0, "d_model % num_heads should be zero." + self.d_model = d_model + self.d_head = int(d_model / num_heads) + self.num_heads = num_heads + self.sqrt_dim = math.sqrt(d_model) + + self.query_proj = nn.Linear(d_model, d_model) + self.key_proj = nn.Linear(d_model, d_model, bias=False) + self.value_proj = nn.Linear(d_model, d_model, bias=False) + self.pos_proj = nn.Linear(d_model, d_model, bias=False) + + self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + torch.nn.init.xavier_uniform_(self.u_bias) + torch.nn.init.xavier_uniform_(self.v_bias) + self.out_proj = nn.Linear(d_model, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_embedding: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = query.shape[0] + query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) + key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) + u_bias = self.u_bias.expand_as(query) + v_bias = self.v_bias.expand_as(query) + a = (query + u_bias).transpose(1, 2) + content_score = a @ key.transpose(2, 3) + b = (query + v_bias).transpose(1, 2) + pos_score = b @ pos_embedding.permute(0, 2, 3, 1) + pos_score = self._relative_shift(pos_score) + + score = content_score + pos_score + score = score * (1.0 / self.sqrt_dim) + + score.masked_fill_(mask, -1e9) + + attn = F.softmax(score, -1) + + context = (attn @ value).transpose(1, 2) + context = context.contiguous().view(batch_size, -1, self.d_model) + + return self.out_proj(context), attn + + def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: # pylint: disable=no-self-use + batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() + zeros = torch.zeros((batch_size, num_heads, seq_length1, 1), device=pos_score.device) + padded_pos_score = torch.cat([zeros, pos_score], dim=-1) + padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) + pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) + return pos_score + + +class MultiHeadAttention(nn.Module): + """ + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + """ + + def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor: + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + split_size = self.num_units // self.num_heads + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + # score = softmax(QK^T / (d_k ** 0.5)) + scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim**0.5) + scores = F.softmax(scores, dim=3) + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + return out diff --git a/viXTTS/TTS/tts/layers/delightful_tts/conv_layers.py b/viXTTS/TTS/tts/layers/delightful_tts/conv_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9aa4495fd9b25fb88ce0dd1493cfa1f2c47d5c --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/conv_layers.py @@ -0,0 +1,671 @@ +from typing import Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F +from torch.nn.utils import parametrize + +from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor + + +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +class ConvNorm(nn.Module): + """A 1-dimensional convolutional layer with optional weight normalization. + + This layer wraps a 1D convolutional layer from PyTorch and applies + optional weight normalization. The layer can be used in a similar way to + the convolutional layers in PyTorch's `torch.nn` module. + + Args: + in_channels (int): The number of channels in the input signal. + out_channels (int): The number of channels in the output signal. + kernel_size (int, optional): The size of the convolving kernel. + Defaults to 1. + stride (int, optional): The stride of the convolution. Defaults to 1. + padding (int, optional): Zero-padding added to both sides of the input. + If `None`, the padding will be calculated so that the output has + the same length as the input. Defaults to `None`. + dilation (int, optional): Spacing between kernel elements. Defaults to 1. + bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`. + w_init_gain (str, optional): The weight initialization function to use. + Can be either 'linear' or 'relu'. Defaults to 'linear'. + use_weight_norm (bool, optional): If `True`, apply weight normalization + to the convolutional weights. Defaults to `False`. + + Shapes: + - Input: :math:`[N, D, T]` + + - Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions. + + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + use_weight_norm=False, + ): + super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + self.kernel_size = kernel_size + self.dilation = dilation + self.use_weight_norm = use_weight_norm + conv_fn = nn.Conv1d + self.conv = conv_fn( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) + if self.use_weight_norm: + self.conv = nn.utils.parametrizations.weight_norm(self.conv) + + def forward(self, signal, mask=None): + conv_signal = self.conv(signal) + if mask is not None: + # always re-zero output if mask is + # available to match zero-padding + conv_signal = conv_signal * mask + return conv_signal + + +class ConvLSTMLinear(nn.Module): + def __init__( + self, + in_dim, + out_dim, + n_layers=2, + n_channels=256, + kernel_size=3, + p_dropout=0.1, + lstm_type="bilstm", + use_linear=True, + ): + super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments + self.out_dim = out_dim + self.lstm_type = lstm_type + self.use_linear = use_linear + self.dropout = nn.Dropout(p=p_dropout) + + convolutions = [] + for i in range(n_layers): + conv_layer = ConvNorm( + in_dim if i == 0 else n_channels, + n_channels, + kernel_size=kernel_size, + stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ) + conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight") + convolutions.append(conv_layer) + + self.convolutions = nn.ModuleList(convolutions) + + if not self.use_linear: + n_channels = out_dim + + if self.lstm_type != "": + use_bilstm = False + lstm_channels = n_channels + if self.lstm_type == "bilstm": + use_bilstm = True + lstm_channels = int(n_channels // 2) + + self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm) + lstm_norm_fn_pntr = nn.utils.spectral_norm + self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0") + if self.lstm_type == "bilstm": + self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse") + + if self.use_linear: + self.dense = nn.Linear(n_channels, out_dim) + + def run_padded_sequence(self, context, lens): + context_embedded = [] + for b_ind in range(context.size()[0]): # TODO: speed up + curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() + for conv in self.convolutions: + curr_context = self.dropout(F.relu(conv(curr_context))) + context_embedded.append(curr_context[0].transpose(0, 1)) + context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) + return context + + def run_unsorted_inputs(self, fn, context, lens): # pylint: disable=no-self-use + lens_sorted, ids_sorted = torch.sort(lens, descending=True) + unsort_ids = [0] * lens.size(0) + for i in range(len(ids_sorted)): # pylint: disable=consider-using-enumerate + unsort_ids[ids_sorted[i]] = i + lens_sorted = lens_sorted.long().cpu() + + context = context[ids_sorted] + context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True) + context = fn(context)[0] + context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0] + + # map back to original indices + context = context[unsort_ids] + return context + + def forward(self, context, lens): + if context.size()[0] > 1: + context = self.run_padded_sequence(context, lens) + # to B, D, T + context = context.transpose(1, 2) + else: + for conv in self.convolutions: + context = self.dropout(F.relu(conv(context))) + + if self.lstm_type != "": + context = context.transpose(1, 2) + self.bilstm.flatten_parameters() + if lens is not None: + context = self.run_unsorted_inputs(self.bilstm, context, lens) + else: + context = self.bilstm(context)[0] + context = context.transpose(1, 2) + + x_hat = context + if self.use_linear: + x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2) + + return x_hat + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int): + super().__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class PointwiseConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class BSConv1d(nn.Module): + """https://arxiv.org/pdf/2003.13549.pdf""" + + def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): + super().__init__() + self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1) + self.depthwise = nn.Conv1d( + channels_out, + channels_out, + kernel_size=kernel_size, + padding=padding, + groups=channels_out, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.pointwise(x) + x2 = self.depthwise(x1) + return x2 + + +class BSConv2d(nn.Module): + """https://arxiv.org/pdf/2003.13549.pdf""" + + def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): + super().__init__() + self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1) + self.depthwise = nn.Conv2d( + channels_out, + channels_out, + kernel_size=kernel_size, + padding=padding, + groups=channels_out, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.pointwise(x) + x2 = self.depthwise(x1) + return x2 + + +class Conv1dGLU(nn.Module): + """From DeepVoice 3""" + + def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int): + super().__init__() + self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding) + self.embedding_proj = nn.Linear(embedding_dim, d_model) + self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0)) + self.softsign = torch.nn.Softsign() + + def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: + x = x.permute((0, 2, 1)) + residual = x + x = self.conv(x) + splitdim = 1 + a, b = x.split(x.size(splitdim) // 2, dim=splitdim) + embeddings = self.embedding_proj(embeddings).unsqueeze(2) + softsign = self.softsign(embeddings) + softsign = softsign.expand_as(a) + a = a + softsign + x = a * torch.sigmoid(b) + x = x + residual + x = x * self.sqrt + x = x.permute((0, 2, 1)) + return x + + +class ConvTransposed(nn.Module): + """ + A 1D convolutional transposed layer for PyTorch. + This layer applies a 1D convolutional transpose operation to its input tensor, + where the number of channels of the input tensor is the same as the number of channels of the output tensor. + + Attributes: + in_channels (int): The number of channels in the input tensor. + out_channels (int): The number of channels in the output tensor. + kernel_size (int): The size of the convolutional kernel. Default: 1. + padding (int): The number of padding elements to add to the input tensor. Default: 0. + conv (BSConv1d): The 1D convolutional transpose layer. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + padding: int = 0, + ): + super().__init__() + self.conv = BSConv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous().transpose(1, 2) + x = self.conv(x) + x = x.contiguous().transpose(1, 2) + return x + + +class DepthwiseConvModule(nn.Module): + def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3): + super().__init__() + padding = calc_same_padding(kernel_size) + self.depthwise = nn.Conv1d( + dim, + dim * expansion, + kernel_size=kernel_size, + padding=padding[0], + groups=dim, + ) + self.act = nn.LeakyReLU(lrelu_slope) + self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0) + self.ln = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln(x) + x = x.permute((0, 2, 1)) + x = self.depthwise(x) + x = self.act(x) + x = self.out(x) + x = x.permute((0, 2, 1)) + return x + + +class AddCoords(nn.Module): + def __init__(self, rank: int, with_r: bool = False): + super().__init__() + self.rank = rank + self.with_r = with_r + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.rank == 1: + batch_size_shape, channel_in_shape, dim_x = x.shape # pylint: disable=unused-variable + xx_range = torch.arange(dim_x, dtype=torch.int32) + xx_channel = xx_range[None, None, :] + + xx_channel = xx_channel.float() / (dim_x - 1) + xx_channel = xx_channel * 2 - 1 + xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) + + xx_channel = xx_channel.to(x.device) + out = torch.cat([x, xx_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) + out = torch.cat([out, rr], dim=1) + + elif self.rank == 2: + batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape + xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) + yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) + + xx_range = torch.arange(dim_y, dtype=torch.int32) + yy_range = torch.arange(dim_x, dtype=torch.int32) + xx_range = xx_range[None, None, :, None] + yy_range = yy_range[None, None, :, None] + + xx_channel = torch.matmul(xx_range, xx_ones) + yy_channel = torch.matmul(yy_range, yy_ones) + + # transpose y + yy_channel = yy_channel.permute(0, 1, 3, 2) + + xx_channel = xx_channel.float() / (dim_y - 1) + yy_channel = yy_channel.float() / (dim_x - 1) + + xx_channel = xx_channel * 2 - 1 + yy_channel = yy_channel * 2 - 1 + + xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) + yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) + + xx_channel = xx_channel.to(x.device) + yy_channel = yy_channel.to(x.device) + + out = torch.cat([x, xx_channel, yy_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) + out = torch.cat([out, rr], dim=1) + + elif self.rank == 3: + batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape + xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) + yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) + zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) + + xy_range = torch.arange(dim_y, dtype=torch.int32) + xy_range = xy_range[None, None, None, :, None] + + yz_range = torch.arange(dim_z, dtype=torch.int32) + yz_range = yz_range[None, None, None, :, None] + + zx_range = torch.arange(dim_x, dtype=torch.int32) + zx_range = zx_range[None, None, None, :, None] + + xy_channel = torch.matmul(xy_range, xx_ones) + xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) + + yz_channel = torch.matmul(yz_range, yy_ones) + yz_channel = yz_channel.permute(0, 1, 3, 4, 2) + yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) + + zx_channel = torch.matmul(zx_range, zz_ones) + zx_channel = zx_channel.permute(0, 1, 4, 2, 3) + zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) + + xx_channel = xx_channel.to(x.device) + yy_channel = yy_channel.to(x.device) + zz_channel = zz_channel.to(x.device) + out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1) + + if self.with_r: + rr = torch.sqrt( + torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2) + ) + out = torch.cat([out, rr], dim=1) + else: + raise NotImplementedError + + return out + + +class CoordConv1d(nn.modules.conv.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + with_r: bool = False, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.rank = 1 + self.addcoords = AddCoords(self.rank, with_r) + self.conv = nn.Conv1d( + in_channels + self.rank + int(with_r), + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.addcoords(x) + x = self.conv(x) + return x + + +class CoordConv2d(nn.modules.conv.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + with_r: bool = False, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.rank = 2 + self.addcoords = AddCoords(self.rank, with_r) + self.conv = nn.Conv2d( + in_channels + self.rank + int(with_r), + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.addcoords(x) + x = self.conv(x) + return x + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( # pylint: disable=dangerous-default-value + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + in_channels, + in_channels, + conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, + dilation=dilation, + ) + ), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + parametrize.remove_parametrizations(self.convt_pre[1], "weight") + for block in self.conv_blocks: + parametrize.remove_parametrizations(block[1], "weight") diff --git a/viXTTS/TTS/tts/layers/delightful_tts/encoders.py b/viXTTS/TTS/tts/layers/delightful_tts/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..0878f0677a29d092597a46e8a3b11e4a521769b8 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/encoders.py @@ -0,0 +1,261 @@ +from typing import List, Tuple, Union + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention +from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d +from TTS.tts.layers.delightful_tts.networks import STL + + +def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + return mask + + +def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: + return torch.ceil(lens / stride).int() + + +class ReferenceEncoder(nn.Module): + """ + Referance encoder for utterance and phoneme prosody encoders. Reference encoder + made up of convolution and RNN layers. + + Args: + num_mels (int): Number of mel frames to produce. + ref_enc_filters (list[int]): List of channel sizes for encoder layers. + ref_enc_size (int): Size of the kernel for the conv layers. + ref_enc_strides (List[int]): List of strides to use for conv layers. + ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. + + Inputs: inputs, mask + - **inputs** (batch, dim, time): Tensor containing mel vector + - **lengths** (batch): Tensor containing the mel lengths. + Returns: + - **outputs** (batch, time, dim): Tensor produced by Reference Encoder. + """ + + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + ): + super().__init__() + + n_mel_channels = num_mels + self.n_mel_channels = n_mel_channels + K = len(ref_enc_filters) + filters = [self.n_mel_channels] + ref_enc_filters + strides = [1] + ref_enc_strides + # Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf + convs = [ + CoordConv1d( + in_channels=filters[0], + out_channels=filters[0 + 1], + kernel_size=ref_enc_size, + stride=strides[0], + padding=ref_enc_size // 2, + with_r=True, + ) + ] + convs2 = [ + nn.Conv1d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=ref_enc_size, + stride=strides[i], + padding=ref_enc_size // 2, + ) + for i in range(1, K) + ] + convs.extend(convs2) + self.convs = nn.ModuleList(convs) + + self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)]) + + self.gru = nn.GRU( + input_size=ref_enc_filters[-1], + hidden_size=ref_enc_gru_size, + batch_first=True, + ) + + def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + inputs --- [N, n_mels, timesteps] + outputs --- [N, E//2] + """ + + mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) + x = x.masked_fill(mel_masks, 0) + for conv, norm in zip(self.convs, self.norms): + x = conv(x) + x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K] + x = norm(x) + + for _ in range(2): + mel_lens = stride_lens(mel_lens) + + mel_masks = get_mask_from_lengths(mel_lens) + + x = x.masked_fill(mel_masks.unsqueeze(1), 0) + x = x.permute((0, 2, 1)) + x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False) + + self.gru.flatten_parameters() + x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2] + x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + + return x, memory, mel_masks + + def calculate_channels( # pylint: disable=no-self-use + self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int + ) -> int: + for _ in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class UtteranceLevelProsodyEncoder(nn.Module): + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + dropout: float, + n_hidden: int, + bottleneck_size_u: int, + token_num: int, + ): + """ + Encoder to extract prosody from utterance. it is made up of a reference encoder + with a couple of linear layers and style token layer with dropout. + + Args: + num_mels (int): Number of mel frames to produce. + ref_enc_filters (list[int]): List of channel sizes for ref encoder layers. + ref_enc_size (int): Size of the kernel for the ref encoder conv layers. + ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers. + ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. + dropout (float): Probability of dropout. + n_hidden (int): Size of hidden layers. + bottleneck_size_u (int): Size of the bottle neck layer. + + Inputs: inputs, mask + - **inputs** (batch, dim, time): Tensor containing mel vector + - **lengths** (batch): Tensor containing the mel lengths. + Returns: + - **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder. + """ + super().__init__() + + self.E = n_hidden + self.d_q = self.d_k = n_hidden + bottleneck_size = bottleneck_size_u + + self.encoder = ReferenceEncoder( + ref_enc_filters=ref_enc_filters, + ref_enc_gru_size=ref_enc_gru_size, + ref_enc_size=ref_enc_size, + ref_enc_strides=ref_enc_strides, + num_mels=num_mels, + ) + self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2) + self.stl = STL(n_hidden=n_hidden, token_num=token_num) + self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: + """ + Shapes: + mels: :math: `[B, C, T]` + mel_lens: :math: `[B]` + + out --- [N, seq_len, E] + """ + _, embedded_prosody, _ = self.encoder(mels, mel_lens) + + # Bottleneck + embedded_prosody = self.encoder_prj(embedded_prosody) + + # Style Token + out = self.encoder_bottleneck(self.stl(embedded_prosody)) + out = self.dropout(out) + + out = out.view((-1, 1, out.shape[3])) + return out + + +class PhonemeLevelProsodyEncoder(nn.Module): + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + dropout: float, + n_hidden: int, + n_heads: int, + bottleneck_size_p: int, + ): + super().__init__() + + self.E = n_hidden + self.d_q = self.d_k = n_hidden + bottleneck_size = bottleneck_size_p + + self.encoder = ReferenceEncoder( + ref_enc_filters=ref_enc_filters, + ref_enc_gru_size=ref_enc_gru_size, + ref_enc_size=ref_enc_size, + ref_enc_strides=ref_enc_strides, + num_mels=num_mels, + ) + self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden) + self.attention = ConformerMultiHeadedSelfAttention( + d_model=n_hidden, + num_heads=n_heads, + dropout_p=dropout, + ) + self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size) + + def forward( + self, + x: torch.Tensor, + src_mask: torch.Tensor, + mels: torch.Tensor, + mel_lens: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + x --- [N, seq_len, encoder_embedding_dim] + mels --- [N, Ty/r, n_mels*r], r=1 + out --- [N, seq_len, bottleneck_size] + attn --- [N, seq_len, ref_len], Ty/r = ref_len + """ + embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens) + + # Bottleneck + embedded_prosody = self.encoder_prj(embedded_prosody) + + attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1)) + x, _ = self.attention( + query=x, + key=embedded_prosody, + value=embedded_prosody, + mask=attn_mask, + encoding=encoding, + ) + x = self.encoder_bottleneck(x) + x = x.masked_fill(src_mask.unsqueeze(-1), 0.0) + return x diff --git a/viXTTS/TTS/tts/layers/delightful_tts/energy_adaptor.py b/viXTTS/TTS/tts/layers/delightful_tts/energy_adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..ea0d1e47214d81a42b934bbaaa4b3ebb9f63bcc6 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/energy_adaptor.py @@ -0,0 +1,82 @@ +from typing import Callable, Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.utils.helpers import average_over_durations + + +class EnergyAdaptor(nn.Module): # pylint: disable=abstract-method + """Variance Adaptor with an added 1D conv layer. Used to + get energy embeddings. + + Args: + channels_in (int): Number of in channels for conv layers. + channels_out (int): Number of out channels. + kernel_size (int): Size the kernel for the conv layers. + dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + emb_kernel_size (int): Size the kernel for the pitch embedding. + + Inputs: inputs, mask + - **inputs** (batch, time1, dim): Tensor containing input vector + - **target** (batch, 1, time2): Tensor containing the energy target + - **dr** (batch, time1): Tensor containing aligner durations vector + - **mask** (batch, time1): Tensor containing indices to be masked + Returns: + - **energy prediction** (batch, 1, time1): Tensor produced by energy predictor + - **energy embedding** (batch, channels, time1): Tensor produced energy adaptor + - **average energy target(train only)** (batch, 1, time1): Tensor produced after averaging over durations + + """ + + def __init__( + self, + channels_in: int, + channels_hidden: int, + channels_out: int, + kernel_size: int, + dropout: float, + lrelu_slope: float, + emb_kernel_size: int, + ): + super().__init__() + self.energy_predictor = VariancePredictor( + channels_in=channels_in, + channels=channels_hidden, + channels_out=channels_out, + kernel_size=kernel_size, + p_dropout=dropout, + lrelu_slope=lrelu_slope, + ) + self.energy_emb = nn.Conv1d( + 1, + channels_hidden, + kernel_size=emb_kernel_size, + padding=int((emb_kernel_size - 1) / 2), + ) + + def get_energy_embedding_train( + self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Shapes: + x: :math: `[B, T_src, C]` + target: :math: `[B, 1, T_max2]` + dr: :math: `[B, T_src]` + mask: :math: `[B, T_src]` + """ + energy_pred = self.energy_predictor(x, mask) + energy_pred.unsqueeze_(1) + avg_energy_target = average_over_durations(target, dr) + energy_emb = self.energy_emb(avg_energy_target) + return energy_pred, avg_energy_target, energy_emb + + def get_energy_embedding(self, x: torch.Tensor, mask: torch.Tensor, energy_transform: Callable) -> torch.Tensor: + energy_pred = self.energy_predictor(x, mask) + energy_pred.unsqueeze_(1) + if energy_transform is not None: + energy_pred = energy_transform(energy_pred, (~mask).sum(dim=(1, 2)), self.pitch_mean, self.pitch_std) + energy_emb_pred = self.energy_emb(energy_pred) + return energy_emb_pred, energy_pred diff --git a/viXTTS/TTS/tts/layers/delightful_tts/kernel_predictor.py b/viXTTS/TTS/tts/layers/delightful_tts/kernel_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..96c550b6c2609c52762bb3eaca373e31f0599bf8 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/kernel_predictor.py @@ -0,0 +1,128 @@ +import torch.nn as nn # pylint: disable=consider-using-from-import +from torch.nn.utils import parametrize + + +class KernelPredictor(nn.Module): + """Kernel predictor for the location-variable convolutions + + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + + """ + + def __init__( # pylint: disable=dangerous-default-value + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.parametrizations.weight_norm( + nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + parametrize.remove_parametrizations(self.input_conv[0], "weight") + parametrize.remove_parametrizations(self.kernel_conv, "weight") + parametrize.remove_parametrizations(self.bias_conv, "weight") + for block in self.residual_convs: + parametrize.remove_parametrizations(block[1], "weight") + parametrize.remove_parametrizations(block[3], "weight") diff --git a/viXTTS/TTS/tts/layers/delightful_tts/networks.py b/viXTTS/TTS/tts/layers/delightful_tts/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..4305022f18cf95565b2da2553740276818fb486c --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/networks.py @@ -0,0 +1,219 @@ +import math +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conv_layers import ConvNorm + + +def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: + assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." + # Kaiming initialization + return torch.randn(shape) * np.sqrt(2 / shape[1]) + + +def positional_encoding(d_model: int, length: int, device: torch.device) -> torch.Tensor: + pe = torch.zeros(length, d_model, device=device) + position = torch.arange(0, length, dtype=torch.float, device=device).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + return pe + + +class BottleneckLayer(nn.Module): + """ + Bottleneck layer for reducing the dimensionality of a tensor. + + Args: + in_dim: The number of input dimensions. + reduction_factor: The factor by which to reduce the number of dimensions. + norm: The normalization method to use. Can be "weightnorm" or "instancenorm". + non_linearity: The non-linearity to use. Can be "relu" or "leakyrelu". + kernel_size: The size of the convolutional kernel. + use_partial_padding: Whether to use partial padding with the convolutional kernel. + + Shape: + - Input: :math:`[N, in_dim]` where `N` is the batch size and `in_dim` is the number of input dimensions. + + - Output: :math:`[N, out_dim]` where `out_dim` is the number of output dimensions. + """ + + def __init__( + self, + in_dim, + reduction_factor, + norm="weightnorm", + non_linearity="relu", + kernel_size=3, + use_partial_padding=False, # pylint: disable=unused-argument + ): + super(BottleneckLayer, self).__init__() # pylint: disable=super-with-arguments + + self.reduction_factor = reduction_factor + reduced_dim = int(in_dim / reduction_factor) + self.out_dim = reduced_dim + if self.reduction_factor > 1: + fn = ConvNorm(in_dim, reduced_dim, kernel_size=kernel_size, use_weight_norm=(norm == "weightnorm")) + if norm == "instancenorm": + fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True)) + + self.projection_fn = fn + self.non_linearity = nn.ReLU() + if non_linearity == "leakyrelu": + self.non_linearity = nn.LeakyReLU() + + def forward(self, x): + if self.reduction_factor > 1: + x = self.projection_fn(x) + x = self.non_linearity(x) + return x + + +class GLUActivation(nn.Module): + """Class that implements the Gated Linear Unit (GLU) activation function. + + The GLU activation function is a variant of the Leaky ReLU activation function, + where the output of the activation function is gated by an input tensor. + + """ + + def __init__(self, slope: float): + super().__init__() + self.lrelu = nn.LeakyReLU(slope) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out, gate = x.chunk(2, dim=1) + x = out * self.lrelu(gate) + return x + + +class StyleEmbedAttention(nn.Module): + def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query: torch.Tensor, key_soft: torch.Tensor) -> torch.Tensor: + values = self.W_value(key_soft) + split_size = self.num_units // self.num_heads + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) + + out_soft = scores_soft = None + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key_soft) # [N, T_k, num_units] + + # [h, N, T_q, num_units/h] + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) + # [h, N, T_k, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) + # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k ** 0.5)) + scores_soft = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores_soft = scores_soft / (self.key_dim**0.5) + scores_soft = F.softmax(scores_soft, dim=3) + + # out = score * V + # [h, N, T_q, num_units/h] + out_soft = torch.matmul(scores_soft, values) + out_soft = torch.cat(torch.split(out_soft, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out_soft # , scores_soft + + +class EmbeddingPadded(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): + super().__init__() + padding_mult = torch.ones((num_embeddings, 1), dtype=torch.int64) + padding_mult[padding_idx] = 0 + self.register_buffer("padding_mult", padding_mult) + self.embeddings = nn.parameter.Parameter(initialize_embeddings((num_embeddings, embedding_dim))) + + def forward(self, idx: torch.Tensor) -> torch.Tensor: + embeddings_zeroed = self.embeddings * self.padding_mult + x = F.embedding(idx, embeddings_zeroed) + return x + + +class EmbeddingProjBlock(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.Linear(embedding_dim, embedding_dim), + nn.LeakyReLU(0.3), + nn.Linear(embedding_dim, embedding_dim), + nn.LeakyReLU(0.3), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = x + for layer in self.layers: + x = layer(x) + x = x + res + return x + + +class LinearNorm(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + + nn.init.xavier_uniform_(self.linear.weight) + if bias: + nn.init.constant_(self.linear.bias, 0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + return x + + +class STL(nn.Module): + """ + A PyTorch module for the Style Token Layer (STL) as described in + "A Style-Based Generator Architecture for Generative Adversarial Networks" + (https://arxiv.org/abs/1812.04948) + + The STL applies a multi-headed attention mechanism over the learned style tokens, + using the text input as the query and the style tokens as the keys and values. + The output of the attention mechanism is used as the text's style embedding. + + Args: + token_num (int): The number of style tokens. + n_hidden (int): Number of hidden dimensions. + """ + + def __init__(self, n_hidden: int, token_num: int): + super(STL, self).__init__() # pylint: disable=super-with-arguments + + num_heads = 1 + E = n_hidden + self.token_num = token_num + self.embed = nn.Parameter(torch.FloatTensor(self.token_num, E // num_heads)) + d_q = E // 2 + d_k = E // num_heads + self.attention = StyleEmbedAttention(query_dim=d_q, key_dim=d_k, num_units=E, num_heads=num_heads) + + torch.nn.init.normal_(self.embed, mean=0, std=0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + N = x.size(0) + query = x.unsqueeze(1) # [N, 1, E//2] + + keys_soft = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads] + + # Weighted sum + emotion_embed_soft = self.attention(query, keys_soft) + + return emotion_embed_soft diff --git a/viXTTS/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py b/viXTTS/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..28418f7163361120914f277446f76ac9f0363254 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed + + +class PhonemeProsodyPredictor(nn.Module): + """Non-parallel Prosody Predictor inspired by: https://arxiv.org/pdf/2102.00851.pdf + It consists of 2 layers of 1D convolutions each followed by a relu activation, layer norm + and dropout, then finally a linear layer. + + Args: + hidden_size (int): Size of hidden channels. + kernel_size (int): Kernel size for the conv layers. + dropout: (float): Probability of dropout. + bottleneck_size (int): bottleneck size for last linear layer. + lrelu_slope (float): Slope of the leaky relu. + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int, + dropout: float, + bottleneck_size: int, + lrelu_slope: float, + ): + super().__init__() + self.d_model = hidden_size + self.layers = nn.ModuleList( + [ + ConvTransposed( + self.d_model, + self.d_model, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(self.d_model), + nn.Dropout(dropout), + ConvTransposed( + self.d_model, + self.d_model, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(self.d_model), + nn.Dropout(dropout), + ] + ) + self.predictor_bottleneck = nn.Linear(self.d_model, bottleneck_size) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, D]` + mask: :math: `[B, T]` + """ + mask = mask.unsqueeze(2) + for layer in self.layers: + x = layer(x) + x = x.masked_fill(mask, 0.0) + x = self.predictor_bottleneck(x) + return x diff --git a/viXTTS/TTS/tts/layers/delightful_tts/pitch_adaptor.py b/viXTTS/TTS/tts/layers/delightful_tts/pitch_adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..9031369e0f019cf115d0d43b288bb97d9db48467 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/pitch_adaptor.py @@ -0,0 +1,88 @@ +from typing import Callable, Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.utils.helpers import average_over_durations + + +class PitchAdaptor(nn.Module): # pylint: disable=abstract-method + """Module to get pitch embeddings via pitch predictor + + Args: + n_input (int): Number of pitch predictor input channels. + n_hidden (int): Number of pitch predictor hidden channels. + n_out (int): Number of pitch predictor out channels. + kernel size (int): Size of the kernel for conv layers. + emb_kernel_size (int): Size the kernel for the pitch embedding. + p_dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + + Inputs: inputs, mask + - **inputs** (batch, time1, dim): Tensor containing input vector + - **target** (batch, 1, time2): Tensor containing the pitch target + - **dr** (batch, time1): Tensor containing aligner durations vector + - **mask** (batch, time1): Tensor containing indices to be masked + Returns: + - **pitch prediction** (batch, 1, time1): Tensor produced by pitch predictor + - **pitch embedding** (batch, channels, time1): Tensor produced pitch pitch adaptor + - **average pitch target(train only)** (batch, 1, time1): Tensor produced after averaging over durations + """ + + def __init__( + self, + n_input: int, + n_hidden: int, + n_out: int, + kernel_size: int, + emb_kernel_size: int, + p_dropout: float, + lrelu_slope: float, + ): + super().__init__() + self.pitch_predictor = VariancePredictor( + channels_in=n_input, + channels=n_hidden, + channels_out=n_out, + kernel_size=kernel_size, + p_dropout=p_dropout, + lrelu_slope=lrelu_slope, + ) + self.pitch_emb = nn.Conv1d( + 1, + n_input, + kernel_size=emb_kernel_size, + padding=int((emb_kernel_size - 1) / 2), + ) + + def get_pitch_embedding_train( + self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Shapes: + x: :math: `[B, T_src, C]` + target: :math: `[B, 1, T_max2]` + dr: :math: `[B, T_src]` + mask: :math: `[B, T_src]` + """ + pitch_pred = self.pitch_predictor(x, mask) # [B, T_src, C_hidden], [B, T_src] --> [B, T_src] + pitch_pred.unsqueeze_(1) # --> [B, 1, T_src] + avg_pitch_target = average_over_durations(target, dr) # [B, 1, T_mel], [B, T_src] --> [B, 1, T_src] + pitch_emb = self.pitch_emb(avg_pitch_target) # [B, 1, T_src] --> [B, C_hidden, T_src] + return pitch_pred, avg_pitch_target, pitch_emb + + def get_pitch_embedding( + self, + x: torch.Tensor, + mask: torch.Tensor, + pitch_transform: Callable, + pitch_mean: torch.Tensor, + pitch_std: torch.Tensor, + ) -> torch.Tensor: + pitch_pred = self.pitch_predictor(x, mask) + if pitch_transform is not None: + pitch_pred = pitch_transform(pitch_pred, (~mask).sum(), pitch_mean, pitch_std) + pitch_pred.unsqueeze_(1) + pitch_emb_pred = self.pitch_emb(pitch_pred) + return pitch_emb_pred, pitch_pred diff --git a/viXTTS/TTS/tts/layers/delightful_tts/variance_predictor.py b/viXTTS/TTS/tts/layers/delightful_tts/variance_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..68303a1bd1148089eab7ee8be12d4f37ddf420e1 --- /dev/null +++ b/viXTTS/TTS/tts/layers/delightful_tts/variance_predictor.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed + + +class VariancePredictor(nn.Module): + """ + Network is 2-layer 1D convolutions with leaky relu activation and then + followed by layer normalization then a dropout layer and finally an + extra linear layer to project the hidden states into the output sequence. + + Args: + channels_in (int): Number of in channels for conv layers. + channels_out (int): Number of out channels for the last linear layer. + kernel_size (int): Size the kernel for the conv layers. + p_dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **mask** (batch, time): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time): Tensor produced by last linear layer. + """ + + def __init__( + self, channels_in: int, channels: int, channels_out: int, kernel_size: int, p_dropout: float, lrelu_slope: float + ): + super().__init__() + + self.layers = nn.ModuleList( + [ + ConvTransposed( + channels_in, + channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(channels), + nn.Dropout(p_dropout), + ConvTransposed( + channels, + channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(channels), + nn.Dropout(p_dropout), + ] + ) + + self.linear_layer = nn.Linear(channels, channels_out) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T_src, C]` + mask: :math: `[B, T_src]` + """ + for layer in self.layers: + x = layer(x) + x = self.linear_layer(x) + x = x.squeeze(-1) + x = x.masked_fill(mask, 0.0) + return x diff --git a/viXTTS/TTS/tts/layers/feed_forward/__init__.py b/viXTTS/TTS/tts/layers/feed_forward/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/feed_forward/decoder.py b/viXTTS/TTS/tts/layers/feed_forward/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0376e2e3926e65254c3a81d085d48c97df033958 --- /dev/null +++ b/viXTTS/TTS/tts/layers/feed_forward/decoder.py @@ -0,0 +1,228 @@ +import torch +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN, Conv1dBNBlock, ResidualConv1dBNBlock +from TTS.tts.layers.generic.transformer import FFTransformerBlock +from TTS.tts.layers.generic.wavenet import WNBlocks +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer + + +class WaveNetDecoder(nn.Module): + """WaveNet based decoder with a prenet and a postnet. + + prenet: conv1d_1x1 + postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1 + + TODO: Integrate speaker conditioning vector. + + Note: + default wavenet parameters; + params = { + "num_blocks": 12, + "hidden_channels":192, + "kernel_size": 5, + "dilation_rate": 1, + "num_layers": 4, + "dropout_p": 0.05 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels for prenet and postnet. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): + super().__init__() + # prenet + self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1) + # wavenet layers + self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params) + # postnet + self.postnet = [ + torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, out_channels, 1), + ] + self.postnet = nn.Sequential(*self.postnet) + + def forward(self, x, x_mask=None, g=None): + x = self.prenet(x) * x_mask + x = self.wn(x, x_mask, g) + o = self.postnet(x) * x_mask + return o + + +class RelativePositionTransformerDecoder(nn.Module): + """Decoder with Relative Positional Transformer. + + Note: + Default params + params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 8, + "rel_attn_window_size": 4, + "input_length": None + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) + self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class FFTransformerDecoder(nn.Module): + """Decoder with FeedForwardTransformer. + + Default params + params={ + 'hidden_channels_ffn': 1024, + 'num_heads': 2, + "dropout_p": 0.1, + "num_layers": 6, + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, params): + super().__init__() + self.transformer_block = FFTransformerBlock(in_channels, **params) + self.postnet = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + # TODO: handle multi-speaker + x_mask = 1 if x_mask is None else x_mask + o = self.transformer_block(x) * x_mask + o = self.postnet(o) * x_mask + return o + + +class ResidualConv1dBNDecoder(nn.Module): + """Residual Convolutional Decoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Note: + Default params + params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params) + self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) + self.postnet = nn.Sequential( + Conv1dBNBlock( + hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2 + ), + nn.Conv1d(hidden_channels, out_channels, 1), + ) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.res_conv_block(x, x_mask) + o = self.post_conv(o) + x + return self.postnet(o) * x_mask + + +class Decoder(nn.Module): + """Decodes the expanded phoneme encoding into spectrograms + Args: + out_channels (int): number of output channels. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. + decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + decoder_params (dict): model parameters for specified decoder type. + c_in_channels (int): number of channels for conditional input. + + Shapes: + - input: (B, C, T) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + out_channels, + in_hidden_channels, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + c_in_channels=0, + ): + super().__init__() + + if decoder_type.lower() == "relative_position_transformer": + self.decoder = RelativePositionTransformerDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "residual_conv_bn": + self.decoder = ResidualConv1dBNDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "wavenet": + self.decoder = WaveNetDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + c_in_channels=c_in_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "fftransformer": + self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) + else: + raise ValueError(f"[!] Unknown decoder type - {decoder_type}") + + def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument + """ + Args: + x: [B, C, T] + x_mask: [B, 1, T] + g: [B, C_g, 1] + """ + # TODO: implement multi-speaker + o = self.decoder(x, x_mask, g) + return o diff --git a/viXTTS/TTS/tts/layers/feed_forward/duration_predictor.py b/viXTTS/TTS/tts/layers/feed_forward/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4422648f4337e48aab39671836fcfb5e12ff4be7 --- /dev/null +++ b/viXTTS/TTS/tts/layers/feed_forward/duration_predictor.py @@ -0,0 +1,41 @@ +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN + + +class DurationPredictor(nn.Module): + """Speedy Speech duration predictor model. + Predicts phoneme durations from encoder outputs. + + Note: + Outputs interpreted as log(durations) + To get actual durations, do exp transformation + + conv_BN_4x1 -> conv_BN_3x1 -> conv_BN_1x1 -> conv_1x1 + + Args: + hidden_channels (int): number of channels in the inner layers. + """ + + def __init__(self, hidden_channels): + super().__init__() + + self.layers = nn.ModuleList( + [ + Conv1dBN(hidden_channels, hidden_channels, 4, 1), + Conv1dBN(hidden_channels, hidden_channels, 3, 1), + Conv1dBN(hidden_channels, hidden_channels, 1, 1), + nn.Conv1d(hidden_channels, 1, 1), + ] + ) + + def forward(self, x, x_mask): + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + """ + o = x + for layer in self.layers: + o = layer(o) * x_mask + return o diff --git a/viXTTS/TTS/tts/layers/feed_forward/encoder.py b/viXTTS/TTS/tts/layers/feed_forward/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..caf939ffc73fedac299228e090b2df3bb4cc553c --- /dev/null +++ b/viXTTS/TTS/tts/layers/feed_forward/encoder.py @@ -0,0 +1,162 @@ +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.transformer import FFTransformerBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer + + +class RelativePositionTransformerEncoder(nn.Module): + """Speedy speech encoder built on Transformer with Relative Position encoding. + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = ResidualConv1dBNBlock( + in_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_res_blocks=3, + num_conv_blocks=1, + dilations=[1, 1, 1], + ) + self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class ResidualConv1dBNEncoder(nn.Module): + """Residual Convolutional Encoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) + self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) + + self.postnet = nn.Sequential( + *[ + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.ReLU(), + nn.BatchNorm1d(hidden_channels), + nn.Conv1d(hidden_channels, out_channels, 1), + ] + ) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.res_conv_block(o, x_mask) + o = self.postnet(o + x) * x_mask + return o * x_mask + + +class Encoder(nn.Module): + # pylint: disable=dangerous-default-value + """Factory class for Speedy Speech encoder enables different encoder types internally. + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. + encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + encoder_params (dict): model parameters for specified encoder type. + c_in_channels (int): number of channels for conditional input. + + Note: + Default encoder_params to be set in config.json... + + ```python + # for 'relative_position_transformer' + encoder_params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "rel_attn_window_size": 4, + "input_length": None + }, + + # for 'residual_conv_bn' + encoder_params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + } + + # for 'fftransformer' + encoder_params = { + "hidden_channels_ffn": 1024 , + "num_heads": 2, + "num_layers": 6, + "dropout_p": 0.1 + } + ``` + """ + + def __init__( + self, + in_hidden_channels, + out_channels, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + c_in_channels=0, + ): + super().__init__() + self.out_channels = out_channels + self.in_channels = in_hidden_channels + self.hidden_channels = in_hidden_channels + self.encoder_type = encoder_type + self.c_in_channels = c_in_channels + + # init encoder + if encoder_type.lower() == "relative_position_transformer": + # text encoder + # pylint: disable=unexpected-keyword-arg + self.encoder = RelativePositionTransformerEncoder( + in_hidden_channels, out_channels, in_hidden_channels, encoder_params + ) + elif encoder_type.lower() == "residual_conv_bn": + self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) + elif encoder_type.lower() == "fftransformer": + assert ( + in_hidden_channels == out_channels + ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" + # pylint: disable=unexpected-keyword-arg + self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) + else: + raise NotImplementedError(" [!] unknown encoder type.") + + def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + g: [B, C, 1] + """ + o = self.encoder(x, x_mask) + return o * x_mask diff --git a/viXTTS/TTS/tts/layers/generic/__init__.py b/viXTTS/TTS/tts/layers/generic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/generic/aligner.py b/viXTTS/TTS/tts/layers/generic/aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..baa6f0e9c4879207695b2de1193c9147b5a3fa4b --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/aligner.py @@ -0,0 +1,92 @@ +from typing import Tuple + +import torch +from torch import nn + + +class AlignmentNetwork(torch.nn.Module): + """Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. + + :: + + query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment + key -> conv1d -> relu -> conv1d -----------------------^ + + Args: + in_query_channels (int): Number of channels in the query network. Defaults to 80. + in_key_channels (int): Number of channels in the key network. Defaults to 512. + attn_channels (int): Number of inner channels in the attention layers. Defaults to 80. + temperature (float): Temperature for the softmax. Defaults to 0.0005. + """ + + def __init__( + self, + in_query_channels=80, + in_key_channels=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + self.softmax = torch.nn.Softmax(dim=3) + self.log_softmax = torch.nn.LogSoftmax(dim=3) + + self.key_layer = nn.Sequential( + nn.Conv1d( + in_key_channels, + in_key_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + self.query_layer = nn.Sequential( + nn.Conv1d( + in_query_channels, + in_query_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_(self.key_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.key_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) + torch.nn.init.xavier_uniform_(self.query_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.query_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) + torch.nn.init.xavier_uniform_(self.query_layer[4].weight, gain=torch.nn.init.calculate_gain("linear")) + + def forward( + self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None + ) -> Tuple[torch.tensor, torch.tensor]: + """Forward pass of the aligner encoder. + Shapes: + - queries: :math:`[B, C, T_de]` + - keys: :math:`[B, C_emb, T_en]` + - mask: :math:`[B, T_de]` + Output: + attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. + attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. + """ + key_out = self.key_layer(keys) + query_out = self.query_layer(queries) + attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 + attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) + if attn_prior is not None: + attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) + + if mask is not None: + attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + + attn = self.softmax(attn_logp) + return attn, attn_logp diff --git a/viXTTS/TTS/tts/layers/generic/gated_conv.py b/viXTTS/TTS/tts/layers/generic/gated_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9a29c4499f970db538a4b99c3c05cba22576195f --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/gated_conv.py @@ -0,0 +1,37 @@ +from torch import nn + +from .normalization import LayerNorm + + +class GatedConvBlock(nn.Module): + """Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf + Args: + in_out_channels (int): number of input/output channels. + kernel_size (int): convolution kernel size. + dropout_p (float): dropout rate. + """ + + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): + super().__init__() + # class arguments + self.dropout_p = dropout_p + self.num_layers = num_layers + # define layers + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)] + self.norm_layers += [LayerNorm(2 * in_out_channels)] + + def forward(self, x, x_mask): + o = x + res = x + for idx in range(self.num_layers): + o = nn.functional.dropout(o, p=self.dropout_p, training=self.training) + o = self.conv_layers[idx](o * x_mask) + o = self.norm_layers[idx](o) + o = nn.functional.glu(o, dim=1) + o = res + o + res = o + return o diff --git a/viXTTS/TTS/tts/layers/generic/normalization.py b/viXTTS/TTS/tts/layers/generic/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c0270e405e4246e47b7bc0787e4cd4b069533f92 --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/normalization.py @@ -0,0 +1,123 @@ +import torch +from torch import nn + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + """Layer norm for the 2nd dimension of the input. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1) + self.beta = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x): + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + self.eps) + x = x * self.gamma + self.beta + return x + + +class LayerNorm2(nn.Module): + """Layer norm for the 2nd dimension of the input using torch primitive. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class TemporalBatchNorm1d(nn.BatchNorm1d): + """Normalize each channel separately over time and batch.""" + + def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): + super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) + + def forward(self, x): + return super().forward(x.transpose(2, 1)).transpose(2, 1) + + +class ActNorm(nn.Module): + """Activation Normalization bijector as an alternative to Batch Norm. It computes + mean and std from a sample data in advance and it uses these values + for normalization at training. + + Args: + channels (int): input channels. + ddi (False): data depended initialization flag. + + Shapes: + - inputs: (B, C, T) + - outputs: (B, C, T) + """ + + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = None + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m**2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) diff --git a/viXTTS/TTS/tts/layers/generic/pos_encoding.py b/viXTTS/TTS/tts/layers/generic/pos_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..913add0d14332bf70c3ecd2a95869d0071310bd4 --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/pos_encoding.py @@ -0,0 +1,69 @@ +import math + +import torch +from torch import nn + + +class PositionalEncoding(nn.Module): + """Sinusoidal positional encoding for non-recurrent neural networks. + Implementation based on "Attention Is All You Need" + + Args: + channels (int): embedding size + dropout_p (float): dropout rate applied to the output. + max_len (int): maximum sequence length. + use_scale (bool): whether to use a learnable scaling coefficient. + """ + + def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False): + super().__init__() + if channels % 2 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) + ) + self.use_scale = use_scale + if use_scale: + self.scale = torch.nn.Parameter(torch.ones(1)) + pe = torch.zeros(max_len, channels) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0).transpose(1, 2) + self.register_buffer("pe", pe) + if dropout_p > 0: + self.dropout = nn.Dropout(p=dropout_p) + self.channels = channels + + def forward(self, x, mask=None, first_idx=None, last_idx=None): + """ + Shapes: + x: [B, C, T] + mask: [B, 1, T] + first_idx: int + last_idx: int + """ + + x = x * math.sqrt(self.channels) + if first_idx is None: + if self.pe.size(2) < x.size(2): + raise RuntimeError( + f"Sequence is {x.size(2)} but PositionalEncoding is" + f" limited to {self.pe.size(2)}. See max_len argument." + ) + if mask is not None: + pos_enc = self.pe[:, :, : x.size(2)] * mask + else: + pos_enc = self.pe[:, :, : x.size(2)] + if self.use_scale: + x = x + self.scale * pos_enc + else: + x = x + pos_enc + else: + if self.use_scale: + x = x + self.scale * self.pe[:, :, first_idx:last_idx] + else: + x = x + self.pe[:, :, first_idx:last_idx] + if hasattr(self, "dropout"): + x = self.dropout(x) + return x diff --git a/viXTTS/TTS/tts/layers/generic/res_conv_bn.py b/viXTTS/TTS/tts/layers/generic/res_conv_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..4beda291aa15398024b5b16cd6bf12b88898a0a9 --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/res_conv_bn.py @@ -0,0 +1,127 @@ +from torch import nn + + +class ZeroTemporalPad(nn.Module): + """Pad sequences to equal lentgh in the temporal dimension""" + + def __init__(self, kernel_size, dilation): + super().__init__() + total_pad = dilation * (kernel_size - 1) + begin = total_pad // 2 + end = total_pad - begin + self.pad_layer = nn.ZeroPad2d((0, 0, begin, end)) + + def forward(self, x): + return self.pad_layer(x) + + +class Conv1dBN(nn.Module): + """1d convolutional with batch norm. + conv1d -> relu -> BN blocks. + + Note: + Batch normalization is applied after ReLU regarding the original implementation. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): kernel size for convolutional filters. + dilation (int): dilation for convolution layers. + """ + + def __init__(self, in_channels, out_channels, kernel_size, dilation): + super().__init__() + padding = dilation * (kernel_size - 1) + pad_s = padding // 2 + pad_e = padding - pad_s + self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation) + self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0)) # uneven left and right padding + self.norm = nn.BatchNorm1d(out_channels) + + def forward(self, x): + o = self.conv1d(x) + o = self.pad(o) + o = nn.functional.relu(o) + o = self.norm(o) + return o + + +class Conv1dBNBlock(nn.Module): + """1d convolutional block with batch norm. It is a set of conv1d -> relu -> BN blocks. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of inner convolution channels. + kernel_size (int): kernel size for convolutional filters. + dilation (int): dilation for convolution layers. + num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2): + super().__init__() + self.conv_bn_blocks = [] + for idx in range(num_conv_blocks): + layer = Conv1dBN( + in_channels if idx == 0 else hidden_channels, + out_channels if idx == (num_conv_blocks - 1) else hidden_channels, + kernel_size, + dilation, + ) + self.conv_bn_blocks.append(layer) + self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks) + + def forward(self, x): + """ + Shapes: + x: (B, D, T) + """ + return self.conv_bn_blocks(x) + + +class ResidualConv1dBNBlock(nn.Module): + """Residual Convolutional Blocks with BN + Each block has 'num_conv_block' conv layers and 'num_res_blocks' such blocks are connected + with residual connections. + + conv_block = (conv1d -> relu -> bn) x 'num_conv_blocks' + residuak_conv_block = (x -> conv_block -> + ->) x 'num_res_blocks' + ' - - - - - - - - - ^ + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of inner convolution channels. + kernel_size (int): kernel size for convolutional filters. + dilations (list): dilations for each convolution layer. + num_res_blocks (int, optional): number of residual blocks. Defaults to 13. + num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2. + """ + + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2 + ): + super().__init__() + assert len(dilations) == num_res_blocks + self.res_blocks = nn.ModuleList() + for idx, dilation in enumerate(dilations): + block = Conv1dBNBlock( + in_channels if idx == 0 else hidden_channels, + out_channels if (idx + 1) == len(dilations) else hidden_channels, + hidden_channels, + kernel_size, + dilation, + num_conv_blocks, + ) + self.res_blocks.append(block) + + def forward(self, x, x_mask=None): + if x_mask is None: + x_mask = 1.0 + o = x * x_mask + for block in self.res_blocks: + res = o + o = block(o) + o = o + res + if x_mask is not None: + o = o * x_mask + return o diff --git a/viXTTS/TTS/tts/layers/generic/time_depth_sep_conv.py b/viXTTS/TTS/tts/layers/generic/time_depth_sep_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..186cea02e75e156c40923de91086c369a9ea02ee --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/time_depth_sep_conv.py @@ -0,0 +1,84 @@ +import torch +from torch import nn + + +class TimeDepthSeparableConv(nn.Module): + """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf + It shows competative results with less computation and memory footprint.""" + + def __init__(self, in_channels, hid_channels, out_channels, kernel_size, bias=True): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hid_channels = hid_channels + self.kernel_size = kernel_size + + self.time_conv = nn.Conv1d( + in_channels, + 2 * hid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm1 = nn.BatchNorm1d(2 * hid_channels) + self.depth_conv = nn.Conv1d( + hid_channels, + hid_channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=hid_channels, + bias=bias, + ) + self.norm2 = nn.BatchNorm1d(hid_channels) + self.time_conv2 = nn.Conv1d( + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm3 = nn.BatchNorm1d(out_channels) + + def forward(self, x): + x_res = x + x = self.time_conv(x) + x = self.norm1(x) + x = nn.functional.glu(x, dim=1) + x = self.depth_conv(x) + x = self.norm2(x) + x = x * torch.sigmoid(x) + x = self.time_conv2(x) + x = self.norm3(x) + x = x_res + x + return x + + +class TimeDepthSeparableConvBlock(nn.Module): + def __init__(self, in_channels, hid_channels, out_channels, num_layers, kernel_size, bias=True): + super().__init__() + assert (kernel_size - 1) % 2 == 0 + assert num_layers > 1 + + self.layers = nn.ModuleList() + layer = TimeDepthSeparableConv( + in_channels, hid_channels, out_channels if num_layers == 1 else hid_channels, kernel_size, bias + ) + self.layers.append(layer) + for idx in range(num_layers - 1): + layer = TimeDepthSeparableConv( + hid_channels, + hid_channels, + out_channels if (idx + 1) == (num_layers - 1) else hid_channels, + kernel_size, + bias, + ) + self.layers.append(layer) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x * mask) + return x diff --git a/viXTTS/TTS/tts/layers/generic/transformer.py b/viXTTS/TTS/tts/layers/generic/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7ecee2bacb68cd330e18630531c97bc6f2e6a3 --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/transformer.py @@ -0,0 +1,89 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class FFTransformer(nn.Module): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn=1024, kernel_size_fft=3, dropout_p=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(in_out_channels, num_heads, dropout=dropout_p) + + padding = (kernel_size_fft - 1) // 2 + self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding) + self.conv2 = nn.Conv1d(hidden_channels_ffn, in_out_channels, kernel_size=kernel_size_fft, padding=padding) + + self.norm1 = nn.LayerNorm(in_out_channels) + self.norm2 = nn.LayerNorm(in_out_channels) + + self.dropout1 = nn.Dropout(dropout_p) + self.dropout2 = nn.Dropout(dropout_p) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """😦 ugly looking with all the transposing""" + src = src.permute(2, 0, 1) + src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src + src2) + # T x B x D -> B x D x T + src = src.permute(1, 2, 0) + src2 = self.conv2(F.relu(self.conv1(src))) + src2 = self.dropout2(src2) + src = src + src2 + src = src.transpose(1, 2) + src = self.norm2(src) + src = src.transpose(1, 2) + return src, enc_align + + +class FFTransformerBlock(nn.Module): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p): + super().__init__() + self.fft_layers = nn.ModuleList( + [ + FFTransformer( + in_out_channels=in_out_channels, + num_heads=num_heads, + hidden_channels_ffn=hidden_channels_ffn, + dropout_p=dropout_p, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument + """ + TODO: handle multi-speaker + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T] or [B, T]` + """ + if mask is not None and mask.ndim == 3: + mask = mask.squeeze(1) + # mask is negated, torch uses 1s and 0s reversely. + mask = ~mask.bool() + alignments = [] + for layer in self.fft_layers: + x, align = layer(x, src_key_padding_mask=mask) + alignments.append(align.unsqueeze(1)) + alignments = torch.cat(alignments, 1) + return x + + +class FFTDurationPredictor: + def __init__( + self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None + ): # pylint: disable=unused-argument + self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.proj = nn.Linear(in_channels, 1) + + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T]` + + TODO: Handle the cond input + """ + x = self.fft(x, mask=mask) + x = self.proj(x) + return x diff --git a/viXTTS/TTS/tts/layers/generic/wavenet.py b/viXTTS/TTS/tts/layers/generic/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..f8de63b49fea8a7140ddd0493446f0541abe6a0a --- /dev/null +++ b/viXTTS/TTS/tts/layers/generic/wavenet.py @@ -0,0 +1,176 @@ +import torch +from torch import nn +from torch.nn.utils import parametrize + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class WN(torch.nn.Module): + """Wavenet layers with weight norm and no input conditioning. + + |-----------------------------------------------------------------------------| + | |-> tanh -| | + res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res + g -------------------------------------| |-> sigmoid -| | + o --------------------------------------------------------------------------- + --------- o + + Args: + in_channels (int): number of input channels. + hidden_channes (int): number of hidden channels. + kernel_size (int): filter kernel size for the first conv layer. + dilation_rate (int): dilations rate to increase dilation per layer. + If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers. + num_layers (int): number of wavenet layers. + c_in_channels (int): number of channels of conditioning input. + dropout_p (float): dropout rate. + weight_norm (bool): enable/disable weight norm for convolution layers. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): + super().__init__() + assert kernel_size % 2 == 1 + assert hidden_channels % 2 == 0 + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.dropout = nn.Dropout(dropout_p) + + # init conditioning layer + if c_in_channels > 0: + cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) + self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight") + # intermediate layers + for i in range(num_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + if i == 0: + in_layer = torch.nn.Conv1d( + in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + else: + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + if i < num_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + # setup weight norm + if not weight_norm: + self.remove_weight_norm() + + def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + x_mask = 1.0 if x_mask is None else x_mask + if g is not None: + g = self.cond_layer(g) + for i in range(self.num_layers): + x_in = self.in_layers[i](x) + x_in = self.dropout(x_in) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.num_layers - 1: + x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.c_in_channels != 0: + parametrize.remove_parametrizations(self.cond_layer, "weight") + for l in self.in_layers: + parametrize.remove_parametrizations(l, "weight") + for l in self.res_skip_layers: + parametrize.remove_parametrizations(l, "weight") + + +class WNBlocks(nn.Module): + """Wavenet blocks. + + Note: After each block dilation resets to 1 and it increases in each block + along the dilation rate. + + Args: + in_channels (int): number of input channels. + hidden_channes (int): number of hidden channels. + kernel_size (int): filter kernel size for the first conv layer. + dilation_rate (int): dilations rate to increase dilation per layer. + If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers. + num_blocks (int): number of wavenet blocks. + num_layers (int): number of wavenet layers. + c_in_channels (int): number of channels of conditioning input. + dropout_p (float): dropout rate. + weight_norm (bool): enable/disable weight norm for convolution layers. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_blocks, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): + super().__init__() + self.wn_blocks = nn.ModuleList() + for idx in range(num_blocks): + layer = WN( + in_channels=in_channels if idx == 0 else hidden_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + weight_norm=weight_norm, + ) + self.wn_blocks.append(layer) + + def forward(self, x, x_mask=None, g=None): + o = x + for layer in self.wn_blocks: + o = layer(o, x_mask, g) + return o diff --git a/viXTTS/TTS/tts/layers/glow_tts/__init__.py b/viXTTS/TTS/tts/layers/glow_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/glow_tts/decoder.py b/viXTTS/TTS/tts/layers/glow_tts/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..61c5174ac5e67885288043885290c2906656c99c --- /dev/null +++ b/viXTTS/TTS/tts/layers/glow_tts/decoder.py @@ -0,0 +1,141 @@ +import torch +from torch import nn + +from TTS.tts.layers.generic.normalization import ActNorm +from TTS.tts.layers.glow_tts.glow import CouplingBlock, InvConvNear + + +def squeeze(x, x_mask=None, num_sqz=2): + """GlowTTS squeeze operation + Increase number of channels and reduce number of time steps + by the same factor. + + Note: + each 's' is a n-dimensional vector. + ``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]`` + """ + b, c, t = x.size() + + t = (t // num_sqz) * num_sqz + x = x[:, :, :t] + x_sqz = x.view(b, c, t // num_sqz, num_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz) + + if x_mask is not None: + x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz] + else: + x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype) + return x_sqz * x_mask, x_mask + + +def unsqueeze(x, x_mask=None, num_sqz=2): + """GlowTTS unsqueeze operation (revert the squeeze) + + Note: + each 's' is a n-dimensional vector. + ``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5, s2, s4, s6]]`` + """ + b, c, t = x.size() + + x_unsqz = x.view(b, num_sqz, c // num_sqz, t) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz) + + if x_mask is not None: + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz) + else: + x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype) + return x_unsqz * x_mask, x_mask + + +class Decoder(nn.Module): + """Stack of Glow Decoder Modules. + + :: + + Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze + + Args: + in_channels (int): channels of input tensor. + hidden_channels (int): hidden decoder channels. + kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_flow_blocks (int): number of decoder blocks. + num_coupling_layers (int): number coupling layers. (number of wavenet layers.) + dropout_p (float): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): + super().__init__() + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_flow_blocks = num_flow_blocks + self.num_coupling_layers = num_coupling_layers + self.dropout_p = dropout_p + self.num_splits = num_splits + self.num_squeeze = num_squeeze + self.sigmoid_scale = sigmoid_scale + self.c_in_channels = c_in_channels + + self.flows = nn.ModuleList() + for _ in range(num_flow_blocks): + self.flows.append(ActNorm(channels=in_channels * num_squeeze)) + self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits)) + self.flows.append( + CouplingBlock( + in_channels * num_squeeze, + hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_coupling_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + sigmoid_scale=sigmoid_scale, + ) + ) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1 ,T]` + - g: :math:`[B, C]` + """ + if not reverse: + flows = self.flows + logdet_tot = 0 + else: + flows = reversed(self.flows) + logdet_tot = None + + if self.num_squeeze > 1: + x, x_mask = squeeze(x, x_mask, self.num_squeeze) + for f in flows: + if not reverse: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + logdet_tot += logdet + else: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + if self.num_squeeze > 1: + x, x_mask = unsqueeze(x, x_mask, self.num_squeeze) + return x, logdet_tot + + def store_inverse(self): + for f in self.flows: + f.store_inverse() diff --git a/viXTTS/TTS/tts/layers/glow_tts/duration_predictor.py b/viXTTS/TTS/tts/layers/glow_tts/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..e766ed6ab5a0348eaca8d1482be124003d8b8c68 --- /dev/null +++ b/viXTTS/TTS/tts/layers/glow_tts/duration_predictor.py @@ -0,0 +1,69 @@ +import torch +from torch import nn + +from ..generic.normalization import LayerNorm + + +class DurationPredictor(nn.Module): + """Glow-TTS duration prediction model. + + :: + + [2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs + + Args: + in_channels (int): Number of channels of the input tensor. + hidden_channels (int): Number of hidden channels of the network. + kernel_size (int): Kernel size for the conv layers. + dropout_p (float): Dropout rate used after each conv layer. + """ + + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None): + super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + + # class arguments + self.in_channels = in_channels + self.filter_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + # layers + self.drop = nn.Dropout(dropout_p) + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(hidden_channels) + self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(hidden_channels) + # output layer + self.proj = nn.Conv1d(hidden_channels, 1, 1) + if cond_channels is not None and cond_channels != 0: + self.cond = nn.Conv1d(cond_channels, in_channels, 1) + + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1) + + def forward(self, x, x_mask, g=None, lang_emb=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if g is not None: + x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask diff --git a/viXTTS/TTS/tts/layers/glow_tts/encoder.py b/viXTTS/TTS/tts/layers/glow_tts/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3b43e527f5e9ca2bd0880bf204e04a1526bc8dfb --- /dev/null +++ b/viXTTS/TTS/tts/layers/glow_tts/encoder.py @@ -0,0 +1,179 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.generic.gated_conv import GatedConvBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.helpers import sequence_mask + + +class Encoder(nn.Module): + """Glow-TTS encoder module. + + :: + + embedding -> -> encoder_module -> --> proj_mean + | + |-> proj_var + | + |-> concat -> duration_predictor + ↑ + speaker_embed + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + hidden_channels (int): encoder's embedding size. + hidden_channels_ffn (int): transformer's feed-forward channels. + kernel_size (int): kernel size for conv layers and duration predictor. + dropout_p (float): dropout rate for any dropout layer. + mean_only (bool): if True, output only mean values and use constant std. + use_prenet (bool): if True, use pre-convolutional layers before transformer layers. + c_in_channels (int): number of channels in conditional input. + + Shapes: + - input: (B, T, C) + + :: + + suggested encoder params... + + for encoder_type == 'rel_pos_transformer' + encoder_params={ + 'kernel_size':3, + 'dropout_p': 0.1, + 'num_layers': 6, + 'num_heads': 2, + 'hidden_channels_ffn': 768, # 4 times the hidden_channels + 'input_length': None + } + + for encoder_type == 'gated_conv' + encoder_params={ + 'kernel_size':5, + 'dropout_p': 0.1, + 'num_layers': 9, + } + + for encoder_type == 'residual_conv_bn' + encoder_params={ + "kernel_size": 4, + "dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + } + + for encoder_type == 'time_depth_separable' + encoder_params={ + "kernel_size": 5, + 'num_layers': 9, + } + """ + + def __init__( + self, + num_chars, + out_channels, + hidden_channels, + hidden_channels_dp, + encoder_type, + encoder_params, + dropout_p_dp=0.1, + mean_only=False, + use_prenet=True, + c_in_channels=0, + ): + super().__init__() + # class arguments + self.num_chars = num_chars + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.hidden_channels_dp = hidden_channels_dp + self.dropout_p_dp = dropout_p_dp + self.mean_only = mean_only + self.use_prenet = use_prenet + self.c_in_channels = c_in_channels + self.encoder_type = encoder_type + # embedding layer + self.emb = nn.Embedding(num_chars, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + # init encoder module + if encoder_type.lower() == "rel_pos_transformer": + if use_prenet: + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = RelativePositionTransformer( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + elif encoder_type.lower() == "gated_conv": + self.encoder = GatedConvBlock(hidden_channels, **encoder_params) + elif encoder_type.lower() == "residual_conv_bn": + if use_prenet: + self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) + self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params) + self.postnet = nn.Sequential( + nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels) + ) + elif encoder_type.lower() == "time_depth_separable": + if use_prenet: + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = TimeDepthSeparableConvBlock( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + else: + raise ValueError(" [!] Unkown encoder type.") + + # final projection layers + self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) + if not mean_only: + self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) + # duration predictor + self.duration_predictor = DurationPredictor( + hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp + ) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B]` + - g (optional): :math:`[B, 1, T]` + """ + # embedding layer + # [B ,T, D] + x = self.emb(x) * math.sqrt(self.hidden_channels) + # [B, D, T] + x = torch.transpose(x, 1, -1) + # compute input sequence mask + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + # prenet + if hasattr(self, "prenet") and self.use_prenet: + x = self.prenet(x, x_mask) + # encoder + x = self.encoder(x, x_mask) + # postnet + if hasattr(self, "postnet"): + x = self.postnet(x) * x_mask + # set duration predictor input + if g is not None: + g_exp = g.expand(-1, -1, x.size(-1)) + x_dp = torch.cat([x.detach(), g_exp], 1) + else: + x_dp = x.detach() + # final projection layer + x_m = self.proj_m(x) * x_mask + if not self.mean_only: + x_logs = self.proj_s(x) * x_mask + else: + x_logs = torch.zeros_like(x_m) + # duration predictor + logw = self.duration_predictor(x_dp, x_mask) + return x_m, x_logs, logw, x_mask diff --git a/viXTTS/TTS/tts/layers/glow_tts/glow.py b/viXTTS/TTS/tts/layers/glow_tts/glow.py new file mode 100644 index 0000000000000000000000000000000000000000..b02c3118085fbd3305796d4ce7f0d149fa1bf72e --- /dev/null +++ b/viXTTS/TTS/tts/layers/glow_tts/glow.py @@ -0,0 +1,233 @@ +import torch +from packaging.version import Version +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.wavenet import WN + +from ..generic.normalization import LayerNorm + + +class ResidualConv1dLayerNormBlock(nn.Module): + """Conv1d with Layer Normalization and residual connection as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + :: + + x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o + |---------------> conv1d_1x1 ------------------| + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of inner layer channels. + out_channels (int): number of output tensor channels. + kernel_size (int): kernel size of conv1d filter. + num_layers (int): number of blocks. + dropout_p (float): dropout rate for each block. + """ + + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_layers = num_layers + self.dropout_p = dropout_p + assert num_layers > 1, " [!] number of layers should be > 0." + assert kernel_size % 2 == 1, " [!] kernel size should be odd number." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + + for idx in range(num_layers): + self.conv_layers.append( + nn.Conv1d( + in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + x_res = x + for i in range(self.num_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x * x_mask) + x = F.dropout(F.relu(x), self.dropout_p, training=self.training) + x = x_res + self.proj(x) + return x * x_mask + + +class InvConvNear(nn.Module): + """Invertible Convolution with input splitting as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + Args: + channels (int): input and output channels. + num_splits (int): number of splits, also H and W of conv layer. + no_jacobian (bool): enable/disable jacobian computations. + + Note: + Split the input into groups of size self.num_splits and + perform 1x1 convolution separately. Cast 1x1 conv operation + to 2d by reshaping the input for efficiency. + """ + + def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + assert num_splits % 2 == 0 + self.channels = channels + self.num_splits = num_splits + self.no_jacobian = no_jacobian + self.weight_inv = None + + if Version(torch.__version__) < Version("1.9"): + w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] + else: + w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0] + + if torch.det(w_init) < 0: + w_init[:, 0] = -1 * w_init[:, 0] + self.weight = nn.Parameter(w_init) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + b, c, t = x.size() + assert c % self.num_splits == 0 + if x_mask is None: + x_mask = 1 + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t + else: + x_len = torch.sum(x_mask, [1, 2]) + + x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t) + + if reverse: + if self.weight_inv is not None: + weight = self.weight_inv + else: + weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + logdet = None + else: + weight = self.weight + if self.no_jacobian: + logdet = 0 + else: + logdet = torch.logdet(self.weight) * (c / self.num_splits) * x_len # [b] + + weight = weight.view(self.num_splits, self.num_splits, 1, 1) + z = F.conv2d(x, weight) + + z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t) + z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask + return z, logdet + + def store_inverse(self): + weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + self.weight_inv = nn.Parameter(weight_inv, requires_grad=False) + + +class CouplingBlock(nn.Module): + """Glow Affine Coupling block as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + :: + + x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o + '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of hidden channels. + kernel_size (int): WaveNet filter kernel size. + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_layers (int): number of WaveNet layers. + c_in_channels (int): number of conditioning input channels. + dropout_p (int): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. + + Note: + It does not use the conditional inputs differently from WaveGlow. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + sigmoid_scale=False, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + self.sigmoid_scale = sigmoid_scale + # input layer + start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) + start = torch.nn.utils.parametrizations.weight_norm(start) + self.start = start + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(hidden_channels, in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + # coupling layers + self.wn = WN(hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) + + def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if x_mask is None: + x_mask = 1 + x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] + + x = self.start(x_0) * x_mask + x = self.wn(x, x_mask, g) + out = self.end(x) + + z_0 = x_0 + t = out[:, : self.in_channels // 2, :] + s = out[:, self.in_channels // 2 :, :] + if self.sigmoid_scale: + s = torch.log(1e-6 + torch.sigmoid(s + 2)) + + if reverse: + z_1 = (x_1 - t) * torch.exp(-s) * x_mask + logdet = None + else: + z_1 = (t + torch.exp(s) * x_1) * x_mask + logdet = torch.sum(s * x_mask, [1, 2]) + + z = torch.cat([z_0, z_1], 1) + return z, logdet + + def store_inverse(self): + self.wn.remove_weight_norm() diff --git a/viXTTS/TTS/tts/layers/glow_tts/transformer.py b/viXTTS/TTS/tts/layers/glow_tts/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..02688d611fe41394e8e1fedbc5742845eae85cfd --- /dev/null +++ b/viXTTS/TTS/tts/layers/glow_tts/transformer.py @@ -0,0 +1,432 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 + + +class RelativePositionMultiHeadAttention(nn.Module): + """Multi-head attention with Relative Positional embedding. + https://arxiv.org/pdf/1809.04281.pdf + + It learns positional embeddings for a window of neighbours. For keys and values, + it learns different set of embeddings. Key embeddings are agregated with the attention + scores and value embeddings are aggregated with the output. + + Note: + Example with relative attention window size 2 + + - input = [a, b, c, d, e] + - rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)] + + So it learns 4 embedding vectors (in total 8) separately for key and value vectors. + + Considering the input c + + - e(t-2) corresponds to c -> a + - e(t-2) corresponds to c -> b + - e(t-2) corresponds to c -> d + - e(t-2) corresponds to c -> e + + These embeddings are shared among different time steps. So input a, b, d and e also uses + the same embeddings. + + Embeddings are ignored when the relative window is out of limit for the first and the last + n items. + + Args: + channels (int): input and inner layer channels. + out_channels (int): output channels. + num_heads (int): number of attention heads. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + heads_share (bool, optional): [description]. Defaults to True. + dropout_p (float, optional): dropout rate. Defaults to 0.. + input_length (int, optional): intput length for positional encoding. Defaults to None. + proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False. + proximal_init (bool, optional): enable/disable poximal init as in the paper. + Init key and query layer weights the same. Defaults to False. + """ + + def __init__( + self, + channels, + out_channels, + num_heads, + rel_attn_window_size=None, + heads_share=True, + dropout_p=0.0, + input_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." + # class attributes + self.channels = channels + self.out_channels = out_channels + self.num_heads = num_heads + self.rel_attn_window_size = rel_attn_window_size + self.heads_share = heads_share + self.input_length = input_length + self.proximal_bias = proximal_bias + self.dropout_p = dropout_p + self.attn = None + # query, key, value layers + self.k_channels = channels // num_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + # output layers + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.dropout = nn.Dropout(dropout_p) + # relative positional encoding layers + if rel_attn_window_size is not None: + n_heads_rel = 1 if heads_share else num_heads + rel_stddev = self.k_channels**-0.5 + emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + self.register_parameter("emb_rel_k", emb_rel_k) + self.register_parameter("emb_rel_v", emb_rel_v) + + # init layers + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + # proximal bias + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - c: :math:`[B, C, T]` + - attn_mask: :math:`[B, 1, T, T]` + """ + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + x, self.attn = self.attention(q, k, v, mask=attn_mask) + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + # compute raw attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + # relative positional encoding for scores + if self.rel_attn_window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + # get relative key embeddings + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + # proximan bias + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype) + # attention score masking + if mask is not None: + # add small value to prevent oor error. + scores = scores.masked_fill(mask == 0, -1e4) + if self.input_length is not None: + block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length) + scores = scores * block_mask + -1e4 * (1 - block_mask) + # attention score normalization + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + # apply dropout to attention weights + p_attn = self.dropout(p_attn) + # compute output + output = torch.matmul(p_attn, value) + # relative positional encoding for values + if self.rel_attn_window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + @staticmethod + def _matmul_with_relative_values(p_attn, re): + """ + Args: + p_attn (Tensor): attention weights. + re (Tensor): relative value embedding vector. (a_(i,j)^V) + + Shapes: + -p_attn: :math:`[B, H, T, V]` + -re: :math:`[H or 1, V, D]` + -logits: :math:`[B, H, T, D]` + """ + logits = torch.matmul(p_attn, re.unsqueeze(0)) + return logits + + @staticmethod + def _matmul_with_relative_keys(query, re): + """ + Args: + query (Tensor): batch of query vectors. (x*W^Q) + re (Tensor): relative key embedding vector. (a_(i,j)^K) + + Shapes: + - query: :math:`[B, H, T, D]` + - re: :math:`[H or 1, V, D]` + - logits: :math:`[B, H, T, V]` + """ + # logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)]) + logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1)) + return logits + + def _get_relative_embeddings(self, relative_embeddings, length): + """Convert embedding vestors to a tensor of embeddings""" + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.rel_attn_window_size + 1), 0) + slice_start_position = max((self.rel_attn_window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + @staticmethod + def _relative_position_to_absolute_position(x): + """Converts tensor from relative to absolute indexing for local attention. + Shapes: + x: :math:`[B, C, T, 2 * T - 1]` + Returns: + A Tensor of shape :math:`[B, C, T, T]` + """ + batch, heads, length, _ = x.size() + # Pad to shift from relative to absolute indexing. + x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) + # Pad extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + @staticmethod + def _absolute_position_to_relative_position(x): + """ + Shapes: + - x: :math:`[B, C, T, T]` + - ret: :math:`[B, C, T, 2*T-1]` + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + @staticmethod + def _attn_proximity_bias(length): + """Produce an attention mask that discourages distant + attention values. + Args: + length (int): an integer scalar. + Returns: + a Tensor with shape :math:`[1, 1, T, T]` + """ + # L + r = torch.arange(length, dtype=torch.float32) + # L x L + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + # scale mask values + diff = -torch.log1p(torch.abs(diff)) + # 1 x 1 x L x L + return diff.unsqueeze(0).unsqueeze(0) + + +class FeedForwardNetwork(nn.Module): + """Feed Forward Inner layers for Transformer. + + Args: + in_channels (int): input tensor channels. + out_channels (int): output tensor channels. + hidden_channels (int): inner layers hidden channels. + kernel_size (int): conv1d filter kernel size. + dropout_p (float, optional): dropout rate. Defaults to 0. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size) + self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + x = torch.relu(x) + x = self.dropout(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, self._pad_shape(padding)) + return x + + @staticmethod + def _pad_shape(padding): + l = padding[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +class RelativePositionTransformer(nn.Module): + """Transformer with Relative Potional Encoding. + https://arxiv.org/abs/1803.02155 + + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. + layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm + primitive. Use type "2", type "1: is for backward compat. Defaults to "1". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size: int = None, + input_length: int = None, + layer_norm_type: str = "1", + ): + super().__init__() + self.hidden_channels = hidden_channels + self.hidden_channels_ffn = hidden_channels_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.rel_attn_window_size = rel_attn_window_size + + self.dropout = nn.Dropout(dropout_p) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for idx in range(self.num_layers): + self.attn_layers.append( + RelativePositionMultiHeadAttention( + hidden_channels if idx != 0 else in_channels, + hidden_channels, + num_heads, + rel_attn_window_size=rel_attn_window_size, + dropout_p=dropout_p, + input_length=input_length, + ) + ) + if layer_norm_type == "1": + self.norm_layers_1.append(LayerNorm(hidden_channels)) + elif layer_norm_type == "2": + self.norm_layers_1.append(LayerNorm2(hidden_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + if hidden_channels != out_channels and (idx + 1) == self.num_layers: + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + self.ffn_layers.append( + FeedForwardNetwork( + hidden_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, + hidden_channels_ffn, + kernel_size, + dropout_p=dropout_p, + ) + ) + + if layer_norm_type == "1": + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + elif layer_norm_type == "2": + self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + def forward(self, x, x_mask): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.num_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.dropout(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.dropout(y) + + if (i + 1) == self.num_layers and hasattr(self, "proj"): + x = self.proj(x) + + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x diff --git a/viXTTS/TTS/tts/layers/losses.py b/viXTTS/TTS/tts/layers/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..de5f408c48cf9183dfb14c30a6248a2b300bde4d --- /dev/null +++ b/viXTTS/TTS/tts/layers/losses.py @@ -0,0 +1,889 @@ +import math + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional + +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss +from TTS.utils.audio.torch_transforms import TorchSTFT + + +# pylint: disable=abstract-method +# relates https://github.com/pytorch/pytorch/issues/42305 +class L1LossMasked(nn.Module): + def __init__(self, seq_len_norm): + super().__init__() + self.seq_len_norm = seq_len_norm + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len, dim) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len, dim) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + x: B x T X D + target: B x T x D + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.l1_loss(x * mask, target * mask, reduction="none") + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.l1_loss(x * mask, target * mask, reduction="sum") + loss = loss / mask.sum() + return loss + + +class MSELossMasked(nn.Module): + def __init__(self, seq_len_norm): + super().__init__() + self.seq_len_norm = seq_len_norm + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len, dim) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len, dim) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + - x: :math:`[B, T, D]` + - target: :math:`[B, T, D]` + - length: :math:`B` + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.mse_loss(x * mask, target * mask, reduction="none") + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.mse_loss(x * mask, target * mask, reduction="sum") + loss = loss / mask.sum() + return loss + + +def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Min-Max normalize tensor through first dimension + Shapes: + - x: :math:`[B, D1, D2]` + - m: :math:`[B, D1, 1]` + """ + maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True) + minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True) + return (x - minimum) / (maximum - minimum + 1e-8) + + +class SSIMLoss(torch.nn.Module): + """SSIM loss as (1 - SSIM) + SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity + """ + + def __init__(self): + super().__init__() + self.loss_func = _SSIMLoss() + + def forward(self, y_hat, y, length): + """ + Args: + y_hat (tensor): model prediction values. + y (tensor): target values. + length (tensor): length of each sample in a batch for masking. + + Shapes: + y_hat: B x T X D + y: B x T x D + length: B + + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2) + y_norm = sample_wise_min_max(y, mask) + y_hat_norm = sample_wise_min_max(y_hat, mask) + ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)) + + if ssim_loss.item() > 1.0: + print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") + ssim_loss = torch.tensor(1.0, device=ssim_loss.device) + + if ssim_loss.item() < 0.0: + print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") + ssim_loss = torch.tensor(0.0, device=ssim_loss.device) + + return ssim_loss + + +class AttentionEntropyLoss(nn.Module): + # pylint: disable=R0201 + def forward(self, align): + """ + Forces attention to be more decisive by penalizing + soft attention weights + """ + entropy = torch.distributions.Categorical(probs=align).entropy() + loss = (entropy / np.log(align.shape[1])).mean() + return loss + + +class BCELossMasked(nn.Module): + """BCE loss with masking. + + Used mainly for stopnet in autoregressive models. + + Args: + pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None. + """ + + def __init__(self, pos_weight: float = None): + super().__init__() + self.register_buffer("pos_weight", torch.tensor([pos_weight])) + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + x: B x T + target: B x T + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + target.requires_grad = False + if length is not None: + # mask: (batch, max_len, 1) + mask = sequence_mask(sequence_length=length, max_len=target.size(1)) + num_items = mask.sum() + loss = functional.binary_cross_entropy_with_logits( + x.masked_select(mask), + target.masked_select(mask), + pos_weight=self.pos_weight.to(x.device), + reduction="sum", + ) + else: + loss = functional.binary_cross_entropy_with_logits( + x, target, pos_weight=self.pos_weight.to(x.device), reduction="sum" + ) + num_items = torch.numel(x) + loss = loss / num_items + return loss + + +class DifferentialSpectralLoss(nn.Module): + """Differential Spectral Loss + https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf""" + + def __init__(self, loss_func): + super().__init__() + self.loss_func = loss_func + + def forward(self, x, target, length=None): + """ + Shapes: + x: B x T + target: B x T + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + x_diff = x[:, 1:] - x[:, :-1] + target_diff = target[:, 1:] - target[:, :-1] + if length is None: + return self.loss_func(x_diff, target_diff) + return self.loss_func(x_diff, target_diff, length - 1) + + +class GuidedAttentionLoss(torch.nn.Module): + def __init__(self, sigma=0.4): + super().__init__() + self.sigma = sigma + + def _make_ga_masks(self, ilens, olens): + B = len(ilens) + max_ilen = max(ilens) + max_olen = max(olens) + ga_masks = torch.zeros((B, max_olen, max_ilen)) + for idx, (ilen, olen) in enumerate(zip(ilens, olens)): + ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma) + return ga_masks + + def forward(self, att_ws, ilens, olens): + ga_masks = self._make_ga_masks(ilens, olens).to(att_ws.device) + seq_masks = self._make_masks(ilens, olens).to(att_ws.device) + losses = ga_masks * att_ws + loss = torch.mean(losses.masked_select(seq_masks)) + return loss + + @staticmethod + def _make_ga_mask(ilen, olen, sigma): + grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) + grid_x, grid_y = grid_x.float(), grid_y.float() + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) + + @staticmethod + def _make_masks(ilens, olens): + in_masks = sequence_mask(ilens) + out_masks = sequence_mask(olens) + return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) + + +class Huber(nn.Module): + # pylint: disable=R0201 + def forward(self, x, y, length=None): + """ + Shapes: + x: B x T + y: B x T + length: B + """ + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float() + return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum() + + +class ForwardSumLoss(nn.Module): + def __init__(self, blank_logprob=-1): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=3) + self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) + self.blank_logprob = blank_logprob + + def forward(self, attn_logprob, in_lens, out_lens): + key_lens = in_lens + query_lens = out_lens + attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) + + total_loss = 0.0 + for bid in range(attn_logprob.shape[0]): + target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) + curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] + + curr_logprob = self.log_softmax(curr_logprob[None])[0] + loss = self.ctc_loss( + curr_logprob, + target_seq, + input_lengths=query_lens[bid : bid + 1], + target_lengths=key_lens[bid : bid + 1], + ) + total_loss = total_loss + loss + + total_loss = total_loss / attn_logprob.shape[0] + return total_loss + + +######################## +# MODEL LOSS LAYERS +######################## + + +class TacotronLoss(torch.nn.Module): + """Collection of Tacotron set-up based on provided config.""" + + def __init__(self, c, ga_sigma=0.4): + super().__init__() + self.stopnet_pos_weight = c.stopnet_pos_weight + self.use_capacitron_vae = c.use_capacitron_vae + if self.use_capacitron_vae: + self.capacitron_capacity = c.capacitron_vae.capacitron_capacity + self.capacitron_vae_loss_alpha = c.capacitron_vae.capacitron_VAE_loss_alpha + self.ga_alpha = c.ga_alpha + self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha + self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha + self.decoder_alpha = c.decoder_loss_alpha + self.postnet_alpha = c.postnet_loss_alpha + self.decoder_ssim_alpha = c.decoder_ssim_alpha + self.postnet_ssim_alpha = c.postnet_ssim_alpha + self.config = c + + # postnet and decoder loss + if c.loss_masking: + self.criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron"] else MSELossMasked(c.seq_len_norm) + else: + self.criterion = nn.L1Loss() if c.model in ["Tacotron"] else nn.MSELoss() + # guided attention loss + if c.ga_alpha > 0: + self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma) + # differential spectral loss + if c.postnet_diff_spec_alpha > 0 or c.decoder_diff_spec_alpha > 0: + self.criterion_diff_spec = DifferentialSpectralLoss(loss_func=self.criterion) + # ssim loss + if c.postnet_ssim_alpha > 0 or c.decoder_ssim_alpha > 0: + self.criterion_ssim = SSIMLoss() + # stopnet loss + # pylint: disable=not-callable + self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None + + # For dev pruposes only + self.criterion_capacitron_reconstruction_loss = nn.L1Loss(reduction="sum") + + def forward( + self, + postnet_output, + decoder_output, + mel_input, + linear_input, + stopnet_output, + stopnet_target, + stop_target_length, + capacitron_vae_outputs, + output_lens, + decoder_b_output, + alignments, + alignment_lens, + alignments_backwards, + input_lens, + ): + # decoder outputs linear or mel spectrograms for Tacotron and Tacotron2 + # the target should be set acccordingly + postnet_target = linear_input if self.config.model.lower() in ["tacotron"] else mel_input + + return_dict = {} + # remove lengths if no masking is applied + if not self.config.loss_masking: + output_lens = None + # decoder and postnet losses + if self.config.loss_masking: + if self.decoder_alpha > 0: + decoder_loss = self.criterion(decoder_output, mel_input, output_lens) + if self.postnet_alpha > 0: + postnet_loss = self.criterion(postnet_output, postnet_target, output_lens) + else: + if self.decoder_alpha > 0: + decoder_loss = self.criterion(decoder_output, mel_input) + if self.postnet_alpha > 0: + postnet_loss = self.criterion(postnet_output, postnet_target) + loss = self.decoder_alpha * decoder_loss + self.postnet_alpha * postnet_loss + return_dict["decoder_loss"] = decoder_loss + return_dict["postnet_loss"] = postnet_loss + + if self.use_capacitron_vae: + # extract capacitron vae infos + posterior_distribution, prior_distribution, beta = capacitron_vae_outputs + + # KL divergence term between the posterior and the prior + kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution)) + + # Limit the mutual information between the data and latent space by the variational capacity limit + kl_capacity = kl_term - self.capacitron_capacity + + # pass beta through softplus to keep it positive + beta = torch.nn.functional.softplus(beta)[0] + + # This is the term going to the main ADAM optimiser, we detach beta because + # beta is optimised by a separate, SGD optimiser below + capacitron_vae_loss = beta.detach() * kl_capacity + + # normalize the capacitron_vae_loss as in L1Loss or MSELoss. + # After this, both the standard loss and capacitron_vae_loss will be in the same scale. + # For this reason we don't need use L1Loss and MSELoss in "sum" reduction mode. + # Note: the batch is not considered because the L1Loss was calculated in "sum" mode + # divided by the batch size, So not dividing the capacitron_vae_loss by B is legitimate. + + # get B T D dimension from input + B, T, D = mel_input.size() + # normalize + if self.config.loss_masking: + # if mask loss get T using the mask + T = output_lens.sum() / B + + # Only for dev purposes to be able to compare the reconstruction loss with the values in the + # original Capacitron paper + return_dict["capaciton_reconstruction_loss"] = ( + self.criterion_capacitron_reconstruction_loss(decoder_output, mel_input) / decoder_output.size(0) + ) + kl_capacity + + capacitron_vae_loss = capacitron_vae_loss / (T * D) + capacitron_vae_loss = capacitron_vae_loss * self.capacitron_vae_loss_alpha + + # This is the term to purely optimise beta and to pass into the SGD optimizer + beta_loss = torch.negative(beta) * kl_capacity.detach() + + loss += capacitron_vae_loss + + return_dict["capacitron_vae_loss"] = capacitron_vae_loss + return_dict["capacitron_vae_beta_loss"] = beta_loss + return_dict["capacitron_vae_kl_term"] = kl_term + return_dict["capacitron_beta"] = beta + + stop_loss = ( + self.criterion_st(stopnet_output, stopnet_target, stop_target_length) + if self.config.stopnet + else torch.zeros(1) + ) + loss += stop_loss + return_dict["stopnet_loss"] = stop_loss + + # backward decoder loss (if enabled) + if self.config.bidirectional_decoder: + if self.config.loss_masking: + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input, output_lens) + else: + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1,)), decoder_output) + loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss) + return_dict["decoder_b_loss"] = decoder_b_loss + return_dict["decoder_c_loss"] = decoder_c_loss + + # double decoder consistency loss (if enabled) + if self.config.double_decoder_consistency: + if self.config.loss_masking: + decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) + else: + decoder_b_loss = self.criterion(decoder_b_output, mel_input) + # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) + attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) + loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss) + return_dict["decoder_coarse_loss"] = decoder_b_loss + return_dict["decoder_ddc_loss"] = attention_c_loss + + # guided attention loss (if enabled) + if self.config.ga_alpha > 0: + ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens) + loss += ga_loss * self.ga_alpha + return_dict["ga_loss"] = ga_loss + + # decoder differential spectral loss + if self.config.decoder_diff_spec_alpha > 0: + decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens) + loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha + return_dict["decoder_diff_spec_loss"] = decoder_diff_spec_loss + + # postnet differential spectral loss + if self.config.postnet_diff_spec_alpha > 0: + postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, postnet_target, output_lens) + loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha + return_dict["postnet_diff_spec_loss"] = postnet_diff_spec_loss + + # decoder ssim loss + if self.config.decoder_ssim_alpha > 0: + decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens) + loss += decoder_ssim_loss * self.postnet_ssim_alpha + return_dict["decoder_ssim_loss"] = decoder_ssim_loss + + # postnet ssim loss + if self.config.postnet_ssim_alpha > 0: + postnet_ssim_loss = self.criterion_ssim(postnet_output, postnet_target, output_lens) + loss += postnet_ssim_loss * self.postnet_ssim_alpha + return_dict["postnet_ssim_loss"] = postnet_ssim_loss + + return_dict["loss"] = loss + return return_dict + + +class GlowTTSLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.constant_factor = 0.5 * math.log(2 * math.pi) + + def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths): + return_dict = {} + # flow loss - neg log likelihood + pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[2]) + # duration loss - MSE + loss_dur = torch.sum((o_dur_log - o_attn_dur) ** 2) / torch.sum(x_lengths) + # duration loss - huber loss + # loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths) + return_dict["loss"] = log_mle + loss_dur + return_dict["log_mle"] = log_mle + return_dict["loss_dur"] = loss_dur + + # check if any loss is NaN + for key, loss in return_dict.items(): + if torch.isnan(loss): + raise RuntimeError(f" [!] NaN loss with {key}.") + return return_dict + + +def mse_loss_custom(x, y): + """MSE loss using the torch back-end without reduction. + It uses less VRAM than the raw code""" + expanded_x, expanded_y = torch.broadcast_tensors(x, y) + return torch._C._nn.mse_loss(expanded_x, expanded_y, 0) # pylint: disable=protected-access, c-extension-no-member + + +class MDNLoss(nn.Module): + """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.""" + + def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use + """ + Shapes: + mu: [B, D, T] + log_sigma: [B, D, T] + mel_spec: [B, D, T] + """ + B, T_seq, T_mel = logp.shape + log_alpha = logp.new_ones(B, T_seq, T_mel) * (-1e4) + log_alpha[:, 0, 0] = logp[:, 0, 0] + for t in range(1, T_mel): + prev_step = torch.cat( + [log_alpha[:, :, t - 1 : t], functional.pad(log_alpha[:, :, t - 1 : t], (0, 0, 1, -1), value=-1e4)], + dim=-1, + ) + log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t] + alpha_last = log_alpha[torch.arange(B), text_lengths - 1, mel_lengths - 1] + mdn_loss = -alpha_last.mean() / T_seq + return mdn_loss # , log_prob_matrix + + +class AlignTTSLoss(nn.Module): + """Modified AlignTTS Loss. + Computes + - L1 and SSIM losses from output spectrograms. + - Huber loss for duration predictor. + - MDNLoss for Mixture of Density Network. + + All loss values are aggregated by a weighted sum of the alpha values. + + Args: + c (dict): TTS model configuration. + """ + + def __init__(self, c): + super().__init__() + self.mdn_loss = MDNLoss() + self.spec_loss = MSELossMasked(False) + self.ssim = SSIMLoss() + self.dur_loss = MSELossMasked(False) + + self.ssim_alpha = c.ssim_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.spec_loss_alpha = c.spec_loss_alpha + self.mdn_alpha = c.mdn_alpha + + def forward( + self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, phase + ): + # ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) + spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 + if phase == 0: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + elif phase == 1: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 2: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 3: + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) + else: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) + loss = ( + self.spec_loss_alpha * spec_loss + + self.ssim_alpha * ssim_loss + + self.dur_loss_alpha * dur_loss + + self.mdn_alpha * mdn_loss + ) + return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} + + +class VitsGeneratorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.kl_loss_alpha = c.kl_loss_alpha + self.gen_loss_alpha = c.gen_loss_alpha + self.feat_loss_alpha = c.feat_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.mel_loss_alpha = c.mel_loss_alpha + self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha + self.stft = TorchSTFT( + c.audio.fft_size, + c.audio.hop_length, + c.audio.win_length, + sample_rate=c.audio.sample_rate, + mel_fmin=c.audio.mel_fmin, + mel_fmax=c.audio.mel_fmax, + n_mels=c.audio.num_mels, + use_mel=True, + do_amp_to_db=True, + ) + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + @staticmethod + def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l + + @staticmethod + def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): + return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + + def forward( + self, + mel_slice, + mel_slice_hat, + z_p, + logs_q, + m_p, + logs_p, + z_len, + scores_disc_fake, + feats_disc_fake, + feats_disc_real, + loss_duration, + use_speaker_encoder_as_loss=False, + gt_spk_emb=None, + syn_spk_emb=None, + ): + """ + Shapes: + - mel_slice : :math:`[B, 1, T]` + - mel_slice_hat: :math:`[B, 1, T]` + - z_p: :math:`[B, C, T]` + - logs_q: :math:`[B, C, T]` + - m_p: :math:`[B, C, T]` + - logs_p: :math:`[B, C, T]` + - z_len: :math:`[B]` + - scores_disc_fake[i]: :math:`[B, C]` + - feats_disc_fake[i][j]: :math:`[B, C, T', P]` + - feats_disc_real[i][j]: :math:`[B, C, T', P]` + """ + loss = 0.0 + return_dict = {} + z_mask = sequence_mask(z_len).float() + # compute losses + loss_kl = ( + self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) + * self.kl_loss_alpha + ) + loss_feat = ( + self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + ) + loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha + loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha + loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha + loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration + + if use_speaker_encoder_as_loss: + loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha + loss = loss + loss_se + return_dict["loss_spk_encoder"] = loss_se + # pass losses to the dict + return_dict["loss_gen"] = loss_gen + return_dict["loss_kl"] = loss_kl + return_dict["loss_feat"] = loss_feat + return_dict["loss_mel"] = loss_mel + return_dict["loss_duration"] = loss_duration + return_dict["loss"] = loss + return return_dict + + +class VitsDiscriminatorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.disc_loss_alpha = c.disc_loss_alpha + + @staticmethod + def discriminator_loss(scores_real, scores_fake): + loss = 0 + real_losses = [] + fake_losses = [] + for dr, dg in zip(scores_real, scores_fake): + dr = dr.float() + dg = dg.float() + real_loss = torch.mean((1 - dr) ** 2) + fake_loss = torch.mean(dg**2) + loss += real_loss + fake_loss + real_losses.append(real_loss.item()) + fake_losses.append(fake_loss.item()) + return loss, real_losses, fake_losses + + def forward(self, scores_disc_real, scores_disc_fake): + loss = 0.0 + return_dict = {} + loss_disc, loss_disc_real, _ = self.discriminator_loss( + scores_real=scores_disc_real, scores_fake=scores_disc_fake + ) + return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha + loss = loss + return_dict["loss_disc"] + return_dict["loss"] = loss + + for i, ldr in enumerate(loss_disc_real): + return_dict[f"loss_disc_real_{i}"] = ldr + return return_dict + + +class ForwardTTSLoss(nn.Module): + """Generic configurable ForwardTTS loss.""" + + def __init__(self, c): + super().__init__() + if c.spec_loss_type == "mse": + self.spec_loss = MSELossMasked(False) + elif c.spec_loss_type == "l1": + self.spec_loss = L1LossMasked(False) + else: + raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type)) + + if c.duration_loss_type == "mse": + self.dur_loss = MSELossMasked(False) + elif c.duration_loss_type == "l1": + self.dur_loss = L1LossMasked(False) + elif c.duration_loss_type == "huber": + self.dur_loss = Huber() + else: + raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type)) + + if c.model_args.use_aligner: + self.aligner_loss = ForwardSumLoss() + self.aligner_loss_alpha = c.aligner_loss_alpha + + if c.model_args.use_pitch: + self.pitch_loss = MSELossMasked(False) + self.pitch_loss_alpha = c.pitch_loss_alpha + + if c.model_args.use_energy: + self.energy_loss = MSELossMasked(False) + self.energy_loss_alpha = c.energy_loss_alpha + + if c.use_ssim_loss: + self.ssim = SSIMLoss() if c.use_ssim_loss else None + self.ssim_loss_alpha = c.ssim_loss_alpha + + self.spec_loss_alpha = c.spec_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.binary_alignment_loss_alpha = c.binary_align_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + energy_output, + energy_target, + input_lens, + alignment_logprob=None, + alignment_hard=None, + alignment_soft=None, + binary_loss_weight=None, + ): + loss = 0 + return_dict = {} + if hasattr(self, "ssim_loss") and self.ssim_loss_alpha > 0: + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.ssim_loss_alpha * ssim_loss + return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss + + if self.spec_loss_alpha > 0: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.spec_loss_alpha * spec_loss + return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss + + if self.dur_loss_alpha > 0: + log_dur_tgt = torch.log(dur_target.float() + 1) + dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) + loss = loss + self.dur_loss_alpha * dur_loss + return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss + + if hasattr(self, "pitch_loss") and self.pitch_loss_alpha > 0: + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) + loss = loss + self.pitch_loss_alpha * pitch_loss + return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + + if hasattr(self, "energy_loss") and self.energy_loss_alpha > 0: + energy_loss = self.energy_loss(energy_output.transpose(1, 2), energy_target.transpose(1, 2), input_lens) + loss = loss + self.energy_loss_alpha * energy_loss + return_dict["loss_energy"] = self.energy_loss_alpha * energy_loss + + if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: + aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + loss = loss + self.aligner_loss_alpha * aligner_loss + return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + + if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss + if binary_loss_weight: + return_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + return_dict["loss"] = loss + return return_dict diff --git a/viXTTS/TTS/tts/layers/overflow/__init__.py b/viXTTS/TTS/tts/layers/overflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/overflow/common_layers.py b/viXTTS/TTS/tts/layers/overflow/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..b036dd1bda92fb709f0cce796cf5a668a1c081df --- /dev/null +++ b/viXTTS/TTS/tts/layers/overflow/common_layers.py @@ -0,0 +1,323 @@ +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from tqdm.auto import tqdm + +from TTS.tts.layers.tacotron.common_layers import Linear +from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock + + +class Encoder(nn.Module): + r"""Neural HMM Encoder + + Same as Tacotron 2 encoder but increases the input length by states per phone + + Args: + num_chars (int): Number of characters in the input. + state_per_phone (int): Number of states per phone. + in_out_channels (int): number of input and output channels. + n_convolutions (int): number of convolutional layers. + """ + + def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutions=3): + super().__init__() + + self.state_per_phone = state_per_phone + self.in_out_channels = in_out_channels + + self.emb = nn.Embedding(num_chars, in_out_channels) + self.convolutions = nn.ModuleList() + for _ in range(n_convolutions): + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, + int(in_out_channels / 2) * state_per_phone, + num_layers=1, + batch_first=True, + bias=True, + bidirectional=True, + ) + self.rnn_state = None + + def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """Forward pass to the encoder. + + Args: + x (torch.FloatTensor): input text indices. + - shape: :math:`(b, T_{in})` + x_len (torch.LongTensor): input text lengths. + - shape: :math:`(b,)` + + Returns: + Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. + -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` + """ + b, T = x.shape + o = self.emb(x).transpose(1, 2) + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, x_len.cpu(), batch_first=True) + self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) + x_len = x_len * self.state_per_phone + return o, x_len + + def inference(self, x, x_len): + """Inference to the encoder. + + Args: + x (torch.FloatTensor): input text indices. + - shape: :math:`(b, T_{in})` + x_len (torch.LongTensor): input text lengths. + - shape: :math:`(b,)` + + Returns: + Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. + -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` + """ + b, T = x.shape + o = self.emb(x).transpose(1, 2) + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + # self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) + x_len = x_len * self.state_per_phone + return o, x_len + + +class ParameterModel(nn.Module): + r"""Main neural network of the outputnet + + Note: Do not put dropout layers here, the model will not converge. + + Args: + outputnet_size (List[int]): the architecture of the parameter model + input_size (int): size of input for the first layer + output_size (int): size of output i.e size of the feature dim + frame_channels (int): feature dim to set the flat start bias + flat_start_params (dict): flat start parameters to set the bias + """ + + def __init__( + self, + outputnet_size: List[int], + input_size: int, + output_size: int, + frame_channels: int, + flat_start_params: dict, + ): + super().__init__() + self.frame_channels = frame_channels + + self.layers = nn.ModuleList( + [Linear(inp, out) for inp, out in zip([input_size] + outputnet_size[:-1], outputnet_size)] + ) + self.last_layer = nn.Linear(outputnet_size[-1], output_size) + self.flat_start_output_layer( + flat_start_params["mean"], flat_start_params["std"], flat_start_params["transition_p"] + ) + + def flat_start_output_layer(self, mean, std, transition_p): + self.last_layer.weight.data.zero_() + self.last_layer.bias.data[0 : self.frame_channels] = mean + self.last_layer.bias.data[self.frame_channels : 2 * self.frame_channels] = OverflowUtils.inverse_softplus(std) + self.last_layer.bias.data[2 * self.frame_channels :] = OverflowUtils.inverse_sigmod(transition_p) + + def forward(self, x): + for layer in self.layers: + x = F.relu(layer(x)) + x = self.last_layer(x) + return x + + +class Outputnet(nn.Module): + r""" + This network takes current state and previous observed values as input + and returns its parameters, mean, standard deviation and probability + of transition to the next state + """ + + def __init__( + self, + encoder_dim: int, + memory_rnn_dim: int, + frame_channels: int, + outputnet_size: List[int], + flat_start_params: dict, + std_floor: float = 1e-2, + ): + super().__init__() + + self.frame_channels = frame_channels + self.flat_start_params = flat_start_params + self.std_floor = std_floor + + input_size = memory_rnn_dim + encoder_dim + output_size = 2 * frame_channels + 1 + + self.parametermodel = ParameterModel( + outputnet_size=outputnet_size, + input_size=input_size, + output_size=output_size, + flat_start_params=flat_start_params, + frame_channels=frame_channels, + ) + + def forward(self, ar_mels, inputs): + r"""Inputs observation and returns the means, stds and transition probability for the current state + + Args: + ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim) + states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim) + + Returns: + means: means for the emission observation for each feature + - shape: (B, hidden_states, feature_size) + stds: standard deviations for the emission observation for each feature + - shape: (batch, hidden_states, feature_size) + transition_vectors: transition vector for the current hidden state + - shape: (batch, hidden_states) + """ + batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1] + N = inputs.shape[1] + + ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim) + ar_mels = torch.cat((ar_mels, inputs), dim=2) + ar_mels = self.parametermodel(ar_mels) + + mean, std, transition_vector = ( + ar_mels[:, :, 0 : self.frame_channels], + ar_mels[:, :, self.frame_channels : 2 * self.frame_channels], + ar_mels[:, :, 2 * self.frame_channels :].squeeze(2), + ) + std = F.softplus(std) + std = self._floor_std(std) + return mean, std, transition_vector + + def _floor_std(self, std): + r""" + It clamps the standard deviation to not to go below some level + This removes the problem when the model tries to cheat for higher likelihoods by converting + one of the gaussians to a point mass. + + Args: + std (float Tensor): tensor containing the standard deviation to be + """ + original_tensor = std.clone().detach() + std = torch.clamp(std, min=self.std_floor) + if torch.any(original_tensor != std): + print( + "[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" + ) + return std + + +class OverflowUtils: + @staticmethod + def get_data_parameters_for_flat_start( + data_loader: torch.utils.data.DataLoader, out_channels: int, states_per_phone: int + ): + """Generates data parameters for flat starting the HMM. + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + states_per_phone (_type_): HMM states per phone + """ + + # State related information for transition_p + total_state_len = 0 + total_mel_len = 0 + + # Useful for data mean an std + total_mel_sum = 0 + total_mel_sq_sum = 0 + + for batch in tqdm(data_loader, leave=False): + text_lengths = batch["token_id_lengths"] + mels = batch["mel"] + mel_lengths = batch["mel_lengths"] + + total_state_len += torch.sum(text_lengths) + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + average_num_states = total_state_len / len(data_loader.dataset) + average_mel_len = total_mel_len / len(data_loader.dataset) + average_duration_each_state = average_mel_len / average_num_states + init_transition_prob = 1 / average_duration_each_state + + return data_mean, data_std, (init_transition_prob * states_per_phone) + + @staticmethod + @torch.no_grad() + def update_flat_start_transition(model, transition_p): + model.neural_hmm.output_net.parametermodel.flat_start_output_layer(0.0, 1.0, transition_p) + + @staticmethod + def log_clamped(x, eps=1e-04): + """ + Avoids the log(0) problem + + Args: + x (torch.tensor): input tensor + eps (float, optional): lower bound. Defaults to 1e-04. + + Returns: + torch.tensor: :math:`log(x)` + """ + clamped_x = torch.clamp(x, min=eps) + return torch.log(clamped_x) + + @staticmethod + def inverse_sigmod(x): + r""" + Inverse of the sigmoid function + """ + if not torch.is_tensor(x): + x = torch.tensor(x) + return OverflowUtils.log_clamped(x / (1.0 - x)) + + @staticmethod + def inverse_softplus(x): + r""" + Inverse of the softplus function + """ + if not torch.is_tensor(x): + x = torch.tensor(x) + return OverflowUtils.log_clamped(torch.exp(x) - 1.0) + + @staticmethod + def logsumexp(x, dim): + r""" + Differentiable LogSumExp: Does not creates nan gradients + when all the inputs are -inf yeilds 0 gradients. + Args: + x : torch.Tensor - The input tensor + dim: int - The dimension on which the log sum exp has to be applied + """ + + m, _ = x.max(dim=dim) + mask = m == -float("inf") + s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim) + return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float("inf")) + + @staticmethod + def double_pad(list_of_different_shape_tensors): + r""" + Pads the list of tensors in 2 dimensions + """ + second_dim_lens = [len(a) for a in [i[0] for i in list_of_different_shape_tensors]] + second_dim_max = max(second_dim_lens) + padded_x = [F.pad(x, (0, second_dim_max - len(x[0]))) for x in list_of_different_shape_tensors] + return nn.utils.rnn.pad_sequence(padded_x, batch_first=True) diff --git a/viXTTS/TTS/tts/layers/overflow/decoder.py b/viXTTS/TTS/tts/layers/overflow/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd7ae88068cfaffe179f2e61354cc7eb760268c --- /dev/null +++ b/viXTTS/TTS/tts/layers/overflow/decoder.py @@ -0,0 +1,81 @@ +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder +from TTS.tts.utils.helpers import sequence_mask + + +class Decoder(nn.Module): + """Uses glow decoder with some modifications. + :: + + Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze + + Args: + in_channels (int): channels of input tensor. + hidden_channels (int): hidden decoder channels. + kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_flow_blocks (int): number of decoder blocks. + num_coupling_layers (int): number coupling layers. (number of wavenet layers.) + dropout_p (float): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): + super().__init__() + + self.glow_decoder = GlowDecoder( + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p, + num_splits, + num_squeeze, + sigmoid_scale, + c_in_channels, + ) + self.n_sqz = num_squeeze + + def forward(self, x, x_len, g=None, reverse=False): + """ + Input shapes: + - x: :math:`[B, C, T]` + - x_len :math:`[B]` + - g: :math:`[B, C]` + + Output shapes: + - x: :math:`[B, C, T]` + - x_len :math:`[B]` + - logget_tot :math:`[B]` + """ + x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max()) + x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype) + x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse) + return x, x_len, logdet_tot + + def preprocess(self, y, y_lengths, y_max_length): + if y_max_length is not None: + y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz + y = y[:, :, :y_max_length] + y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz + return y, y_lengths, y_max_length + + def store_inverse(self): + self.glow_decoder.store_inverse() diff --git a/viXTTS/TTS/tts/layers/overflow/neural_hmm.py b/viXTTS/TTS/tts/layers/overflow/neural_hmm.py new file mode 100644 index 0000000000000000000000000000000000000000..0631ba98c00029e9871c965e4c7f465aa32bc406 --- /dev/null +++ b/viXTTS/TTS/tts/layers/overflow/neural_hmm.py @@ -0,0 +1,553 @@ +from typing import List + +import torch +import torch.distributions as tdist +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from TTS.tts.layers.overflow.common_layers import Outputnet, OverflowUtils +from TTS.tts.layers.tacotron.common_layers import Prenet +from TTS.tts.utils.helpers import sequence_mask + + +class NeuralHMM(nn.Module): + """Autoregressive left to right HMM model primarily used in "Neural HMMs are all you need (for high-quality attention-free TTS)" + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality than statistical speech synthesis using + HMMs. However, neural TTS is generally not probabilistic and uses non-monotonic attention. Attention failures increase + training time and can make synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with an autoregressive left-right + no-skip hidden Markov model defined by a neural network. Based on this proposal, we modify Tacotron 2 to obtain an + HMM-based neural TTS model with monotonic alignment, trained to maximise the full sequence likelihood without + approximation. We also describe how to combine ideas from classical and contemporary TTS for best results. The resulting + example system is smaller and simpler than Tacotron 2, and learns to speak with fewer iterations and less data, whilst + achieving comparable naturalness prior to the post-net. Our approach also allows easy control over speaking rate. + + Args: + frame_channels (int): Output dimension to generate. + ar_order (int): Autoregressive order of the model. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + deterministic_transition (bool): deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + encoder_dim (int): Channels of encoder input and character embedding tensors. Defaults to 512. + prenet_type (str): `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the Prenet. + prenet_dim (int): Dimension of the Prenet. + prenet_n_layers (int): Number of layers in the Prenet. + prenet_dropout (float): Dropout probability of the Prenet. + prenet_dropout_at_inference (bool): If True, dropout is applied at inference time. + memory_rnn_dim (int): Size of the memory RNN to process output of prenet. + outputnet_size (List[int]): Size of the output network inside the neural HMM. + flat_start_params (dict): Parameters for the flat start initialization of the neural HMM. + std_floor (float): Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. + use_grad_checkpointing (bool, optional): Use gradient checkpointing to save memory. Defaults to True. + """ + + def __init__( + self, + frame_channels: int, + ar_order: int, + deterministic_transition: bool, + encoder_dim: int, + prenet_type: str, + prenet_dim: int, + prenet_n_layers: int, + prenet_dropout: float, + prenet_dropout_at_inference: bool, + memory_rnn_dim: int, + outputnet_size: List[int], + flat_start_params: dict, + std_floor: float, + use_grad_checkpointing: bool = True, + ): + super().__init__() + + self.frame_channels = frame_channels + self.ar_order = ar_order + self.deterministic_transition = deterministic_transition + self.prenet_dim = prenet_dim + self.memory_rnn_dim = memory_rnn_dim + self.use_grad_checkpointing = use_grad_checkpointing + + self.transition_model = TransitionModel() + self.emission_model = EmissionModel() + + assert ar_order > 0, f"AR order must be greater than 0 provided {ar_order}" + + self.ar_order = ar_order + self.prenet = Prenet( + in_features=frame_channels * ar_order, + prenet_type=prenet_type, + prenet_dropout=prenet_dropout, + dropout_at_inference=prenet_dropout_at_inference, + out_features=[self.prenet_dim for _ in range(prenet_n_layers)], + bias=False, + ) + self.memory_rnn = nn.LSTMCell(input_size=prenet_dim, hidden_size=memory_rnn_dim) + self.output_net = Outputnet( + encoder_dim, memory_rnn_dim, frame_channels, outputnet_size, flat_start_params, std_floor + ) + self.register_buffer("go_tokens", torch.zeros(ar_order, 1)) + + def forward(self, inputs, inputs_len, mels, mel_lens): + r"""HMM forward algorithm for training uses logarithmic version of Rabiner (1989) forward algorithm. + + Args: + inputs (torch.FloatTensor): Encoder outputs + inputs_len (torch.LongTensor): Encoder output lengths + mels (torch.FloatTensor): Mel inputs + mel_lens (torch.LongTensor): Length of mel inputs + + Shapes: + - inputs: (B, T, D_out_enc) + - inputs_len: (B) + - mels: (B, D_mel, T_mel) + - mel_lens: (B) + + Returns: + log_prob (torch.FloatTensor): Log probability of the sequence + """ + # Get dimensions of inputs + batch_size, N, _ = inputs.shape + T_max = torch.max(mel_lens) + mels = mels.permute(0, 2, 1) + + # Intialize forward algorithm + log_state_priors = self._initialize_log_state_priors(inputs) + log_c, log_alpha_scaled, transition_matrix, means = self._initialize_forward_algorithm_variables(mels, N) + + # Initialize autoregression elements + ar_inputs = self._add_go_token(mels) + h_memory, c_memory = self._init_lstm_states(batch_size, self.memory_rnn_dim, mels) + + for t in range(T_max): + # Process Autoregression + h_memory, c_memory = self._process_ar_timestep(t, ar_inputs, h_memory, c_memory) + # Get mean, std and transition vector from decoder for this timestep + # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop + if self.use_grad_checkpointing and self.training: + mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) + else: + mean, std, transition_vector = self.output_net(h_memory, inputs) + + if t == 0: + log_alpha_temp = log_state_priors + self.emission_model(mels[:, 0], mean, std, inputs_len) + else: + log_alpha_temp = self.emission_model(mels[:, t], mean, std, inputs_len) + self.transition_model( + log_alpha_scaled[:, t - 1, :], transition_vector, inputs_len + ) + log_c[:, t] = torch.logsumexp(log_alpha_temp, dim=1) + log_alpha_scaled[:, t, :] = log_alpha_temp - log_c[:, t].unsqueeze(1) + transition_matrix[:, t] = transition_vector # needed for absorption state calculation + + # Save for plotting + means.append(mean.detach()) + + log_c, log_alpha_scaled = self._mask_lengths(mel_lens, log_c, log_alpha_scaled) + + sum_final_log_c = self.get_absorption_state_scaling_factor( + mel_lens, log_alpha_scaled, inputs_len, transition_matrix + ) + + log_probs = torch.sum(log_c, dim=1) + sum_final_log_c + + return log_probs, log_alpha_scaled, transition_matrix, means + + @staticmethod + def _mask_lengths(mel_lens, log_c, log_alpha_scaled): + """ + Mask the lengths of the forward variables so that the variable lenghts + do not contribute in the loss calculation + Args: + mel_inputs (torch.FloatTensor): (batch, T, frame_channels) + mel_inputs_lengths (torch.IntTensor): (batch) + log_c (torch.FloatTensor): (batch, T) + Returns: + log_c (torch.FloatTensor) : scaled probabilities (batch, T) + log_alpha_scaled (torch.FloatTensor): forward probabilities (batch, T, N) + """ + mask_log_c = sequence_mask(mel_lens) + log_c = log_c * mask_log_c + mask_log_alpha_scaled = mask_log_c.unsqueeze(2) + log_alpha_scaled = log_alpha_scaled * mask_log_alpha_scaled + return log_c, log_alpha_scaled + + def _process_ar_timestep( + self, + t, + ar_inputs, + h_memory, + c_memory, + ): + """ + Process autoregression in timestep + 1. At a specific t timestep + 2. Perform data dropout if applied (we did not use it) + 3. Run the autoregressive frame through the prenet (has dropout) + 4. Run the prenet output through the post prenet rnn + + Args: + t (int): mel-spec timestep + ar_inputs (torch.FloatTensor): go-token appended mel-spectrograms + - shape: (b, D_out, T_out) + h_post_prenet (torch.FloatTensor): previous timestep rnn hidden state + - shape: (b, memory_rnn_dim) + c_post_prenet (torch.FloatTensor): previous timestep rnn cell state + - shape: (b, memory_rnn_dim) + + Returns: + h_post_prenet (torch.FloatTensor): rnn hidden state of the current timestep + c_post_prenet (torch.FloatTensor): rnn cell state of the current timestep + """ + prenet_input = ar_inputs[:, t : t + self.ar_order].flatten(1) + memory_inputs = self.prenet(prenet_input) + h_memory, c_memory = self.memory_rnn(memory_inputs, (h_memory, c_memory)) + return h_memory, c_memory + + def _add_go_token(self, mel_inputs): + """Append the go token to create the autoregressive input + Args: + mel_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) + Returns: + ar_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) + """ + batch_size, T, _ = mel_inputs.shape + go_tokens = self.go_tokens.unsqueeze(0).expand(batch_size, self.ar_order, self.frame_channels) + ar_inputs = torch.cat((go_tokens, mel_inputs), dim=1)[:, :T] + return ar_inputs + + @staticmethod + def _initialize_forward_algorithm_variables(mel_inputs, N): + r"""Initialize placeholders for forward algorithm variables, to use a stable + version we will use log_alpha_scaled and the scaling constant + + Args: + mel_inputs (torch.FloatTensor): (b, T_max, frame_channels) + N (int): number of states + Returns: + log_c (torch.FloatTensor): Scaling constant (b, T_max) + """ + b, T_max, _ = mel_inputs.shape + log_alpha_scaled = mel_inputs.new_zeros((b, T_max, N)) + log_c = mel_inputs.new_zeros(b, T_max) + transition_matrix = mel_inputs.new_zeros((b, T_max, N)) + + # Saving for plotting later, will not have gradient tapes + means = [] + return log_c, log_alpha_scaled, transition_matrix, means + + @staticmethod + def _init_lstm_states(batch_size, hidden_state_dim, device_tensor): + r""" + Initialize Hidden and Cell states for LSTM Cell + + Args: + batch_size (Int): batch size + hidden_state_dim (Int): dimensions of the h and c + device_tensor (torch.FloatTensor): useful for the device and type + + Returns: + (torch.FloatTensor): shape (batch_size, hidden_state_dim) + can be hidden state for LSTM + (torch.FloatTensor): shape (batch_size, hidden_state_dim) + can be the cell state for LSTM + """ + return ( + device_tensor.new_zeros(batch_size, hidden_state_dim), + device_tensor.new_zeros(batch_size, hidden_state_dim), + ) + + def get_absorption_state_scaling_factor(self, mels_len, log_alpha_scaled, inputs_len, transition_vector): + """Returns the final scaling factor of absorption state + + Args: + mels_len (torch.IntTensor): Input size of mels to + get the last timestep of log_alpha_scaled + log_alpha_scaled (torch.FloatTEnsor): State probabilities + text_lengths (torch.IntTensor): length of the states to + mask the values of states lengths + ( + Useful when the batch has very different lengths, + when the length of an observation is less than + the number of max states, then the log alpha after + the state value is filled with -infs. So we mask + those values so that it only consider the states + which are needed for that length + ) + transition_vector (torch.FloatTensor): transtiion vector for each state per timestep + + Shapes: + - mels_len: (batch_size) + - log_alpha_scaled: (batch_size, N, T) + - text_lengths: (batch_size) + - transition_vector: (batch_size, N, T) + + Returns: + sum_final_log_c (torch.FloatTensor): (batch_size) + + """ + N = torch.max(inputs_len) + max_inputs_len = log_alpha_scaled.shape[2] + state_lengths_mask = sequence_mask(inputs_len, max_len=max_inputs_len) + + last_log_alpha_scaled_index = ( + (mels_len - 1).unsqueeze(-1).expand(-1, N).unsqueeze(1) + ) # Batch X Hidden State Size + last_log_alpha_scaled = torch.gather(log_alpha_scaled, 1, last_log_alpha_scaled_index).squeeze(1) + last_log_alpha_scaled = last_log_alpha_scaled.masked_fill(~state_lengths_mask, -float("inf")) + + last_transition_vector = torch.gather(transition_vector, 1, last_log_alpha_scaled_index).squeeze(1) + last_transition_probability = torch.sigmoid(last_transition_vector) + log_probability_of_transitioning = OverflowUtils.log_clamped(last_transition_probability) + + last_transition_probability_index = self.get_mask_for_last_item(inputs_len, inputs_len.device) + log_probability_of_transitioning = log_probability_of_transitioning.masked_fill( + ~last_transition_probability_index, -float("inf") + ) + final_log_c = last_log_alpha_scaled + log_probability_of_transitioning + + # If the length of the mel is less than the number of states it will select the -inf values leading to nan gradients + # Ideally, we should clean the dataset otherwise this is a little hack uncomment the line below + final_log_c = final_log_c.clamp(min=torch.finfo(final_log_c.dtype).min) + + sum_final_log_c = torch.logsumexp(final_log_c, dim=1) + return sum_final_log_c + + @staticmethod + def get_mask_for_last_item(lengths, device, out_tensor=None): + """Returns n-1 mask for the last item in the sequence. + + Args: + lengths (torch.IntTensor): lengths in a batch + device (str, optional): Defaults to "cpu". + out_tensor (torch.Tensor, optional): uses the memory of a specific tensor. + Defaults to None. + + Returns: + - Shape: :math:`(b, max_len)` + """ + max_len = torch.max(lengths).item() + ids = ( + torch.arange(0, max_len, device=device) if out_tensor is None else torch.arange(0, max_len, out=out_tensor) + ) + mask = ids == lengths.unsqueeze(1) - 1 + return mask + + @torch.inference_mode() + def inference( + self, + inputs: torch.FloatTensor, + input_lens: torch.LongTensor, + sampling_temp: float, + max_sampling_time: int, + duration_threshold: float, + ): + """Inference from autoregressive neural HMM + + Args: + inputs (torch.FloatTensor): input states + - shape: :math:`(b, T, d)` + input_lens (torch.LongTensor): input state lengths + - shape: :math:`(b)` + sampling_temp (float): sampling temperature + max_sampling_temp (int): max sampling temperature + duration_threshold (float): duration threshold to switch to next state + - Use this to change the spearking rate of the synthesised audio + """ + + b = inputs.shape[0] + outputs = { + "hmm_outputs": [], + "hmm_outputs_len": [], + "alignments": [], + "input_parameters": [], + "output_parameters": [], + } + for i in range(b): + neural_hmm_outputs, states_travelled, input_parameters, output_parameters = self.sample( + inputs[i : i + 1], input_lens[i], sampling_temp, max_sampling_time, duration_threshold + ) + + outputs["hmm_outputs"].append(neural_hmm_outputs) + outputs["hmm_outputs_len"].append(neural_hmm_outputs.shape[0]) + outputs["alignments"].append(states_travelled) + outputs["input_parameters"].append(input_parameters) + outputs["output_parameters"].append(output_parameters) + + outputs["hmm_outputs"] = nn.utils.rnn.pad_sequence(outputs["hmm_outputs"], batch_first=True) + outputs["hmm_outputs_len"] = torch.tensor( + outputs["hmm_outputs_len"], dtype=input_lens.dtype, device=input_lens.device + ) + return outputs + + @torch.inference_mode() + def sample(self, inputs, input_lens, sampling_temp, max_sampling_time, duration_threshold): + """Samples an output from the parameter models + + Args: + inputs (torch.FloatTensor): input states + - shape: :math:`(1, T, d)` + input_lens (torch.LongTensor): input state lengths + - shape: :math:`(1)` + sampling_temp (float): sampling temperature + max_sampling_time (int): max sampling time + duration_threshold (float): duration threshold to switch to next state + + Returns: + outputs (torch.FloatTensor): Output Observations + - Shape: :math:`(T, output_dim)` + states_travelled (list[int]): Hidden states travelled + - Shape: :math:`(T)` + input_parameters (list[torch.FloatTensor]): Input parameters + output_parameters (list[torch.FloatTensor]): Output parameters + """ + states_travelled, outputs, t = [], [], 0 + + # Sample initial state + current_state = 0 + states_travelled.append(current_state) + + # Prepare autoregression + prenet_input = self.go_tokens.unsqueeze(0).expand(1, self.ar_order, self.frame_channels) + h_memory, c_memory = self._init_lstm_states(1, self.memory_rnn_dim, prenet_input) + + input_parameter_values = [] + output_parameter_values = [] + quantile = 1 + while True: + memory_input = self.prenet(prenet_input.flatten(1).unsqueeze(0)) + # will be 1 while sampling + h_memory, c_memory = self.memory_rnn(memory_input.squeeze(0), (h_memory, c_memory)) + + z_t = inputs[:, current_state].unsqueeze(0) # Add fake time dimension + mean, std, transition_vector = self.output_net(h_memory, z_t) + + transition_probability = torch.sigmoid(transition_vector.flatten()) + staying_probability = torch.sigmoid(-transition_vector.flatten()) + + # Save for plotting + input_parameter_values.append([prenet_input, current_state]) + output_parameter_values.append([mean, std, transition_probability]) + + x_t = self.emission_model.sample(mean, std, sampling_temp=sampling_temp) + + # Prepare autoregressive input for next iteration + prenet_input = torch.cat((prenet_input, x_t), dim=1)[:, 1:] + + outputs.append(x_t.flatten()) + + transition_matrix = torch.cat((staying_probability, transition_probability)) + quantile *= staying_probability + if not self.deterministic_transition: + switch = transition_matrix.multinomial(1)[0].item() + else: + switch = quantile < duration_threshold + + if switch: + current_state += 1 + quantile = 1 + + states_travelled.append(current_state) + + if (current_state == input_lens) or (max_sampling_time and t == max_sampling_time - 1): + break + + t += 1 + + return ( + torch.stack(outputs, dim=0), + F.one_hot(input_lens.new_tensor(states_travelled)), + input_parameter_values, + output_parameter_values, + ) + + @staticmethod + def _initialize_log_state_priors(text_embeddings): + """Creates the log pi in forward algorithm. + + Args: + text_embeddings (torch.FloatTensor): used to create the log pi + on current device + + Shapes: + - text_embeddings: (B, T, D_out_enc) + """ + N = text_embeddings.shape[1] + log_state_priors = text_embeddings.new_full([N], -float("inf")) + log_state_priors[0] = 0.0 + return log_state_priors + + +class TransitionModel(nn.Module): + """Transition Model of the HMM, it represents the probability of transitioning + form current state to all other states""" + + def forward(self, log_alpha_scaled, transition_vector, inputs_len): # pylint: disable=no-self-use + r""" + product of the past state with transitional probabilities in log space + + Args: + log_alpha_scaled (torch.Tensor): Multiply previous timestep's alphas by + transition matrix (in log domain) + - shape: (batch size, N) + transition_vector (torch.tensor): transition vector for each state + - shape: (N) + inputs_len (int tensor): Lengths of states in a batch + - shape: (batch) + + Returns: + out (torch.FloatTensor): log probability of transitioning to each state + """ + transition_p = torch.sigmoid(transition_vector) + staying_p = torch.sigmoid(-transition_vector) + + log_staying_probability = OverflowUtils.log_clamped(staying_p) + log_transition_probability = OverflowUtils.log_clamped(transition_p) + + staying = log_alpha_scaled + log_staying_probability + leaving = log_alpha_scaled + log_transition_probability + leaving = leaving.roll(1, dims=1) + leaving[:, 0] = -float("inf") + inputs_len_mask = sequence_mask(inputs_len) + out = OverflowUtils.logsumexp(torch.stack((staying, leaving), dim=2), dim=2) + out = out.masked_fill(~inputs_len_mask, -float("inf")) # There are no states to contribute to the loss + return out + + +class EmissionModel(nn.Module): + """Emission Model of the HMM, it represents the probability of + emitting an observation based on the current state""" + + def __init__(self) -> None: + super().__init__() + self.distribution_function: tdist.Distribution = tdist.normal.Normal + + def sample(self, means, stds, sampling_temp): + return self.distribution_function(means, stds * sampling_temp).sample() if sampling_temp > 0 else means + + def forward(self, x_t, means, stds, state_lengths): + r"""Calculates the log probability of the the given data (x_t) + being observed from states with given means and stds + Args: + x_t (float tensor) : observation at current time step + - shape: (batch, feature_dim) + means (float tensor): means of the distributions of hidden states + - shape: (batch, hidden_state, feature_dim) + stds (float tensor): standard deviations of the distributions of the hidden states + - shape: (batch, hidden_state, feature_dim) + state_lengths (int tensor): Lengths of states in a batch + - shape: (batch) + + Returns: + out (float tensor): observation log likelihoods, + expressing the probability of an observation + being generated from a state i + shape: (batch, hidden_state) + """ + emission_dists = self.distribution_function(means, stds) + out = emission_dists.log_prob(x_t.unsqueeze(1)) + state_lengths_mask = sequence_mask(state_lengths).unsqueeze(2) + out = torch.sum(out * state_lengths_mask, dim=2) + return out diff --git a/viXTTS/TTS/tts/layers/overflow/plotting_utils.py b/viXTTS/TTS/tts/layers/overflow/plotting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a63aeb370a38a29660dc93267f4be138381c7df6 --- /dev/null +++ b/viXTTS/TTS/tts/layers/overflow/plotting_utils.py @@ -0,0 +1,79 @@ +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def validate_numpy_array(value: Any): + r""" + Validates the input and makes sure it returns a numpy array (i.e on CPU) + + Args: + value (Any): the input value + + Raises: + TypeError: if the value is not a numpy array or torch tensor + + Returns: + np.ndarray: numpy array of the value + """ + if isinstance(value, np.ndarray): + pass + elif isinstance(value, list): + value = np.array(value) + elif torch.is_tensor(value): + value = value.cpu().numpy() + else: + raise TypeError("Value must be a numpy array, a torch tensor or a list") + + return value + + +def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None): + """Get the most probable state means from the log_alpha_scaled. + + Args: + log_alpha_scaled (torch.Tensor): Log alpha scaled values. + - Shape: :math:`(T, N)` + means (torch.Tensor): Means of the states. + - Shape: :math:`(N, T, D_out)` + decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None. + """ + max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1] + max_len = means.shape[0] + n_mel_channels = means.shape[2] + max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels) + means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype) + if decoder is not None: + mel = ( + decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0] + .squeeze(0) + .T + ) + else: + mel = means + return mel + + +def plot_transition_probabilities_to_numpy(states, transition_probabilities, output_fig=False): + """Generates trainsition probabilities plot for the states and the probability of transition. + + Args: + states (torch.IntTensor): the states + transition_probabilities (torch.FloatTensor): the transition probabilities + """ + states = validate_numpy_array(states) + transition_probabilities = validate_numpy_array(transition_probabilities) + + fig, ax = plt.subplots(figsize=(30, 3)) + ax.plot(transition_probabilities, "o") + ax.set_title("Transition probability of state") + ax.set_xlabel("hidden state") + ax.set_ylabel("probability") + ax.set_xticks([i for i in range(len(transition_probabilities))]) # pylint: disable=unnecessary-comprehension + ax.set_xticklabels([int(x) for x in states], rotation=90) + plt.tight_layout() + if not output_fig: + plt.close() + return fig diff --git a/viXTTS/TTS/tts/layers/tacotron/__init__.py b/viXTTS/TTS/tts/layers/tacotron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/layers/tacotron/attentions.py b/viXTTS/TTS/tts/layers/tacotron/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..25c3798e6b8f5fbc66224af66c9955e245b94097 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tacotron/attentions.py @@ -0,0 +1,486 @@ +import torch +from scipy.stats import betabinom +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.tacotron.common_layers import Linear + + +class LocationLayer(nn.Module): + """Layers for Location Sensitive Attention + + Args: + attention_dim (int): number of channels in the input tensor. + attention_n_filters (int, optional): number of filters in convolution. Defaults to 32. + attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31. + """ + + def __init__(self, attention_dim, attention_n_filters=32, attention_kernel_size=31): + super().__init__() + self.location_conv1d = nn.Conv1d( + in_channels=2, + out_channels=attention_n_filters, + kernel_size=attention_kernel_size, + stride=1, + padding=(attention_kernel_size - 1) // 2, + bias=False, + ) + self.location_dense = Linear(attention_n_filters, attention_dim, bias=False, init_gain="tanh") + + def forward(self, attention_cat): + """ + Shapes: + attention_cat: [B, 2, C] + """ + processed_attention = self.location_conv1d(attention_cat) + processed_attention = self.location_dense(processed_attention.transpose(1, 2)) + return processed_attention + + +class GravesAttention(nn.Module): + """Graves Attention as is ref1 with updates from ref2. + ref1: https://arxiv.org/abs/1910.10288 + ref2: https://arxiv.org/pdf/1906.01083.pdf + + Args: + query_dim (int): number of channels in query tensor. + K (int): number of Gaussian heads to be used for computing attention. + """ + + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) + + def __init__(self, query_dim, K): + super().__init__() + self._mask_value = 1e-8 + self.K = K + # self.attention_alignment = 0.05 + self.eps = 1e-5 + self.J = None + self.N_a = nn.Sequential( + nn.Linear(query_dim, query_dim, bias=True), nn.ReLU(), nn.Linear(query_dim, 3 * K, bias=True) + ) + self.attention_weights = None + self.mu_prev = None + self.init_layers() + + def init_layers(self): + torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K) : (3 * self.K)], 1.0) # bias mean + torch.nn.init.constant_(self.N_a[2].bias[self.K : (2 * self.K)], 10) # bias std + + def init_states(self, inputs): + if self.J is None or inputs.shape[1] + 1 > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1] + 2.0).to(inputs.device) + 0.5 + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + # pylint: disable=R0201 + # pylint: disable=unused-argument + def preprocess_inputs(self, inputs): + return None + + def forward(self, query, inputs, processed_inputs, mask): + """ + Shapes: + query: [B, C_attention_rnn] + inputs: [B, T_in, C_encoder] + processed_inputs: place_holder + mask: [B, T_in] + """ + gbk_t = self.N_a(query) + gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) + + # attention model parameters + # each B x K + g_t = gbk_t[:, 0, :] + b_t = gbk_t[:, 1, :] + k_t = gbk_t[:, 2, :] + + # dropout to decorrelate attention heads + g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training) + + # attention GMM parameters + sig_t = torch.nn.functional.softplus(b_t) + self.eps + + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = torch.softmax(g_t, dim=-1) + self.eps + + j = self.J[: inputs.size(1) + 1] + + # attention weights + phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) + + # discritize attention weights + alpha_t = torch.sum(phi_t, 1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + alpha_t[alpha_t == 0] = 1e-8 + + # apply masking + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t + return context + + +class OriginalAttention(nn.Module): + """Bahdanau Attention with various optional modifications. + - Location sensitive attnetion: https://arxiv.org/abs/1712.05884 + - Forward Attention: https://arxiv.org/abs/1807.06736 + state masking at inference + - Using sigmoid instead of softmax normalization + - Attention windowing at inference time + + Note: + Location Sensitive Attention extends the additive attention mechanism + to use cumulative attention weights from previous decoder time steps with the current time step features. + + Forward attention computes most probable monotonic alignment. The modified attention probabilities at each + timestep are computed recursively by the forward algorithm. + + Transition agent in the forward attention explicitly gates the attention mechanism whether to move forward or + stay at each decoder timestep. + + Attention windowing is a inductive prior that prevents the model from attending to previous and future timesteps + beyond a certain window. + + Args: + query_dim (int): number of channels in the query tensor. + embedding_dim (int): number of channels in the vakue tensor. In general, the value tensor is the output of the encoder layer. + attention_dim (int): number of channels of the inner attention layers. + location_attention (bool): enable/disable location sensitive attention. + attention_location_n_filters (int): number of location attention filters. + attention_location_kernel_size (int): filter size of location attention convolution layer. + windowing (int): window size for attention windowing. if it is 5, for computing the attention, it only considers the time steps [(t-5), ..., (t+5)] of the input. + norm (str): normalization method applied to the attention weights. 'softmax' or 'sigmoid' + forward_attn (bool): enable/disable forward attention. + trans_agent (bool): enable/disable transition agent in the forward attention. + forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ): + super().__init__() + self.query_layer = Linear(query_dim, attention_dim, bias=False, init_gain="tanh") + self.inputs_layer = Linear(embedding_dim, attention_dim, bias=False, init_gain="tanh") + self.v = Linear(attention_dim, 1, bias=True) + if trans_agent: + self.ta = nn.Linear(query_dim + embedding_dim, 1, bias=True) + if location_attention: + self.location_layer = LocationLayer( + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ) + self._mask_value = -float("inf") + self.windowing = windowing + self.win_idx = None + self.norm = norm + self.forward_attn = forward_attn + self.trans_agent = trans_agent + self.forward_attn_mask = forward_attn_mask + self.location_attention = location_attention + + def init_win_idx(self): + self.win_idx = -1 + self.win_back = 2 + self.win_front = 6 + + def init_forward_attn(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) + self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) + + def init_location_attention(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights_cum = torch.zeros([B, T], device=inputs.device) + + def init_states(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights = torch.zeros([B, T], device=inputs.device) + if self.location_attention: + self.init_location_attention(inputs) + if self.forward_attn: + self.init_forward_attn(inputs) + if self.windowing: + self.init_win_idx() + + def preprocess_inputs(self, inputs): + return self.inputs_layer(inputs) + + def update_location_attention(self, alignments): + self.attention_weights_cum += alignments + + def get_location_attention(self, query, processed_inputs): + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_cat) + energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def get_attention(self, query, processed_inputs): + processed_query = self.query_layer(query.unsqueeze(1)) + energies = self.v(torch.tanh(processed_query + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def apply_windowing(self, attention, inputs): + back_win = self.win_idx - self.win_back + front_win = self.win_idx + self.win_front + if back_win > 0: + attention[:, :back_win] = -float("inf") + if front_win < inputs.shape[1]: + attention[:, front_win:] = -float("inf") + # this is a trick to solve a special problem. + # but it does not hurt. + if self.win_idx == -1: + attention[:, 0] = attention.max() + # Update the window + self.win_idx = torch.argmax(attention, 1).long()[0].item() + return attention + + def apply_forward_attention(self, alignment): + # forward attention + fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0)) + # compute transition potentials + alpha = ((1 - self.u) * self.alpha + self.u * fwd_shifted_alpha + 1e-8) * alignment + # force incremental alignment + if not self.training and self.forward_attn_mask: + _, n = fwd_shifted_alpha.max(1) + val, _ = alpha.max(1) + for b in range(alignment.shape[0]): + alpha[b, n[b] + 3 :] = 0 + alpha[b, : (n[b] - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step + # renormalize attention weights + alpha = alpha / alpha.sum(dim=1, keepdim=True) + return alpha + + def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: [B, C_attn_rnn] + inputs: [B, T_en, D_en] + processed_inputs: [B, T_en, D_attn] + mask: [B, T_en] + """ + if self.location_attention: + attention, _ = self.get_location_attention(query, processed_inputs) + else: + attention, _ = self.get_attention(query, processed_inputs) + # apply masking + if mask is not None: + attention.data.masked_fill_(~mask, self._mask_value) + # apply windowing - only in eval mode + if not self.training and self.windowing: + attention = self.apply_windowing(attention, inputs) + + # normalize attention values + if self.norm == "softmax": + alignment = torch.softmax(attention, dim=-1) + elif self.norm == "sigmoid": + alignment = torch.sigmoid(attention) / torch.sigmoid(attention).sum(dim=1, keepdim=True) + else: + raise ValueError("Unknown value for attention norm type") + + if self.location_attention: + self.update_location_attention(alignment) + + # apply forward attention if enabled + if self.forward_attn: + alignment = self.apply_forward_attention(alignment) + self.alpha = alignment + + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + self.attention_weights = alignment + + # compute transition agent + if self.forward_attn and self.trans_agent: + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) + self.u = torch.sigmoid(self.ta(ta_input)) + return context + + +class MonotonicDynamicConvolutionAttention(nn.Module): + """Dynamic convolution attention from + https://arxiv.org/pdf/1910.10288.pdf + + + query -> linear -> tanh -> linear ->| + | mask values + v | | + atten_w(t-1) -|-> conv1d_dynamic -> linear -|-> tanh -> + -> softmax -> * -> * -> context + |-> conv1d_static -> linear -| | + |-> conv1d_prior -> log ----------------| + + query: attention rnn output. + + Note: + Dynamic convolution attention is an alternation of the location senstive attention with + dynamically computed convolution filters from the previous attention scores and a set of + constraints to keep the attention alignment diagonal. + DCA is sensitive to mixed precision training and might cause instable training. + + Args: + query_dim (int): number of channels in the query tensor. + embedding_dim (int): number of channels in the value tensor. + static_filter_dim (int): number of channels in the convolution layer computing the static filters. + static_kernel_size (int): kernel size for the convolution layer computing the static filters. + dynamic_filter_dim (int): number of channels in the convolution layer computing the dynamic filters. + dynamic_kernel_size (int): kernel size for the convolution layer computing the dynamic filters. + prior_filter_len (int, optional): [description]. Defaults to 11 from the paper. + alpha (float, optional): [description]. Defaults to 0.1 from the paper. + beta (float, optional): [description]. Defaults to 0.9 from the paper. + """ + + def __init__( + self, + query_dim, + embedding_dim, # pylint: disable=unused-argument + attention_dim, + static_filter_dim, + static_kernel_size, + dynamic_filter_dim, + dynamic_kernel_size, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ): + super().__init__() + self._mask_value = 1e-8 + self.dynamic_filter_dim = dynamic_filter_dim + self.dynamic_kernel_size = dynamic_kernel_size + self.prior_filter_len = prior_filter_len + self.attention_weights = None + # setup key and query layers + self.query_layer = nn.Linear(query_dim, attention_dim) + self.key_layer = nn.Linear(attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False) + self.static_filter_conv = nn.Conv1d( + 1, + static_filter_dim, + static_kernel_size, + padding=(static_kernel_size - 1) // 2, + bias=False, + ) + self.static_filter_layer = nn.Linear(static_filter_dim, attention_dim, bias=False) + self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim) + self.v = nn.Linear(attention_dim, 1, bias=False) + + prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, alpha, beta) + self.register_buffer("prior", torch.FloatTensor(prior).flip(0)) + + # pylint: disable=unused-argument + def forward(self, query, inputs, processed_inputs, mask): + """ + query: [B, C_attn_rnn] + inputs: [B, T_en, D_en] + processed_inputs: place holder. + mask: [B, T_en] + """ + # compute prior filters + prior_filter = F.conv1d( + F.pad(self.attention_weights.unsqueeze(1), (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1) + ) + prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1) + G = self.key_layer(torch.tanh(self.query_layer(query))) + # compute dynamic filters + dynamic_filter = F.conv1d( + self.attention_weights.unsqueeze(0), + G.view(-1, 1, self.dynamic_kernel_size), + padding=(self.dynamic_kernel_size - 1) // 2, + groups=query.size(0), + ) + dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2) + # compute static filters + static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2) + alignment = ( + self.v( + torch.tanh(self.static_filter_layer(static_filter) + self.dynamic_filter_layer(dynamic_filter)) + ).squeeze(-1) + + prior_filter + ) + # compute attention weights + attention_weights = F.softmax(alignment, dim=-1) + # apply masking + if mask is not None: + attention_weights.data.masked_fill_(~mask, self._mask_value) + self.attention_weights = attention_weights + # compute context + context = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1) + return context + + def preprocess_inputs(self, inputs): # pylint: disable=no-self-use + return None + + def init_states(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights = torch.zeros([B, T], device=inputs.device) + self.attention_weights[:, 0] = 1.0 + + +def init_attn( + attn_type, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + attn_K, +): + if attn_type == "original": + return OriginalAttention( + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ) + if attn_type == "graves": + return GravesAttention(query_dim, attn_K) + if attn_type == "dynamic_convolution": + return MonotonicDynamicConvolutionAttention( + query_dim, + embedding_dim, + attention_dim, + static_filter_dim=8, + static_kernel_size=21, + dynamic_filter_dim=8, + dynamic_kernel_size=21, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ) + + raise RuntimeError(f" [!] Given Attention Type '{attn_type}' is not exist.") diff --git a/viXTTS/TTS/tts/layers/tacotron/capacitron_layers.py b/viXTTS/TTS/tts/layers/tacotron/capacitron_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..2181ffa7ec4e1f54d86cc5865a8fa7f6b6e362af --- /dev/null +++ b/viXTTS/TTS/tts/layers/tacotron/capacitron_layers.py @@ -0,0 +1,205 @@ +import torch +from torch import nn +from torch.distributions.multivariate_normal import MultivariateNormal as MVN +from torch.nn import functional as F + + +class CapacitronVAE(nn.Module): + """Effective Use of Variational Embedding Capacity for prosody transfer. + + See https://arxiv.org/abs/1906.03402""" + + def __init__( + self, + num_mel, + capacitron_VAE_embedding_dim, + encoder_output_dim=256, + reference_encoder_out_dim=128, + speaker_embedding_dim=None, + text_summary_embedding_dim=None, + ): + super().__init__() + # Init distributions + self.prior_distribution = MVN( + torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim) + ) + self.approximate_posterior_distribution = None + # define output ReferenceEncoder dim to the capacitron_VAE_embedding_dim + self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim) + + # Init beta, the lagrange-like term for the KL distribution + self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True) + mlp_input_dimension = reference_encoder_out_dim + + if text_summary_embedding_dim is not None: + self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim) + mlp_input_dimension += text_summary_embedding_dim + if speaker_embedding_dim is not None: + # TODO: Test a multispeaker model! + mlp_input_dimension += speaker_embedding_dim + self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim) + + def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None): + # Use reference + if reference_mel_info is not None: + reference_mels = reference_mel_info[0] # [batch_size, num_frames, num_mels] + mel_lengths = reference_mel_info[1] # [batch_size] + enc_out = self.encoder(reference_mels, mel_lengths) + + # concat speaker_embedding and/or text summary embedding + if text_info is not None: + text_inputs = text_info[0] # [batch_size, num_characters, num_embedding] + input_lengths = text_info[1] + text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device) + enc_out = torch.cat([enc_out, text_summary_out], dim=-1) + if speaker_embedding is not None: + speaker_embedding = torch.squeeze(speaker_embedding) + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) + + # Feed the output of the ref encoder and information about text/speaker into + # an MLP to produce the parameteres for the approximate poterior distributions + mu, sigma = self.post_encoder_mlp(enc_out) + # convert to cpu because prior_distribution was created on cpu + mu = mu.cpu() + sigma = sigma.cpu() + + # Sample from the posterior: z ~ q(z|x) + self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma)) + VAE_embedding = self.approximate_posterior_distribution.rsample() + # Infer from the model, bypasses encoding + else: + # Sample from the prior: z ~ p(z) + VAE_embedding = self.prior_distribution.sample().unsqueeze(0) + + # reshape to [batch_size, 1, capacitron_VAE_embedding_dim] + return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, num_mel, out_dim): + super().__init__() + self.num_mel = num_mel + filters = [1] + [32, 32, 64, 64, 128, 128] + num_layers = len(filters) - 1 + convs = [ + nn.Conv2d( + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2) + ) + for i in range(num_layers) + ] + self.convs = nn.ModuleList(convs) + self.training = False + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) + + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers) + self.recurrence = nn.LSTM( + input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False + ) + + def forward(self, inputs, input_lengths): + batch_size = inputs.size(0) + x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel] + valid_lengths = input_lengths.float() # [batch_size] + for conv, bn in zip(self.convs, self.bns): + x = conv(x) + x = bn(x) + x = F.relu(x) + + # Create the post conv width mask based on the valid lengths of the output of the convolution. + # The valid lengths for the output of a convolution on varying length inputs is + # ceil(input_length/stride) + 1 for stride=3 and padding=2 + # For example (kernel_size=3, stride=2, padding=2): + # 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d + # _____ + # x _____ + # x _____ + # x ____ + # x + # x x x x -> Output valid length = 4 + # Since every example in te batch is zero padded and therefore have separate valid_lengths, + # we need to mask off all the values AFTER the valid length for each example in the batch. + # Otherwise, the convolutions create noise and a lot of not real information + valid_lengths = (valid_lengths / 2).float() + valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size] + post_conv_max_width = x.size(2) + + mask = torch.arange(post_conv_max_width).to(inputs.device).expand( + len(valid_lengths), post_conv_max_width + ) < valid_lengths.unsqueeze(1) + mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1] + x = x * mask + + x = x.transpose(1, 2) + # x: 4D tensor [batch_size, post_conv_width, + # num_channels==128, post_conv_height] + + post_conv_width = x.size(1) + x = x.contiguous().view(batch_size, post_conv_width, -1) + # x: 3D tensor [batch_size, post_conv_width, + # num_channels*post_conv_height] + + # Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding + post_conv_input_lengths = valid_lengths + packed_seqs = nn.utils.rnn.pack_padded_sequence( + x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False + ) # dynamic rnn sequence padding + self.recurrence.flatten_parameters() + _, (ht, _) = self.recurrence(packed_seqs) + last_output = ht[-1] + + return last_output.to(inputs.device) # [B, 128] + + @staticmethod + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for _ in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + +class TextSummary(nn.Module): + def __init__(self, embedding_dim, encoder_output_dim): + super().__init__() + self.lstm = nn.LSTM( + encoder_output_dim, # text embedding dimension from the text encoder + embedding_dim, # fixed length output summary the lstm creates from the input + batch_first=True, + bidirectional=False, + ) + + def forward(self, inputs, input_lengths): + # Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding + packed_seqs = nn.utils.rnn.pack_padded_sequence( + inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False + ) # dynamic rnn sequence padding + self.lstm.flatten_parameters() + _, (ht, _) = self.lstm(packed_seqs) + last_output = ht[-1] + return last_output + + +class PostEncoderMLP(nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.hidden_size = hidden_size + modules = [ + nn.Linear(input_size, hidden_size), # Hidden Layer + nn.Tanh(), + nn.Linear(hidden_size, hidden_size * 2), + ] # Output layer twice the size for mean and variance + self.net = nn.Sequential(*modules) + self.softplus = nn.Softplus() + + def forward(self, _input): + mlp_output = self.net(_input) + # The mean parameter is unconstrained + mu = mlp_output[:, : self.hidden_size] + # The standard deviation must be positive. Parameterise with a softplus + sigma = self.softplus(mlp_output[:, self.hidden_size :]) + return mu, sigma diff --git a/viXTTS/TTS/tts/layers/tacotron/common_layers.py b/viXTTS/TTS/tts/layers/tacotron/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f78ff1e75f6c23eb1a0fe827247a1127bc8f9958 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tacotron/common_layers.py @@ -0,0 +1,119 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Linear(nn.Module): + """Linear layer with a specific initialization. + + Args: + in_features (int): number of channels in the input tensor. + out_features (int): number of channels in the output tensor. + bias (bool, optional): enable/disable bias in the layer. Defaults to True. + init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'. + """ + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class LinearBN(nn.Module): + """Linear layer with Batch Normalization. + + x -> linear -> BN -> o + + Args: + in_features (int): number of channels in the input tensor. + out_features (int ): number of channels in the output tensor. + bias (bool, optional): enable/disable bias in the linear layer. Defaults to True. + init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'. + """ + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) + self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + """ + Shapes: + x: [T, B, C] or [B, C] + """ + out = self.linear_layer(x) + if len(out.shape) == 3: + out = out.permute(1, 2, 0) + out = self.batch_normalization(out) + if len(out.shape) == 3: + out = out.permute(2, 0, 1) + return out + + +class Prenet(nn.Module): + """Tacotron specific Prenet with an optional Batch Normalization. + + Note: + Prenet with BN improves the model performance significantly especially + if it is enabled after learning a diagonal attention alignment with the original + prenet. However, if the target dataset is high quality then it also works from + the start. It is also suggested to disable dropout if BN is in use. + + prenet_type == "original" + x -> [linear -> ReLU -> Dropout]xN -> o + + prenet_type == "bn" + x -> [linear -> BN -> ReLU -> Dropout]xN -> o + + Args: + in_features (int): number of channels in the input tensor and the inner layers. + prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original". + prenet_dropout (bool, optional): dropout rate. Defaults to True. + dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models. + out_features (list, optional): List of output channels for each prenet block. + It also defines number of the prenet blocks based on the length of argument list. + Defaults to [256, 256]. + bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + prenet_type="original", + prenet_dropout=True, + dropout_at_inference=False, + out_features=[256, 256], + bias=True, + ): + super().__init__() + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + self.dropout_at_inference = dropout_at_inference + in_features = [in_features] + out_features[:-1] + if prenet_type == "bn": + self.linear_layers = nn.ModuleList( + [LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) + elif prenet_type == "original": + self.linear_layers = nn.ModuleList( + [Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) + + def forward(self, x): + for linear in self.linear_layers: + if self.prenet_dropout: + x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference) + else: + x = F.relu(linear(x)) + return x diff --git a/viXTTS/TTS/tts/layers/tacotron/gst_layers.py b/viXTTS/TTS/tts/layers/tacotron/gst_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..05dba7084ff5533b68779d46238530f4988db934 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tacotron/gst_layers.py @@ -0,0 +1,149 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class GST(nn.Module): + """Global Style Token Module for factorizing prosody in speech. + + See https://arxiv.org/pdf/1803.09017""" + + def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None): + super().__init__() + self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) + self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim) + + def forward(self, inputs, speaker_embedding=None): + enc_out = self.encoder(inputs) + # concat speaker_embedding + if speaker_embedding is not None: + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) + style_embed = self.style_token_layer(enc_out) + + return style_embed + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, num_mel, embedding_dim): + super().__init__() + self.num_mel = num_mel + filters = [1] + [32, 32, 64, 64, 128, 128] + num_layers = len(filters) - 1 + convs = [ + nn.Conv2d( + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ) + for i in range(num_layers) + ] + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) + + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) + self.recurrence = nn.GRU( + input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True + ) + + def forward(self, inputs): + batch_size = inputs.size(0) + x = inputs.view(batch_size, 1, -1, self.num_mel) + # x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel] + for conv, bn in zip(self.convs, self.bns): + x = conv(x) + x = bn(x) + x = F.relu(x) + + x = x.transpose(1, 2) + # x: 4D tensor [batch_size, post_conv_width, + # num_channels==128, post_conv_height] + post_conv_width = x.size(1) + x = x.contiguous().view(batch_size, post_conv_width, -1) + # x: 3D tensor [batch_size, post_conv_width, + # num_channels*post_conv_height] + self.recurrence.flatten_parameters() + _, out = self.recurrence(x) + # out: 3D tensor [seq_len==1, batch_size, encoding_size=128] + + return out.squeeze(0) + + @staticmethod + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for _ in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + +class StyleTokenLayer(nn.Module): + """NN Module attending to style tokens based on prosody encodings.""" + + def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None): + super().__init__() + + self.query_dim = gst_embedding_dim // 2 + + if d_vector_dim: + self.query_dim += d_vector_dim + + self.key_dim = gst_embedding_dim // num_heads + self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) + nn.init.normal_(self.style_tokens, mean=0, std=0.5) + self.attention = MultiHeadAttention( + query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads + ) + + def forward(self, inputs): + batch_size = inputs.size(0) + prosody_encoding = inputs.unsqueeze(1) + # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128] + tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1) + # tokens: 3D tensor [batch_size, num tokens, token embedding size] + style_embed = self.attention(prosody_encoding, tokens) + + return style_embed + + +class MultiHeadAttention(nn.Module): + """ + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + """ + + def __init__(self, query_dim, key_dim, num_units, num_heads): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query, key): + queries = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + + split_size = self.num_units // self.num_heads + queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k**0.5)) + scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim**0.5) + scores = F.softmax(scores, dim=3) + + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out diff --git a/viXTTS/TTS/tts/layers/tacotron/tacotron.py b/viXTTS/TTS/tts/layers/tacotron/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..7a47c35ef67852456d7211f32502ffb84509d61f --- /dev/null +++ b/viXTTS/TTS/tts/layers/tacotron/tacotron.py @@ -0,0 +1,503 @@ +# coding: utf-8 +# adapted from https://github.com/r9y9/tacotron_pytorch + +import torch +from torch import nn + +from .attentions import init_attn +from .common_layers import Prenet + + +class BatchNormConv1d(nn.Module): + r"""A wrapper for Conv1d with BatchNorm. It sets the activation + function between Conv and BatchNorm layers. BatchNorm layer + is initialized with the TF default values for momentum and eps. + + Args: + in_channels: size of each input sample + out_channels: size of each output samples + kernel_size: kernel size of conv filters + stride: stride of conv filters + padding: padding of conv filters + activation: activation function set b/w Conv1d and BatchNorm + + Shapes: + - input: (B, D) + - output: (B, D) + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): + super().__init__() + self.padding = padding + self.padder = nn.ConstantPad1d(padding, 0) + self.conv1d = nn.Conv1d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False + ) + # Following tensorflow's default parameters + self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) + self.activation = activation + # self.init_layers() + + def init_layers(self): + if isinstance(self.activation, torch.nn.ReLU): + w_gain = "relu" + elif isinstance(self.activation, torch.nn.Tanh): + w_gain = "tanh" + elif self.activation is None: + w_gain = "linear" + else: + raise RuntimeError("Unknown activation function") + torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) + + def forward(self, x): + x = self.padder(x) + x = self.conv1d(x) + x = self.bn(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class Highway(nn.Module): + r"""Highway layers as explained in https://arxiv.org/abs/1505.00387 + + Args: + in_features (int): size of each input sample + out_feature (int): size of each output sample + + Shapes: + - input: (B, *, H_in) + - output: (B, *, H_out) + """ + + # TODO: Try GLU layer + def __init__(self, in_features, out_feature): + super().__init__() + self.H = nn.Linear(in_features, out_feature) + self.H.bias.data.zero_() + self.T = nn.Linear(in_features, out_feature) + self.T.bias.data.fill_(-1) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + # self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid")) + + def forward(self, inputs): + H = self.relu(self.H(inputs)) + T = self.sigmoid(self.T(inputs)) + return H * T + inputs * (1.0 - T) + + +class CBHG(nn.Module): + """CBHG module: a recurrent neural network composed of: + - 1-d convolution banks + - Highway networks + residual connections + - Bidirectional gated recurrent units + + Args: + in_features (int): sample size + K (int): max filter size in conv bank + projections (list): conv channel sizes for conv projections + num_highways (int): number of highways layers + + Shapes: + - input: (B, C, T_in) + - output: (B, T_in, C*2) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ): + super().__init__() + self.in_features = in_features + self.conv_bank_features = conv_bank_features + self.highway_features = highway_features + self.gru_features = gru_features + self.conv_projections = conv_projections + self.relu = nn.ReLU() + # list of conv1d bank with filter size k=1...K + # TODO: try dilational layers instead + self.conv1d_banks = nn.ModuleList( + [ + BatchNormConv1d( + in_features, + conv_bank_features, + kernel_size=k, + stride=1, + padding=[(k - 1) // 2, k // 2], + activation=self.relu, + ) + for k in range(1, K + 1) + ] + ) + # max pooling of conv bank, with padding + # TODO: try average pooling OR larger kernel size + out_features = [K * conv_bank_features] + conv_projections[:-1] + activations = [self.relu] * (len(conv_projections) - 1) + activations += [None] + # setup conv1d projection layers + layer_set = [] + for in_size, out_size, ac in zip(out_features, conv_projections, activations): + layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac) + layer_set.append(layer) + self.conv1d_projections = nn.ModuleList(layer_set) + # setup Highway layers + if self.highway_features != conv_projections[-1]: + self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False) + self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)]) + # bi-directional GPU layer + self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True) + + def forward(self, inputs): + # (B, in_features, T_in) + x = inputs + # (B, hid_features*K, T_in) + # Concat conv1d bank outputs + outs = [] + for conv1d in self.conv1d_banks: + out = conv1d(x) + outs.append(out) + x = torch.cat(outs, dim=1) + assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks) + for conv1d in self.conv1d_projections: + x = conv1d(x) + x += inputs + x = x.transpose(1, 2) + if self.highway_features != self.conv_projections[-1]: + x = self.pre_highway(x) + # Residual connection + # TODO: try residual scaling as in Deep Voice 3 + # TODO: try plain residual layers + for highway in self.highways: + x = highway(x) + # (B, T_in, hid_features*2) + # TODO: replace GRU with convolution as in Deep Voice 3 + self.gru.flatten_parameters() + outputs, _ = self.gru(x) + return outputs + + +class EncoderCBHG(nn.Module): + r"""CBHG module with Encoder specific arguments""" + + def __init__(self): + super().__init__() + self.cbhg = CBHG( + 128, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ) + + def forward(self, x): + return self.cbhg(x) + + +class Encoder(nn.Module): + r"""Stack Prenet and CBHG module for encoder + Args: + inputs (FloatTensor): embedding features + + Shapes: + - inputs: (B, T, D_in) + - outputs: (B, T, 128 * 2) + """ + + def __init__(self, in_features): + super().__init__() + self.prenet = Prenet(in_features, out_features=[256, 128]) + self.cbhg = EncoderCBHG() + + def forward(self, inputs): + # B x T x prenet_dim + outputs = self.prenet(inputs) + outputs = self.cbhg(outputs.transpose(1, 2)) + return outputs + + +class PostCBHG(nn.Module): + def __init__(self, mel_dim): + super().__init__() + self.cbhg = CBHG( + mel_dim, + K=8, + conv_bank_features=128, + conv_projections=[256, mel_dim], + highway_features=128, + gru_features=128, + num_highways=4, + ) + + def forward(self, x): + return self.cbhg(x) + + +class Decoder(nn.Module): + """Tacotron decoder. + + Args: + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). + memory_size (int): size of the past window. if <= 0 memory_size = r + attn_type (string): type of attention used in decoder. + attn_windowing (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + d_vector_dim (int): size of speaker embedding vector, for multi-speaker training. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + + def __init__( + self, + in_channels, + frame_channels, + r, + memory_size, + attn_type, + attn_windowing, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + max_decoder_steps, + ): + super().__init__() + self.r_init = r + self.r = r + self.in_channels = in_channels + self.max_decoder_steps = max_decoder_steps + self.use_memory_queue = memory_size > 0 + self.memory_size = memory_size if memory_size > 0 else r + self.frame_channels = frame_channels + self.separate_stopnet = separate_stopnet + self.query_dim = 256 + # memory -> |Prenet| -> processed_memory + prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels + self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128]) + # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State + # attention_rnn generates queries for the attention mechanism + self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) + # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input + self.project_to_decoder_in = nn.Linear(256 + in_channels, 256) + # decoder_RNN_input -> |RNN| -> RNN_state + self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)]) + # RNN_state -> |Linear| -> mel_spec + self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init) + # learn init values instead of zero init. + self.stopnet = StopNet(256 + frame_channels * self.r_init) + + def set_r(self, new_r): + self.r = new_r + + def _reshape_memory(self, memory): + """ + Reshape the spectrograms for given 'r' + """ + # Grouping multiple frames if necessary + if memory.size(-1) == self.frame_channels: + memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) + # Time first (T_decoder, B, frame_channels) + memory = memory.transpose(0, 1) + return memory + + def _init_states(self, inputs): + """ + Initialization of decoder states + """ + B = inputs.size(0) + # go frame as zeros matrix + if self.use_memory_queue: + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size) + else: + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels) + # decoder states + self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) + self.decoder_rnn_hiddens = [ + torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) + ] + self.context_vec = inputs.data.new(B, self.in_channels).zero_() + # cache attention inputs + self.processed_inputs = self.attention.preprocess_inputs(inputs) + + def _parse_outputs(self, outputs, attentions, stop_tokens): + # Back to batch first + attentions = torch.stack(attentions).transpose(0, 1) + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) + outputs = outputs.transpose(1, 2) + return outputs, attentions, stop_tokens + + def decode(self, inputs, mask=None): + # Prenet + processed_memory = self.prenet(self.memory_input) + # Attention RNN + self.attention_rnn_hidden = self.attention_rnn( + torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden + ) + self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) + # Concat RNN output and attention context vector + decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) + + # Pass through the decoder RNNs + for idx, decoder_rnn in enumerate(self.decoder_rnns): + self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx]) + # Residual connection + decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input + decoder_output = decoder_input + + # predict mel vectors from decoder vectors + output = self.proj_to_mel(decoder_output) + # output = torch.sigmoid(output) + # predict stop token + stopnet_input = torch.cat([decoder_output, output], -1) + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) + output = output[:, : self.r * self.frame_channels] + return output, stop_token, self.attention.attention_weights + + def _update_memory_input(self, new_memory): + if self.use_memory_queue: + if self.memory_size > self.r: + # memory queue size is larger than number of frames per decoder iter + self.memory_input = torch.cat( + [new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()], + dim=-1, + ) + else: + # memory queue size smaller than number of frames per decoder iter + self.memory_input = new_memory[:, : self.memory_size * self.frame_channels] + else: + # use only the last frame prediction + # assert new_memory.shape[-1] == self.r * self.frame_channels + self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :] + + def forward(self, inputs, memory, mask): + """ + Args: + inputs: Encoder outputs. + memory: Decoder memory (autoregression. If None (at eval-time), + decoder outputs are used as decoder inputs. If None, it uses the last + output as the input. + mask: Attention mask for sequence padding. + + Shapes: + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) + """ + # Run greedy decoding if memory is None + memory = self._reshape_memory(memory) + outputs = [] + attentions = [] + stop_tokens = [] + t = 0 + self._init_states(inputs) + self.attention.init_states(inputs) + while len(outputs) < memory.size(0): + if t > 0: + new_memory = memory[t - 1] + self._update_memory_input(new_memory) + + output, stop_token, attention = self.decode(inputs, mask) + outputs += [output] + attentions += [attention] + stop_tokens += [stop_token.squeeze(1)] + t += 1 + return self._parse_outputs(outputs, attentions, stop_tokens) + + def inference(self, inputs): + """ + Args: + inputs: encoder outputs. + Shapes: + - inputs: batch x time x encoder_out_dim + """ + outputs = [] + attentions = [] + stop_tokens = [] + t = 0 + self._init_states(inputs) + self.attention.init_states(inputs) + while True: + if t > 0: + new_memory = outputs[-1] + self._update_memory_input(new_memory) + output, stop_token, attention = self.decode(inputs, None) + stop_token = torch.sigmoid(stop_token.data) + outputs += [output] + attentions += [attention] + stop_tokens += [stop_token] + t += 1 + if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): + break + if t > self.max_decoder_steps: + print(" | > Decoder stopped with 'max_decoder_steps") + break + return self._parse_outputs(outputs, attentions, stop_tokens) + + +class StopNet(nn.Module): + r"""Stopnet signalling decoder to stop inference. + Args: + in_features (int): feature dimension of input. + """ + + def __init__(self, in_features): + super().__init__() + self.dropout = nn.Dropout(0.1) + self.linear = nn.Linear(in_features, 1) + torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear")) + + def forward(self, inputs): + outputs = self.dropout(inputs) + outputs = self.linear(outputs) + return outputs diff --git a/viXTTS/TTS/tts/layers/tacotron/tacotron2.py b/viXTTS/TTS/tts/layers/tacotron/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..c79b70997249efc94cbac630bcc7d6c571f5743e --- /dev/null +++ b/viXTTS/TTS/tts/layers/tacotron/tacotron2.py @@ -0,0 +1,414 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .attentions import init_attn +from .common_layers import Linear, Prenet + + +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg +class ConvBNBlock(nn.Module): + r"""Convolutions with Batch Normalization and non-linear activation. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): convolution kernel size. + activation (str): 'relu', 'tanh', None (linear). + + Shapes: + - input: (B, C_in, T) + - output: (B, C_out, T) + """ + + def __init__(self, in_channels, out_channels, kernel_size, activation=None): + super().__init__() + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) + self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) + self.dropout = nn.Dropout(p=0.5) + if activation == "relu": + self.activation = nn.ReLU() + elif activation == "tanh": + self.activation = nn.Tanh() + else: + self.activation = nn.Identity() + + def forward(self, x): + o = self.convolution1d(x) + o = self.batch_normalization(o) + o = self.activation(o) + o = self.dropout(o) + return o + + +class Postnet(nn.Module): + r"""Tacotron2 Postnet + + Args: + in_out_channels (int): number of output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + + def __init__(self, in_out_channels, num_convs=5): + super().__init__() + self.convolutions = nn.ModuleList() + self.convolutions.append(ConvBNBlock(in_out_channels, 512, kernel_size=5, activation="tanh")) + for _ in range(1, num_convs - 1): + self.convolutions.append(ConvBNBlock(512, 512, kernel_size=5, activation="tanh")) + self.convolutions.append(ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) + + def forward(self, x): + o = x + for layer in self.convolutions: + o = layer(o) + return o + + +class Encoder(nn.Module): + r"""Tacotron2 Encoder + + Args: + in_out_channels (int): number of input and output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + + def __init__(self, in_out_channels=512): + super().__init__() + self.convolutions = nn.ModuleList() + for _ in range(3): + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, bidirectional=True + ) + self.rnn_state = None + + def forward(self, x, input_lengths): + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, input_lengths.cpu(), batch_first=True) + self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + return o + + def inference(self, x): + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + # self.lstm.flatten_parameters() + o, _ = self.lstm(o) + return o + + +# adapted from https://github.com/NVIDIA/tacotron2/ +class Decoder(nn.Module): + """Tacotron2 decoder. We don't use Zoneout but Dropout between RNN layers. + + Args: + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). + memory_size (int): size of the past window. if <= 0 memory_size = r + attn_type (string): type of attention used in decoder. + attn_win (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + in_channels, + frame_channels, + r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + max_decoder_steps, + ): + super().__init__() + self.frame_channels = frame_channels + self.r_init = r + self.r = r + self.encoder_embedding_dim = in_channels + self.separate_stopnet = separate_stopnet + self.max_decoder_steps = max_decoder_steps + self.stop_threshold = 0.5 + + # model dimensions + self.query_dim = 1024 + self.decoder_rnn_dim = 1024 + self.prenet_dim = 256 + self.attn_dim = 128 + self.p_attention_dropout = 0.1 + self.p_decoder_dropout = 0.1 + + # memory -> |Prenet| -> processed_memory + prenet_dim = self.frame_channels + self.prenet = Prenet( + prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False + ) + + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True) + + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) + + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True) + + self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, self.frame_channels * self.r_init) + + self.stopnet = nn.Sequential( + nn.Dropout(0.1), + Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain="sigmoid"), + ) + self.memory_truncated = None + + def set_r(self, new_r): + self.r = new_r + + def get_go_frame(self, inputs): + B = inputs.size(0) + memory = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.r) + return memory + + def _init_states(self, inputs, mask, keep_states=False): + B = inputs.size(0) + # T = inputs.size(1) + if not keep_states: + self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.context = torch.zeros(1, device=inputs.device).repeat(B, self.encoder_embedding_dim) + self.inputs = inputs + self.processed_inputs = self.attention.preprocess_inputs(inputs) + self.mask = mask + + def _reshape_memory(self, memory): + """ + Reshape the spectrograms for given 'r' + """ + # Grouping multiple frames if necessary + if memory.size(-1) == self.frame_channels: + memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) + # Time first (T_decoder, B, frame_channels) + memory = memory.transpose(0, 1) + return memory + + def _parse_outputs(self, outputs, stop_tokens, alignments): + alignments = torch.stack(alignments).transpose(0, 1) + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) + outputs = outputs.transpose(1, 2) + return outputs, stop_tokens, alignments + + def _update_memory(self, memory): + if len(memory.shape) == 2: + return memory[:, self.frame_channels * (self.r - 1) :] + return memory[:, :, self.frame_channels * (self.r - 1) :] + + def decode(self, memory): + """ + shapes: + - memory: B x r * self.frame_channels + """ + # self.context: B x D_en + # query_input: B x D_en + (r * self.frame_channels) + query_input = torch.cat((memory, self.context), -1) + # self.query and self.attention_rnn_cell_state : B x D_attn_rnn + self.query, self.attention_rnn_cell_state = self.attention_rnn( + query_input, (self.query, self.attention_rnn_cell_state) + ) + self.query = F.dropout(self.query, self.p_attention_dropout, self.training) + self.attention_rnn_cell_state = F.dropout( + self.attention_rnn_cell_state, self.p_attention_dropout, self.training + ) + # B x D_en + self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) + # B x (D_en + D_attn_rnn) + decoder_rnn_input = torch.cat((self.query, self.context), -1) + # self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_rnn_input, (self.decoder_hidden, self.decoder_cell) + ) + self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) + # B x (D_decoder_rnn + D_en) + decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) + # B x (self.r * self.frame_channels) + decoder_output = self.linear_projection(decoder_hidden_context) + # B x (D_decoder_rnn + (self.r * self.frame_channels)) + stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) + # select outputs for the reduction rate self.r + decoder_output = decoder_output[:, : self.r * self.frame_channels] + return decoder_output, self.attention.attention_weights, stop_token + + def forward(self, inputs, memories, mask): + r"""Train Decoder with teacher forcing. + Args: + inputs: Encoder outputs. + memories: Feature frames for teacher-forcing. + mask: Attention mask for sequence padding. + + Shapes: + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ + memory = self.get_go_frame(inputs).unsqueeze(0) + memories = self._reshape_memory(memories) + memories = torch.cat((memory, memories), dim=0) + memories = self._update_memory(memories) + memories = self.prenet(memories) + + self._init_states(inputs, mask=mask) + self.attention.init_states(inputs) + + outputs, stop_tokens, alignments = [], [], [] + while len(outputs) < memories.size(0) - 1: + memory = memories[len(outputs)] + decoder_output, attention_weights, stop_token = self.decode(memory) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token.squeeze(1)] + alignments += [attention_weights] + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + return outputs, alignments, stop_tokens + + def inference(self, inputs): + r"""Decoder inference without teacher forcing and use + Stopnet to stop decoder. + Args: + inputs: Encoder outputs. + + Shapes: + - inputs: (B, T, D_out_enc) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ + memory = self.get_go_frame(inputs) + memory = self._update_memory(memory) + + self._init_states(inputs, mask=None) + self.attention.init_states(inputs) + + outputs, stop_tokens, alignments, t = [], [], [], 0 + while True: + memory = self.prenet(memory) + decoder_output, alignment, stop_token = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token] + alignments += [alignment] + + if stop_token > self.stop_threshold and t > inputs.shape[0] // 2: + break + if len(outputs) == self.max_decoder_steps: + print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}") + break + + memory = self._update_memory(decoder_output) + t += 1 + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + + return outputs, alignments, stop_tokens + + def inference_truncated(self, inputs): + """ + Preserve decoder states for continuous inference + """ + if self.memory_truncated is None: + self.memory_truncated = self.get_go_frame(inputs) + self._init_states(inputs, mask=None, keep_states=False) + else: + self._init_states(inputs, mask=None, keep_states=True) + + self.attention.init_states(inputs) + outputs, stop_tokens, alignments, t = [], [], [], 0 + while True: + memory = self.prenet(self.memory_truncated) + decoder_output, alignment, stop_token = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token] + alignments += [alignment] + + if stop_token > 0.7: + break + if len(outputs) == self.max_decoder_steps: + print(" | > Decoder stopped with 'max_decoder_steps") + break + + self.memory_truncated = decoder_output + t += 1 + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + + return outputs, alignments, stop_tokens + + def inference_step(self, inputs, t, memory=None): + """ + For debug purposes + """ + if t == 0: + memory = self.get_go_frame(inputs) + self._init_states(inputs, mask=None) + + memory = self.prenet(memory) + decoder_output, stop_token, alignment = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + memory = decoder_output + return decoder_output, stop_token, alignment diff --git a/viXTTS/TTS/tts/layers/tortoise/arch_utils.py b/viXTTS/TTS/tts/layers/tortoise/arch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dad1814369599f0bc637a92624a73dfab99dc1a1 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/arch_utils.py @@ -0,0 +1,433 @@ +import functools +import math +import os + +import fsspec +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from transformers import LogitsWarper + +from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, rel_pos=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + if rel_pos is not None: + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( + bs * self.n_heads, weight.shape[-2], weight.shape[-1] + ) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + do_checkpoint=True, + relative_pos_embeddings=False, + ): + super().__init__() + self.channels = channels + self.do_checkpoint = do_checkpoint + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias( + scale=(channels // self.num_heads) ** 0.5, + causal=False, + heads=num_heads, + num_buckets=32, + max_distance=64, + ) + else: + self.relative_pos_embeddings = None + + def forward(self, x, mask=None): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv, mask, self.relative_pos_embeddings) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.factor = factor + if use_conv: + ksize = 5 + pad = 2 + self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad) + + def forward(self, x): + assert x.shape[1] == self.channels + x = F.interpolate(x, scale_factor=self.factor, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + + stride = factor + if use_conv: + self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad) + else: + assert self.channels == self.out_channels + self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + up=False, + down=False, + kernel_size=3, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False) + self.x_upd = Upsample(channels, False) + elif down: + self.h_upd = Downsample(channels, False) + self.x_upd = Downsample(channels, False) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3, + ): + super().__init__() + self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) + ch = base_channels + res = [] + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) + attn = [] + for a in range(attn_blocks): + attn.append( + AttentionBlock( + embedding_dim, + num_attn_heads, + ) + ) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + h = self.attn(h) + return h[:, :, 0] + + +DEFAULT_MEL_NORM_FILE = "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth" + + +class TorchMelSpectrogram(nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + mel_fmin=0, + mel_fmax=8000, + sampling_rate=22050, + normalize=False, + mel_norm_file=DEFAULT_MEL_NORM_FILE, + ): + super().__init__() + # These are the default tacotron values for the MEL spectrogram. + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.sampling_rate = sampling_rate + self.mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + power=2, + normalized=normalize, + sample_rate=self.sampling_rate, + f_min=self.mel_fmin, + f_max=self.mel_fmax, + n_mels=self.n_mel_channels, + norm="slaney", + ) + self.mel_norm_file = mel_norm_file + if self.mel_norm_file is not None: + with fsspec.open(self.mel_norm_file) as f: + self.mel_norms = torch.load(f) + else: + self.mel_norms = None + + def forward(self, inp): + if ( + len(inp.shape) == 3 + ): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + self.mel_stft = self.mel_stft.to(inp.device) + mel = self.mel_stft(inp) + # Perform dynamic range compression + mel = torch.log(torch.clamp(mel, min=1e-5)) + if self.mel_norms is not None: + self.mel_norms = self.mel_norms.to(mel.device) + mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class CheckpointedLayer(nn.Module): + """ + Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses + checkpoint for all other args. + """ + + def __init__(self, wrap): + super().__init__() + self.wrap = wrap + + def forward(self, x, *args, **kwargs): + for k, v in kwargs.items(): + assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. + partial = functools.partial(self.wrap, **kwargs) + return partial(x, *args) + + +class CheckpointedXTransformerEncoder(nn.Module): + """ + Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid + to channels-last that XTransformer expects. + """ + + def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) + self.needs_permute = needs_permute + self.exit_permute = exit_permute + + if not checkpoint: + return + for i in range(len(self.transformer.attn_layers.layers)): + n, b, r = self.transformer.attn_layers.layers[i] + self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + + def forward(self, x, **kwargs): + if self.needs_permute: + x = x.permute(0, 2, 1) + h = self.transformer(x, **kwargs) + if self.exit_permute: + h = h.permute(0, 2, 1) + return h + + +class TypicalLogitsWarper(LogitsWarper): + def __init__( + self, + mass: float = 0.9, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, + ): + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores diff --git a/viXTTS/TTS/tts/layers/tortoise/audio_utils.py b/viXTTS/TTS/tts/layers/tortoise/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70711ed7a485ecd4a8c8eb8ab6c338aa79871de7 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/audio_utils.py @@ -0,0 +1,177 @@ +import os +from glob import glob +from typing import Dict, List + +import librosa +import numpy as np +import torch +import torchaudio +from scipy.io.wavfile import read + +from TTS.utils.audio.torch_transforms import TorchSTFT + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + if data.dtype == np.int32: + norm_fix = 2**31 + elif data.dtype == np.int16: + norm_fix = 2**15 + elif data.dtype == np.float16 or data.dtype == np.float32: + norm_fix = 1.0 + else: + raise NotImplementedError(f"Provided data dtype not supported: {data.dtype}") + return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) + + +def check_audio(audio, audiopath: str): + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 2) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + audio.clip_(-1, 1) + + +def read_audio_file(audiopath: str): + if audiopath[-4:] == ".wav": + audio, lsr = load_wav_to_torch(audiopath) + elif audiopath[-4:] == ".mp3": + audio, lsr = librosa.load(audiopath, sr=None) + audio = torch.FloatTensor(audio) + else: + assert False, f"Unsupported audio format provided: {audiopath[-4:]}" + + # Remove any channel data. + if len(audio.shape) > 1: + if audio.shape[0] < 5: + audio = audio[0] + else: + assert audio.shape[1] < 5 + audio = audio[:, 0] + + return audio, lsr + + +def load_required_audio(audiopath: str): + audio, lsr = read_audio_file(audiopath) + + audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)] + for audio in audios: + check_audio(audio, audiopath) + + return [audio.unsqueeze(0) for audio in audios] + + +def load_audio(audiopath, sampling_rate): + audio, lsr = read_audio_file(audiopath) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + check_audio(audio, audiopath) + + return audio.unsqueeze(0) + + +TACOTRON_MEL_MAX = 2.3143386840820312 +TACOTRON_MEL_MIN = -11.512925148010254 + + +def denormalize_tacotron_mel(norm_mel): + return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN + + +def normalize_tacotron_mel(mel): + return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def get_voices(extra_voice_dirs: List[str] = []): + dirs = extra_voice_dirs + voices: Dict[str, List[str]] = {} + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth")) + return voices + + +def load_voice(voice: str, extra_voice_dirs: List[str] = []): + if voice == "random": + return None, None + + voices = get_voices(extra_voice_dirs) + paths = voices[voice] + if len(paths) == 1 and paths[0].endswith(".pth"): + return None, torch.load(paths[0]) + else: + conds = [] + for cond_path in paths: + c = load_required_audio(cond_path) + conds.append(c) + return conds, None + + +def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): + latents = [] + clips = [] + for voice in voices: + if voice == "random": + if len(voices) > 1: + print("Cannot combine a random voice with a non-random voice. Just using a random voice.") + return None, None + clip, latent = load_voice(voice, extra_voice_dirs) + if latent is None: + assert ( + len(latents) == 0 + ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + clips.extend(clip) + elif clip is None: + assert ( + len(clips) == 0 + ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + latents.append(latent) + if len(latents) == 0: + return clips, None + else: + latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0) + latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0) + latents = (latents_0, latents_1) + return None, latents + + +def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): + stft = TorchSTFT( + n_fft=1024, + hop_length=256, + win_length=1024, + use_mel=True, + n_mels=100, + sample_rate=24000, + mel_fmin=0, + mel_fmax=12000, + ) + stft = stft.to(device) + mel = stft(wav) + mel = dynamic_range_compression(mel) + if do_normalization: + mel = normalize_tacotron_mel(mel) + return mel diff --git a/viXTTS/TTS/tts/layers/tortoise/autoregressive.py b/viXTTS/TTS/tts/layers/tortoise/autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..14d881bc1029ef577f24ae28f9414e431661142a --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/autoregressive.py @@ -0,0 +1,631 @@ +# AGPL: a notification must be added stating that changes have been made to that file. +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +def _p(t): + return t and (len(t), len(t[0]), t[0][0].shape) # kv_cache debug + + +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.text_pos_embedding = text_pos_emb + self.embeddings = embeddings + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + def store_mel_emb(self, mel_emb): + self.cached_mel_emb = mel_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_mel_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Create embedding + mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: + text_inputs = input_ids[:, mel_len:] + text_emb = self.embeddings(text_inputs) + text_emb = text_emb + self.text_pos_embedding(text_emb) + if self.cached_mel_emb.shape[0] != text_emb.shape[0]: + mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0) + else: # this outcome only occurs once per loop in most cases + mel_emb = self.cached_mel_emb + emb = torch.cat([mel_emb, text_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) + + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +class ConditioningEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + do_checkpointing=False, + mean=False, + ): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + self.mean = mean + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + if self.mean: + return h.mean(dim=2) + else: + return h[:, :, 0] + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=0.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] + + +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len, + n_ctx=max_mel_seq_len + max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024) + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + return ( + gpt, + LearnedPositionEmbeddings(max_mel_seq_len, model_dim), + LearnedPositionEmbeddings(max_text_seq_len, model_dim), + None, + None, + ) + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential( + nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 16, channels // 2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x.permute(0, 2, 1) + + +class UnifiedVoice(nn.Module): + def __init__( + self, + layers=8, + model_dim=512, + heads=8, + max_text_tokens=120, + max_mel_tokens=250, + max_conditioning_inputs=1, + mel_length_compression=1024, + number_text_tokens=256, + start_text_token=None, + number_mel_codes=8194, + start_mel_token=8192, + stop_mel_token=8193, + train_solo_embeddings=False, + use_mel_codes_as_input=True, + checkpointing=True, + types=1, + ): + """ + Args: + layers: Number of layers in transformer stack. + model_dim: Operating dimensions of the transformer + heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 + max_text_tokens: Maximum number of text tokens that will be encountered by model. + max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). + mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. + number_text_tokens: + start_text_token: + stop_text_token: + number_mel_codes: + start_mel_token: + stop_mel_token: + train_solo_embeddings: + use_mel_codes_as_input: + checkpointing: + """ + super().__init__() + + self.number_text_tokens = number_text_tokens + self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token + self.stop_text_token = 0 + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.layers = layers + self.heads = heads + self.max_mel_tokens = max_mel_tokens + self.max_text_tokens = max_text_tokens + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.mel_length_compression = mel_length_compression + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim) + if use_mel_codes_as_input: + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + else: + self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + ( + self.gpt, + self.mel_pos_embedding, + self.text_pos_embedding, + self.mel_layer_pos_embedding, + self.text_layer_pos_embedding, + ) = build_hf_gpt_transformer( + layers, + model_dim, + heads, + self.max_mel_tokens + 2 + self.max_conditioning_inputs, + self.max_text_tokens + 2, + checkpointing, + ) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + + # Initialize the embeddings per the GPT-2 scheme + embeddings = [self.text_embedding] + if use_mel_codes_as_input: + embeddings.append(self.mel_embedding) + for module in embeddings: + module.weight.data.normal_(mean=0.0, std=0.02) + + def post_init_gpt2_config(self, kv_cache=True): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding + # self.inference_model.save_pretrained("") + + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, wav_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc") + for b in range(len(mel_lengths)): + actual_end = ( + mel_lengths[b] + 1 + ) # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_mel_token + return mel_input_tokens + + def get_logits( + self, + speech_conditioning_inputs, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + get_attns=False, + return_latent=False, + ): + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + enc = self.final_norm(enc) + + if return_latent: + return ( + enc[ + :, + speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1], + ], + enc[:, -second_inputs.shape[1] :], + ) + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def forward( + self, + speech_conditioning_latent, + text_inputs, + text_lengths, + mel_codes, + wav_lengths, + types=None, + text_first=True, + raw_mels=None, + return_attentions=False, + return_latent=False, + clip_inputs=True, + ): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + speech_conditioning_input: MEL float tensor, (b,1024) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + raw_mels: MEL float tensor (b,80,s) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. + """ + # Types are expressed by expanding the text embedding space. + if types is not None: + text_inputs = text_inputs * (1 + types).unsqueeze(-1) + + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, : max_mel_len * 4] + mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token) + + conds = speech_conditioning_latent.unsqueeze(1) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token + ) + if raw_mels is not None: + mel_inp = F.pad(raw_mels, (0, 8)) + else: + mel_inp = mel_codes + mel_emb = self.mel_embedding(mel_inp) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + + if text_first: + text_logits, mel_logits = self.get_logits( + conds, + text_emb, + self.text_head, + mel_emb, + self.mel_head, + get_attns=return_attentions, + return_latent=return_latent, + ) + if return_latent: + return mel_logits[ + :, :-2 + ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + else: + mel_logits, text_logits = self.get_logits( + conds, + mel_emb, + self.mel_head, + text_emb, + self.text_head, + get_attns=return_attentions, + return_latent=return_latent, + ) + if return_latent: + return text_logits[ + :, :-2 + ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + + if return_attentions: + return mel_logits + loss_text = F.cross_entropy(text_logits, text_targets.long()) + loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) + return loss_text.mean(), loss_mel.mean(), mel_logits + + def inference_speech( + self, + speech_conditioning_latent, + text_inputs, + input_tokens=None, + num_return_sequences=1, + max_generate_length=None, + typical_sampling=False, + typical_mass=0.9, + **hf_generate_kwargs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + conds = speech_conditioning_latent.unsqueeze(1) + emb = torch.cat([conds, text_emb], dim=1) + self.inference_model.store_mel_emb(emb) + + fake_inputs = torch.full( + ( + emb.shape[0], + conds.shape[1] + emb.shape[1], + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + fake_inputs[:, -1] = self.start_mel_token + trunc_index = fake_inputs.shape[1] + if input_tokens is None: + inputs = fake_inputs + else: + assert ( + num_return_sequences % input_tokens.shape[0] == 0 + ), "The number of return sequences must be divisible by the number of input sequences" + fake_inputs = fake_inputs.repeat(num_return_sequences, 1) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + inputs = torch.cat([fake_inputs, input_tokens], dim=1) + + logits_processor = ( + LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() + ) # TODO disable this + max_length = ( + trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length + ) + gen = self.inference_model.generate( + inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=max_length, + logits_processor=logits_processor, + num_return_sequences=num_return_sequences, + **hf_generate_kwargs, + ) + return gen[:, trunc_index:] + + +if __name__ == "__main__": + gpt = UnifiedVoice( + model_dim=256, + heads=4, + train_solo_embeddings=True, + use_mel_codes_as_input=True, + max_conditioning_inputs=4, + ) + l = gpt( + torch.randn(2, 3, 80, 800), + torch.randint(high=120, size=(2, 120)), + torch.tensor([32, 120]), + torch.randint(high=8192, size=(2, 250)), + torch.tensor([250 * 256, 195 * 256]), + ) + gpt.text_forward( + torch.randn(2, 80, 800), + torch.randint(high=50, size=(2, 80)), + torch.tensor([32, 80]), + ) diff --git a/viXTTS/TTS/tts/layers/tortoise/classifier.py b/viXTTS/TTS/tts/layers/tortoise/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..8764bb070b5ad8267ee2992ccc33f5bb65bad005 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/classifier.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + kernel_size=3, + do_checkpoint=True, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.do_checkpoint = do_checkpoint + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3, + ): + super().__init__() + self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) + ch = base_channels + res = [] + self.layers = depth + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) + attn = [] + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + for blk in self.attn: + h = blk(h) + return h[:, :, 0] + + +class AudioMiniEncoderWithClassifierHead(nn.Module): + def __init__(self, classes, distribute_zero_label=True, **kwargs): + super().__init__() + self.enc = AudioMiniEncoder(**kwargs) + self.head = nn.Linear(self.enc.dim, classes) + self.num_classes = classes + self.distribute_zero_label = distribute_zero_label + + def forward(self, x, labels=None): + h = self.enc(x) + logits = self.head(h) + if labels is None: + return logits + else: + if self.distribute_zero_label: + oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) + zeros_indices = (labels == 0).unsqueeze(-1) + # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. + zero_extra_mass = torch.full_like( + oh_labels, + dtype=torch.float, + fill_value=0.2 / (self.num_classes - 1), + ) + zero_extra_mass[:, 0] = -0.2 + zero_extra_mass = zero_extra_mass * zeros_indices + oh_labels = oh_labels + zero_extra_mass + else: + oh_labels = labels + loss = nn.functional.cross_entropy(logits, oh_labels) + return loss diff --git a/viXTTS/TTS/tts/layers/tortoise/clvp.py b/viXTTS/TTS/tts/layers/tortoise/clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..69b8c17c3fe71f55be12b728fa3c8f0e85cefb89 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/clvp.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from TTS.tts.layers.tortoise.arch_utils import CheckpointedXTransformerEncoder +from TTS.tts.layers.tortoise.transformer import Transformer +from TTS.tts.layers.tortoise.xtransformers import Encoder + + +def exists(val): + return val is not None + + +def masked_mean(t, mask, dim=1): + t = t.masked_fill(~mask[:, :, None], 0.0) + return t.sum(dim=1) / mask.sum(dim=1)[..., None] + + +class CLVP(nn.Module): + """ + CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding + transcribed text. + + Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py + """ + + def __init__( + self, + *, + dim_text=512, + dim_speech=512, + dim_latent=512, + num_text_tokens=256, + text_enc_depth=6, + text_seq_len=120, + text_heads=8, + num_speech_tokens=8192, + speech_enc_depth=6, + speech_heads=8, + speech_seq_len=250, + text_mask_percentage=0, + voice_mask_percentage=0, + wav_token_compression=1024, + use_xformers=False, + ): + super().__init__() + self.text_emb = nn.Embedding(num_text_tokens, dim_text) + self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) + + self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) + self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + + if use_xformers: + self.text_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_text, + depth=text_enc_depth, + heads=text_heads, + ff_dropout=0.1, + ff_mult=2, + attn_dropout=0.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ), + ) + self.speech_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_speech, + depth=speech_enc_depth, + heads=speech_heads, + ff_dropout=0.1, + ff_mult=2, + attn_dropout=0.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ), + ) + else: + self.text_transformer = Transformer( + causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads + ) + self.speech_transformer = Transformer( + causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads + ) + + self.temperature = nn.Parameter(torch.tensor(1.0)) + self.text_mask_percentage = text_mask_percentage + self.voice_mask_percentage = voice_mask_percentage + self.wav_token_compression = wav_token_compression + self.xformers = use_xformers + if not use_xformers: + self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) + self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) + + def forward(self, text, speech_tokens, return_loss=False): + b, device = text.shape[0], text.device + if self.training: + text_mask = torch.rand_like(text.float()) > self.text_mask_percentage + voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage + else: + text_mask = torch.ones_like(text.float()).bool() + voice_mask = torch.ones_like(speech_tokens.float()).bool() + + text_emb = self.text_emb(text) + speech_emb = self.speech_emb(speech_tokens) + + if not self.xformers: + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + + enc_text = self.text_transformer(text_emb, mask=text_mask) + enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) + + text_latents = masked_mean(enc_text, text_mask, dim=1) + speech_latents = masked_mean(enc_speech, voice_mask, dim=1) + + text_latents = self.to_text_latent(text_latents) + speech_latents = self.to_speech_latent(speech_latents) + + text_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (text_latents, speech_latents)) + + temp = self.temperature.exp() + + if not return_loss: + sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp + return sim + + sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp + labels = torch.arange(b, device=device) + loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + return loss + + +if __name__ == "__main__": + clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2) + clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=True, + ) + nonloss = clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=False, + ) + print(nonloss.shape) diff --git a/viXTTS/TTS/tts/layers/tortoise/diffusion.py b/viXTTS/TTS/tts/layers/tortoise/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..7bea02ca08a46cb474406014c690b7973e33d55d --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/diffusion.py @@ -0,0 +1,1242 @@ +""" +This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself: + +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math + +import numpy as np +import torch +import torch as th +from tqdm import tqdm + +from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper + +try: + from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral + + K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} +except ImportError: + K_DIFFUSION_SAMPLERS = None + + +SAMPLERS = ["dpm++2m", "p", "ddim"] + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = "previous_x" # the model predicts x_{t-1} + START_X = "start_x" # the model predicts x_0 + EPSILON = "epsilon" # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = "learned" + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED_RANGE = "learned_range" + + +class LossType(enum.Enum): + MSE = "mse" # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances) + KL = "kl" # use the variational lower-bound + RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + conditioning_free=False, + conditioning_free_k=1, + ramp_conditioning_free=True, + sampler="p", + ): + self.sampler = sampler + self.model_mean_type = ModelMeanType(model_mean_type) + self.model_var_type = ModelVarType(model_var_type) + self.loss_type = LossType(loss_type) + self.rescale_timesteps = rescale_timesteps + self.conditioning_free = conditioning_free + self.conditioning_free_k = conditioning_free_k + self.ramp_conditioning_free = ramp_conditioning_free + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.conditioning_free: + model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.conditioning_free: + model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + if self.conditioning_free: + if self.ramp_conditioning_free: + assert t.shape[0] == 1 # This should only be used in inference. + cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) + else: + cfk = self.conditioning_free_k + model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def k_diffusion_sample_loop( + self, + k_sampler, + pbar, + model, + shape, + noise=None, # all given + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + device=None, # ALL UNUSED + model_kwargs=None, # {'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=False, # unused as well + ): + assert isinstance(model_kwargs, dict) + if device is None: + device = next(model.parameters()).device + s_in = noise.new_ones([noise.shape[0]]) + + def model_split(*args, **kwargs): + model_output = model(*args, **kwargs) + model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1) + return model_epsilon, model_var + + # + """ + print(self.betas) + print(th.tensor(self.betas)) + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) + """ + noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4) + + def model_fn_prewrap(x, t, *args, **kwargs): + """ + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + print(t) + print(self.timestep_map) + exit() + """ + """ + model_output = model(x, self._scale_timesteps(t*4000), **model_kwargs) + out = self.p_mean_variance(model, x, t*4000, model_kwargs=model_kwargs) + return out['pred_xstart'] + """ + x, _ = x.chunk(2) + t, _ = (t * 1000).chunk(2) + res = torch.cat( + [ + model_split(x, t, conditioning_free=True, **model_kwargs)[0], + model_split(x, t, **model_kwargs)[0], + ] + ) + pbar.update(1) + return res + + model_fn = model_wrapper( + model_fn_prewrap, + noise_schedule, + model_type="noise", # "noise" or "x_start" or "v" or "score" + model_kwargs=model_kwargs, + guidance_type="classifier-free", + condition=th.Tensor(1), + unconditional_condition=th.Tensor(1), + guidance_scale=self.conditioning_free_k, + ) + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + x_sample = dpm_solver.sample( + noise, + steps=self.num_timesteps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + #''' + return x_sample + + def sample_loop(self, *args, **kwargs): + s = self.sampler + if s == "p": + return self.p_sample_loop(*args, **kwargs) + elif s == "ddim": + return self.ddim_sample_loop(*args, **kwargs) + elif s == "dpm++2m": + if self.conditioning_free is not True: + raise RuntimeError("cond_free must be true") + with tqdm(total=self.num_timesteps) as pbar: + if K_DIFFUSION_SAMPLERS is None: + raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers") + return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) + else: + raise RuntimeError("sampler not impl") + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + for i in tqdm(indices, disable=not progress): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, disable=not progress) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + # TODO: support multiple model outputs for this mode. + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) + if isinstance(model_outputs, tuple): + model_output = model_outputs[0] + terms["extra_outputs"] = model_outputs[1:] + else: + model_output = model_outputs + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def autoregressive_training_losses( + self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None + ): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert False # not currently supported for this type of diffusion. + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) + terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) + model_output = terms[gd_out_key] + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C, 2, *x_t.shape[2:]) + model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1] + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model, autoregressive=False): + if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): + return model + mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel + return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + model_output = self.model(x, new_ts, **kwargs) + return model_output + + +class _WrappedAutoregressiveModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, x0, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, x0, new_ts, **kwargs) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/viXTTS/TTS/tts/layers/tortoise/diffusion_decoder.py b/viXTTS/TTS/tts/layers/tortoise/diffusion_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3cf7698a7334b4cfc8d9bdd0f5f6ee3059189d --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/diffusion_decoder.py @@ -0,0 +1,415 @@ +import math +import random +from abc import abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import autocast + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, normalization + + +def is_latent(t): + return t.dtype == torch.float + + +def is_sequence(t): + return t.dtype == torch.long + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepBlock(nn.Module): + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class ResBlock(TimestepBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + dims=2, + kernel_size=3, + efficient_config=True, + use_scale_shift_norm=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_scale_shift_norm = use_scale_shift_norm + padding = {1: 0, 3: 1, 5: 2}[kernel_size] + eff_kernel = 1 if efficient_config else 3 + eff_padding = 0 if efficient_config else 1 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding), + ) + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding) + + def forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class DiffusionLayer(TimestepBlock): + def __init__(self, model_channels, dropout, num_heads): + super().__init__() + self.resblk = ResBlock( + model_channels, + model_channels, + dropout, + model_channels, + dims=1, + use_scale_shift_norm=True, + ) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + + def forward(self, x, time_emb): + y = self.resblk(x, time_emb) + return self.attn(y) + + +class DiffusionTts(nn.Module): + def __init__( + self, + model_channels=512, + num_layers=8, + in_channels=100, + in_latent_channels=512, + in_tokens=8193, + out_channels=200, # mean and variance + dropout=0, + use_fp16=False, + num_heads=16, + # Parameters for regularization. + layer_drop=0.1, + unconditioned_percentage=0.1, # This implements a mechanism similar to what is used in classifier-free training. + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.dropout = dropout + self.num_heads = num_heads + self.unconditioned_percentage = unconditioned_percentage + self.enable_fp16 = use_fp16 + self.layer_drop = layer_drop + + self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1) + self.time_embed = nn.Sequential( + nn.Linear(model_channels, model_channels), + nn.SiLU(), + nn.Linear(model_channels, model_channels), + ) + + # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. + # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally + # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive + # transformer network. + self.code_embedding = nn.Embedding(in_tokens, model_channels) + self.code_converter = nn.Sequential( + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.code_norm = normalization(model_channels) + self.latent_conditioner = nn.Sequential( + nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.contextual_embedder = nn.Sequential( + nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2), + nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + ) + self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1)) + self.conditioning_timestep_integrator = TimestepEmbedSequential( + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + ) + + self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) + + self.layers = nn.ModuleList( + [DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + + [ + ResBlock( + model_channels, + model_channels, + dropout, + dims=1, + use_scale_shift_norm=True, + ) + for _ in range(3) + ] + ) + + self.out = nn.Sequential( + normalization(model_channels), + nn.SiLU(), + nn.Conv1d(model_channels, out_channels, 3, padding=1), + ) + + def get_grad_norm_parameter_groups(self): + groups = { + "minicoder": list(self.contextual_embedder.parameters()), + "layers": list(self.layers.parameters()), + "code_converters": list(self.code_embedding.parameters()) + + list(self.code_converter.parameters()) + + list(self.latent_conditioner.parameters()) + + list(self.latent_conditioner.parameters()), + "timestep_integrator": list(self.conditioning_timestep_integrator.parameters()) + + list(self.integrating_conv.parameters()), + "time_embed": list(self.time_embed.parameters()), + } + return groups + + def get_conditioning(self, conditioning_input): + speech_conditioning_input = ( + conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.cat(conds, dim=-1) + conds = conds.mean(dim=-1) + return conds + + def timestep_independent( + self, + aligned_conditioning, + conditioning_latent, + expected_seq_len, + return_code_pred, + ): + # Shuffle aligned_latent to BxCxS format + if is_latent(aligned_conditioning): + aligned_conditioning = aligned_conditioning.permute(0, 2, 1) + + cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1) + if is_latent(aligned_conditioning): + code_emb = self.latent_conditioner(aligned_conditioning) + else: + code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) + code_emb = self.code_converter(code_emb) + code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = ( + torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage + ) + code_emb = torch.where( + unconditioned_batches, + self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), + code_emb, + ) + expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest") + + if not return_code_pred: + return expanded_code_emb + else: + mel_pred = self.mel_head(expanded_code_emb) + # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. + mel_pred = mel_pred * unconditioned_batches.logical_not() + return expanded_code_emb, mel_pred + + def forward( + self, + x, + timesteps, + aligned_conditioning=None, + conditioning_latent=None, + precomputed_aligned_embeddings=None, + conditioning_free=False, + return_code_pred=False, + ): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning(). + :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() + :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. + :return: an [N x C x ...] Tensor of outputs. + """ + assert precomputed_aligned_embeddings is not None or ( + aligned_conditioning is not None and conditioning_latent is not None + ) + assert not ( + return_code_pred and precomputed_aligned_embeddings is not None + ) # These two are mutually exclusive. + + unused_params = [] + if conditioning_free: + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) + else: + if precomputed_aligned_embeddings is not None: + code_emb = precomputed_aligned_embeddings + else: + code_emb, mel_pred = self.timestep_independent( + aligned_conditioning, conditioning_latent, x.shape[-1], True + ) + if is_latent(aligned_conditioning): + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters()) + ) + else: + unused_params.extend(list(self.latent_conditioner.parameters())) + + unused_params.append(self.unconditioned_embedding) + + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + x = self.inp_block(x) + x = torch.cat([x, code_emb], dim=1) + x = self.integrating_conv(x) + for i, lyr in enumerate(self.layers): + # Do layer drop where applicable. Do not drop first and last layers. + if ( + self.training + and self.layer_drop > 0 + and i != 0 + and i != (len(self.layers) - 1) + and random.random() < self.layer_drop + ): + unused_params.extend(list(lyr.parameters())) + else: + # First and last blocks will have autocast disabled for improved precision. + with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + x = lyr(x, time_emb) + + x = x.float() + out = self.out(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + + if return_code_pred: + return out, mel_pred + return out + + +if __name__ == "__main__": + clip = torch.randn(2, 100, 400) + aligned_latent = torch.randn(2, 388, 512) + aligned_sequence = torch.randint(0, 8192, (2, 100)) + cond = torch.randn(2, 100, 400) + ts = torch.LongTensor([600, 600]) + model = DiffusionTts(512, layer_drop=0.3, unconditioned_percentage=0.5) + # Test with latent aligned conditioning + # o = model(clip, ts, aligned_latent, cond) + # Test with sequence aligned conditioning + o = model(clip, ts, aligned_sequence, cond) diff --git a/viXTTS/TTS/tts/layers/tortoise/dpm_solver.py b/viXTTS/TTS/tts/layers/tortoise/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..c70888df42063e65dabf50eadb9a78813effa4e9 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/dpm_solver.py @@ -0,0 +1,1562 @@ +import math + +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear", "cosine"]: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule + ) + ) + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) + self.schedule = schedule + if schedule == "cosine": + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1.0 + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), + self.t_array.to(t.device), + self.log_alpha_array.to(t.device), + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + + def t_fn(log_alpha_t): + return ( + torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims( + torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), + dims, + ) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * ( + K - 1 + ) + [1] + else: + orders = [ + 3, + ] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update( + self, + x, + s, + t, + r1=0.5, + model_s=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(t), + ) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) + alpha_s1, alpha_s2, alpha_t = ( + torch.exp(log_alpha_s1), + torch.exp(log_alpha_s2), + torch.exp(log_alpha_t), + ) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + x, + s, + t, + order, + return_intermediate=False, + solver_type="dpmsolver", + r1=None, + r2=None, + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + r2=r2, + ) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive( + self, + x, + order, + t_T, + t_0, + h_init=0.05, + atol=0.0078, + rtol=0.05, + theta=0.9, + t_err=1e-5, + solver_type="dpmsolver", + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + + def lower_update(x, s, t): + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + + def lower_update(x, s, t): + return self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max( + torch.ones_like(x).to(x) * atol, + rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)), + ) + + def norm_fn(v): + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min( + theta * h * torch.float_power(E, -1.0 / order).float(), + lambda_0 - lambda_s, + ) + nfe += order + print("adaptive solver nfe", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate, + ) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive( + x, + order=order, + t_T=t_T, + t_0=t_0, + atol=atol, + rtol=rtol, + solver_type=solver_type, + ) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step_order, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + ( + timesteps_outer, + orders, + ) = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, + order=order, + skip_type=skip_type, + t_T=t_T, + t_0=t_0, + device=device, + ) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, + t_T=s.item(), + t_0=t.item(), + N=order, + device=device, + ) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] diff --git a/viXTTS/TTS/tts/layers/tortoise/random_latent_generator.py b/viXTTS/TTS/tts/layers/tortoise/random_latent_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9b39c1e4b22ee5a9ad84a1711a08a8530c4d76b7 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/random_latent_generator.py @@ -0,0 +1,55 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), + negative_slope=negative_slope, + ) + * scale + ) + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + return out + + +class RandomLatentConverter(nn.Module): + def __init__(self, channels): + super().__init__() + self.layers = nn.Sequential( + *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels) + ) + self.channels = channels + + def forward(self, ref): + r = torch.randn(ref.shape[0], self.channels, device=ref.device) + y = self.layers(r) + return y + + +if __name__ == "__main__": + model = RandomLatentConverter(512) + model(torch.randn(5, 512)) diff --git a/viXTTS/TTS/tts/layers/tortoise/tokenizer.py b/viXTTS/TTS/tts/layers/tortoise/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d243d6558d0dfcbfee59769f991ce5ad9a603678 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/tokenizer.py @@ -0,0 +1,37 @@ +import os + +import torch +from tokenizers import Tokenizer + +from TTS.tts.utils.text.cleaners import english_cleaners + +DEFAULT_VOCAB_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json" +) + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None): + self.tokenizer = None + if vocab_file is not None: + self.tokenizer = Tokenizer.from_file(vocab_file) + if vocab_str is not None: + self.tokenizer = Tokenizer.from_str(vocab_str) + + def preprocess_text(self, txt): + txt = english_cleaners(txt) + return txt + + def encode(self, txt): + txt = self.preprocess_text(txt) + txt = txt.replace(" ", "[SPACE]") + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") + txt = txt.replace("[SPACE]", " ") + txt = txt.replace("[STOP]", "") + txt = txt.replace("[UNK]", "") + return txt diff --git a/viXTTS/TTS/tts/layers/tortoise/transformer.py b/viXTTS/TTS/tts/layers/tortoise/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..70d46aa3e03626d8123700a5c2541d2d1a7314b4 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/transformer.py @@ -0,0 +1,229 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, depth=1): + if isinstance(val, list): + val = tuple(val) + return val if isinstance(val, tuple) else (val,) * depth + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def stable_softmax(t, dim=-1, alpha=32**2): + t = t / alpha + t = t - torch.amax(t, dim=dim, keepdim=True).detach() + return (t * alpha).softmax(dim=dim) + + +def route_args(router, args, depth): + routed_args = [(dict(), dict()) for _ in range(depth)] + matched_keys = [key for key in args.keys() if key in router] + + for key in matched_keys: + val = args[key] + for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): + new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) + routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + return routed_args + + +# classes +class SequentialSequence(nn.Module): + def __init__(self, layers, args_route={}, layer_dropout=0.0): + super().__init__() + assert all( + len(route) == len(layers) for route in args_route.values() + ), "each argument route map must have the same depth as the number of sequential layers" + self.layers = layers + self.args_route = args_route + self.layer_dropout = layer_dropout + + def forward(self, x, **kwargs): + args = route_args(self.args_route, kwargs, len(self.layers)) + layers_and_args = list(zip(self.layers, args)) + + for (f, g), (f_args, g_args) in layers_and_args: + x = x + f(x, **f_args) + x = x + g(x, **g_args) + return x + + +class DivideMax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + maxes = x.amax(dim=self.dim, keepdim=True).detach() + return x / maxes + + +# https://arxiv.org/abs/2103.17239 +class LayerScale(nn.Module): + def __init__(self, dim, depth, fn): + super().__init__() + if depth <= 18: + init_eps = 0.1 + elif depth > 18 and depth <= 24: + init_eps = 1e-5 + else: + init_eps = 1e-6 + + scale = torch.zeros(1, 1, dim).fill_(init_eps) + self.scale = nn.Parameter(scale) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + + +# layer norm + + +class PreNorm(nn.Module): + def __init__(self, dim, fn, sandwich=False): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + x = self.fn(x, **kwargs) + return self.norm_out(x) + + +# feed forward + + +class GEGLU(nn.Module): + def forward(self, x): + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0.0, mult=4.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim), + ) + + def forward(self, x): + return self.net(x) + + +# Attention + + +class Attention(nn.Module): + def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.seq_len = seq_len + self.scale = dim_head**-0.5 + + self.causal = causal + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + def forward(self, x, mask=None): + b, n, _, h, device = *x.shape, self.heads, x.device + softmax = torch.softmax + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) + + q = q * self.scale + + dots = torch.einsum("b h i d, b h j d -> b h i j", q, k) + mask_value = max_neg_value(dots) + + if exists(mask): + mask = rearrange(mask, "b j -> b () () j") + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() + dots.masked_fill_(mask, mask_value) + + attn = softmax(dots, dim=-1) + + out = torch.einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return out + + +# main transformer class +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + seq_len, + causal=True, + heads=8, + dim_head=64, + ff_mult=4, + attn_dropout=0.0, + ff_dropout=0.0, + sparse_attn=False, + sandwich_norm=False, + ): + super().__init__() + layers = nn.ModuleList([]) + sparse_layer = cast_tuple(sparse_attn, depth) + + for ind, sparse_attn in zip(range(depth), sparse_layer): + attn = Attention( + dim, + causal=causal, + seq_len=seq_len, + heads=heads, + dim_head=dim_head, + dropout=attn_dropout, + ) + + ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) + + layers.append( + nn.ModuleList( + [ + LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)), + ] + ) + ) + + execute_type = SequentialSequence + route_attn = ((True, False),) * depth + attn_route_map = {"mask": route_attn} + + self.layers = execute_type(layers, args_route=attn_route_map) + + def forward(self, x, **kwargs): + return self.layers(x, **kwargs) diff --git a/viXTTS/TTS/tts/layers/tortoise/utils.py b/viXTTS/TTS/tts/layers/tortoise/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..810a9e7f7a8ab4a6a48974367020961f9a9967f4 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/utils.py @@ -0,0 +1,46 @@ +import os +from urllib import request + +from tqdm import tqdm + +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") +MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) +MODELS_DIR = "/data/speech_synth/models/" +MODELS = { + "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", + "classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth", + "clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth", + "diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth", + "vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth", + "rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth", + "rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth", +} + + +def download_models(specific_models=None): + """ + Call to download all the models that Tortoise uses. + """ + os.makedirs(MODELS_DIR, exist_ok=True) + for model_name, url in MODELS.items(): + if specific_models is not None and model_name not in specific_models: + continue + model_path = os.path.join(MODELS_DIR, model_name) + if os.path.exists(model_path): + continue + print(f"Downloading {model_name} from {url}...") + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n)) + print("Done.") + + +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f"Model {model_name} not found in available models.") + model_path = os.path.join(models_dir, model_name) + if not os.path.exists(model_path) and models_dir == MODELS_DIR: + download_models([model_name]) + return model_path diff --git a/viXTTS/TTS/tts/layers/tortoise/vocoder.py b/viXTTS/TTS/tts/layers/tortoise/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a5200c26738b55273a74c86e4308a6bd6783f5d7 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/vocoder.py @@ -0,0 +1,405 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize + +MAX_WAV_VALUE = 32768.0 + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.parametrizations.weight_norm( + nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + parametrize.remove_parametrizations(self.input_conv[0], "weight") + parametrize.remove_parametrizations(self.kernel_conv, "weight") + parametrize.remove_parametrizations(self.bias_conv) + for block in self.residual_convs: + parametrize.remove_parametrizations(block[1], "weight") + parametrize.remove_parametrizations(block[3], "weight") + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + in_channels, + in_channels, + conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, + dilation=dilation, + ) + ), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + parametrize.remove_parametrizations(self.convt_pre[1], "weight") + for block in self.conv_blocks: + parametrize.remove_parametrizations(block[1], "weight") + + +class UnivNetGenerator(nn.Module): + """ + UnivNet Generator + + Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py. + """ + + def __init__( + self, + noise_dim=64, + channel_size=32, + dilations=[1, 3, 9, 27], + strides=[8, 8, 4], + lReLU_slope=0.2, + kpnet_conv_size=3, + # Below are MEL configurations options that this generator requires. + hop_length=256, + n_mel_channels=100, + ): + super(UnivNetGenerator, self).__init__() + self.mel_channel = n_mel_channels + self.noise_dim = noise_dim + self.hop_length = hop_length + channel_size = channel_size + kpnet_conv_size = kpnet_conv_size + + self.res_stack = nn.ModuleList() + hop_length = 1 + for stride in strides: + hop_length = stride * hop_length + self.res_stack.append( + LVCBlock( + channel_size, + n_mel_channels, + stride=stride, + dilations=dilations, + lReLU_slope=lReLU_slope, + cond_hop_length=hop_length, + kpnet_conv_size=kpnet_conv_size, + ) + ) + + self.conv_pre = nn.utils.parametrizations.weight_norm( + nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect") + ) + + self.conv_post = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), + nn.Tanh(), + ) + + def forward(self, c, z): + """ + Args: + c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) + z (Tensor): the noise sequence (batch, noise_dim, in_length) + + """ + z = self.conv_pre(z) # (B, c_g, L) + + for res_block in self.res_stack: + res_block.to(z.device) + z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) + + z = self.conv_post(z) # (B, 1, L * 256) + + return z + + def eval(self, inference=False): + super(UnivNetGenerator, self).eval() + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def remove_weight_norm(self): + parametrize.remove_parametrizations(self.conv_pre, "weight") + + for layer in self.conv_post: + if len(layer.state_dict()) != 0: + parametrize.remove_parametrizations(layer, "weight") + + for res_block in self.res_stack: + res_block.remove_weight_norm() + + def inference(self, c, z=None): + # pad input mel with zeros to cut artifact + # see https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) + mel = torch.cat((c, zero), dim=2) + + if z is None: + z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) + + audio = self.forward(mel, z) + audio = audio[:, :, : -(self.hop_length * 10)] + audio = audio.clamp(min=-1, max=1) + return audio + + +@dataclass +class VocType: + constructor: Callable[[], nn.Module] + model_path: str + subkey: Optional[str] = None + + def optionally_index(self, model_dict): + if self.subkey is not None: + return model_dict[self.subkey] + return model_dict + + +class VocConf(Enum): + Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g") + + +if __name__ == "__main__": + model = UnivNetGenerator() + + c = torch.randn(3, 100, 10) + z = torch.randn(3, 64, 10) + print(c.shape) + + y = model(c, z) + print(y.shape) + assert y.shape == torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) diff --git a/viXTTS/TTS/tts/layers/tortoise/wav2vec_alignment.py b/viXTTS/TTS/tts/layers/tortoise/wav2vec_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..47456cc5ac41b7ed9522fe543affc8482218730c --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/wav2vec_alignment.py @@ -0,0 +1,150 @@ +import torch +import torchaudio +from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC + + +def max_alignment(s1, s2, skip_character="~", record=None): + """ + A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is + used to replace that character. + + Finally got to use my DP skills! + """ + if record is None: + record = {} + assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" + if len(s1) == 0: + return "" + if len(s2) == 0: + return skip_character * len(s1) + if s1 == s2: + return s1 + if s1[0] == s2[0]: + return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) + + take_s1_key = (len(s1), len(s2) - 1) + if take_s1_key in record: + take_s1, take_s1_score = record[take_s1_key] + else: + take_s1 = max_alignment(s1, s2[1:], skip_character, record) + take_s1_score = len(take_s1.replace(skip_character, "")) + record[take_s1_key] = (take_s1, take_s1_score) + + take_s2_key = (len(s1) - 1, len(s2)) + if take_s2_key in record: + take_s2, take_s2_score = record[take_s2_key] + else: + take_s2 = max_alignment(s1[1:], s2, skip_character, record) + take_s2_score = len(take_s2.replace(skip_character, "")) + record[take_s2_key] = (take_s2, take_s2_score) + + return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 + + +class Wav2VecAlignment: + """ + Uses wav2vec2 to perform audio<->text alignment. + """ + + def __init__(self, device="cuda"): + self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h") + self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols") + self.device = device + + def align(self, audio, expected_text, audio_sample_rate=24000): + orig_len = audio.shape[-1] + + with torch.no_grad(): + self.model = self.model.to(self.device) + audio = audio.to(self.device) + audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) + clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + logits = self.model(clip_norm).logits + self.model = self.model.cpu() + + logits = logits[0] + pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) + + fixed_expectation = max_alignment(expected_text.lower(), pred_string) + w2v_compression = orig_len // logits.shape[0] + expected_tokens = self.tokenizer.encode(fixed_expectation) + expected_chars = list(fixed_expectation) + if len(expected_tokens) == 1: + return [0] # The alignment is simple; there is only one token. + expected_tokens.pop(0) # The first token is a given. + expected_chars.pop(0) + + alignments = [0] + + def pop_till_you_win(): + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + while popped_char == "~": + alignments.append(-1) + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + return popped + + next_expected_token = pop_till_you_win() + for i, logit in enumerate(logits): + top = logit.argmax() + if next_expected_token == top: + alignments.append(i * w2v_compression) + if len(expected_tokens) > 0: + next_expected_token = pop_till_you_win() + else: + break + + pop_till_you_win() + if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)): + torch.save([audio, expected_text], "alignment_debug.pth") + assert False, ( + "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" + "your current working directory. Please report this along with the file so it can get fixed." + ) + + # Now fix up alignments. Anything with -1 should be interpolated. + alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. + for i in range(len(alignments)): + if alignments[i] == -1: + for j in range(i + 1, len(alignments)): + if alignments[j] != -1: + next_found_token = j + break + for j in range(i, next_found_token): + gap = alignments[next_found_token] - alignments[i - 1] + alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1] + + return alignments[:-1] + + def redact(self, audio, expected_text, audio_sample_rate=24000): + if "[" not in expected_text: + return audio + splitted = expected_text.split("[") + fully_split = [splitted[0]] + for spl in splitted[1:]: + assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.' + fully_split.extend(spl.split("]")) + + # At this point, fully_split is a list of strings, with every other string being something that should be redacted. + non_redacted_intervals = [] + last_point = 0 + for i in range(len(fully_split)): + if i % 2 == 0: + end_interval = max(0, last_point + len(fully_split[i]) - 1) + non_redacted_intervals.append((last_point, end_interval)) + last_point += len(fully_split[i]) + + bare_text = "".join(fully_split) + alignments = self.align(audio, bare_text, audio_sample_rate) + + output_audio = [] + for nri in non_redacted_intervals: + start, stop = nri + output_audio.append(audio[:, alignments[start] : alignments[stop]]) + return torch.cat(output_audio, dim=-1) diff --git a/viXTTS/TTS/tts/layers/tortoise/xtransformers.py b/viXTTS/TTS/tts/layers/tortoise/xtransformers.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb3f77269c0e7b718d350217796ec704543c681 --- /dev/null +++ b/viXTTS/TTS/tts/layers/tortoise/xtransformers.py @@ -0,0 +1,1259 @@ +import math +from collections import namedtuple +from functools import partial +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) + +LayerIntermediates = namedtuple( + "Intermediates", + [ + "hiddens", + "attn_intermediates", + "past_key_values", + ], +) + + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cast_tuple(val, depth): + return val if isinstance(val, tuple) else (val,) * depth + + +class always: + def __init__(self, val): + self.val = val + + def __call__(self, *args, **kwargs): + return self.val + + +class not_equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x != self.val + + +class equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x == self.val + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +# init helpers + + +def init_zero_(layer): + nn.init.constant_(layer.weight, 0.0) + if exists(layer.bias): + nn.init.constant_(layer.bias, 0.0) + + +# keyword argument helpers + + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# activations + + +class ReluSquared(nn.Module): + def forward(self, x): + return F.relu(x) ** 2 + + +# positional embeddings + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim**-0.5 + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + pos_emb = self.emb(n) + pos_emb = rearrange(pos_emb, "n d -> () n d") + return pos_emb * self.scale + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return rearrange(emb, "n d -> () n d") + + +class RelativePositionBias(nn.Module): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): + super().__init__() + self.scale = scale + self.causal = causal + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + if not causal: + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = ( + max_exact + + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() + ) + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, qk_dots): + i, j, device = *qk_dots.shape[-2:], qk_dots.device + q_pos = torch.arange(i, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) + rel_pos = k_pos[None, :] - q_pos[:, None] + rp_bucket = self._relative_position_bucket( + rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance + ) + values = self.relative_attention_bias(rp_bucket) + bias = rearrange(values, "i j h -> () h i j") + return qk_dots + (bias * self.scale) + + +class AlibiPositionalBias(nn.Module): + def __init__(self, heads, **kwargs): + super().__init__() + self.heads = heads + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, "h -> () h () ()") + self.register_buffer("slopes", slopes, persistent=False) + self.register_buffer("bias", None, persistent=False) + + @staticmethod + def _get_slopes(heads): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2 ** math.floor(math.log2(heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: heads - closest_power_of_2] + ) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + if exists(self.bias) and self.bias.shape[-1] >= j: + return qk_dots + self.bias[..., :j] + + bias = torch.arange(j, device=device) + bias = rearrange(bias, "j -> () () () j") + bias = bias * self.slopes + + num_heads_unalibied = h - bias.shape[1] + bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) + + self.register_buffer("bias", bias, persistent=False) + return qk_dots + self.bias + + +class LearnedAlibiPositionalBias(AlibiPositionalBias): + def __init__(self, heads, bidirectional=False): + super().__init__(heads) + los_slopes = torch.log(self.slopes) + self.learned_logslopes = nn.Parameter(los_slopes) + + self.bidirectional = bidirectional + if self.bidirectional: + self.learned_logslopes_future = nn.Parameter(los_slopes) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + def get_slopes(param): + return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1])) + + if exists(self.bias) and self.bias.shape[-1] >= j: + bias = self.bias[..., :i, :j] + else: + i_arange = torch.arange(i, device=device) + j_arange = torch.arange(j, device=device) + bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(i_arange, "i -> 1 1 i 1") + self.register_buffer("bias", bias, persistent=False) + + if self.bidirectional: + past_slopes = get_slopes(self.learned_logslopes) + future_slopes = get_slopes(self.learned_logslopes_future) + bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes) + else: + slopes = get_slopes(self.learned_logslopes) + bias = bias * slopes + + return qk_dots + bias + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, device): + t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return rearrange(emb, "n d -> () () n d") + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + seq_len = t.shape[-2] + freqs = freqs[:, :, -seq_len:] + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + + +# norms + + +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + scale_fn = lambda t: t * self.value + + if not isinstance(out, tuple): + return scale_fn(out) + + return (scale_fn(out[0]), *out[1:]) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + rezero_fn = lambda t: t * self.g + + if not isinstance(out, tuple): + return rezero_fn(out) + + return (rezero_fn(out[0]), *out[1:]) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSScaleShiftNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + self.scale_shift_process = nn.Linear(dim * 2, dim * 2) + + def forward(self, x, norm_scale_shift_inp): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + norm = x / norm.clamp(min=self.eps) * self.g + + ss_emb = self.scale_shift_process(norm_scale_shift_inp) + scale, shift = torch.chunk(ss_emb, 2, dim=1) + h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return h + + +# residual and residual gates + + +class Residual(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + gated_output = self.gru(rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")) + + return gated_output.reshape_as(x) + + +# token shifting + + +def shift(t, amount, mask=None): + if amount == 0: + return t + + if exists(mask): + t = t.masked_fill(~mask[..., None], 0.0) + + return F.pad(t, (0, 0, amount, -amount), value=0.0) + + +class ShiftTokens(nn.Module): + def __init__(self, shifts, fn): + super().__init__() + self.fn = fn + self.shifts = tuple(shifts) + + def forward(self, x, **kwargs): + mask = kwargs.get("mask", None) + shifts = self.shifts + segments = len(shifts) + feats_per_shift = x.shape[-1] // segments + splitted = x.split(feats_per_shift, dim=-1) + segments_to_shift, rest = splitted[:segments], splitted[segments:] + segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts))) + x = torch.cat((*segments_to_shift, *rest), dim=-1) + return self.fn(x, **kwargs) + + +# feedforward + + +class GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + glu=False, + relu_squared=False, + post_act_ln=False, + dropout=0.0, + zero_init_output=False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + activation = ReluSquared() if relu_squared else nn.GELU() + + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), activation) if not glu else GLU(dim, inner_dim, activation) + ) + + self.net = nn.Sequential( + project_in, + nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out), + ) + + # init last linear layer to 0 + if zero_init_output: + init_zero_(self.net[-1]) + + def forward(self, x): + return self.net(x) + + +# attention. + + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + talking_heads=False, + head_scale=False, + collab_heads=False, + collab_compression=0.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + scale_init_value=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + ): + super().__init__() + self.scale = dim_head**-0.5 + + self.heads = heads + self.causal = causal + self.max_attend_past = max_attend_past + + qk_dim = v_dim = dim_head * heads + + # collaborative heads + self.collab_heads = collab_heads + if self.collab_heads: + qk_dim = int(collab_compression * qk_dim) + self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) + + self.to_q = nn.Linear(dim, qk_dim, bias=False) + self.to_k = nn.Linear(dim, qk_dim, bias=False) + self.to_v = nn.Linear(dim, v_dim, bias=False) + + self.dropout = nn.Dropout(dropout) + + # add GLU gating for aggregated values, from alphafold2 + self.to_v_gate = None + if gate_values: + self.to_v_gate = nn.Linear(dim, v_dim) + nn.init.constant_(self.to_v_gate.weight, 0) + nn.init.constant_(self.to_v_gate.bias, 1) + + # cosine sim attention + self.qk_norm = qk_norm + if qk_norm: + scale_init_value = default( + scale_init_value, -3 + ) # if not provided, initialize as though it were sequence length of 1024 + self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # head scaling + self.head_scale = head_scale + if head_scale: + self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim) + + self.rel_pos_bias = rel_pos_bias + if rel_pos_bias: + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" + self.rel_pos = RelativePositionBias( + scale=dim_head**0.5, + causal=causal, + heads=heads, + num_buckets=rel_pos_num_buckets, + max_distance=rel_pos_max_distance, + ) + + # init output projection 0 + if zero_init_output: + init_zero_(self.to_out) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, + layer_past=None, + ): + b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = ( + *x.shape, + self.heads, + self.talking_heads, + self.collab_heads, + self.head_scale, + self.scale, + x.device, + exists(context), + ) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + if not collab_heads: + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + else: + q = einsum("b i d, h d -> b h i d", q, self.collab_mixing) + k = rearrange(k, "b n d -> b () n d") + v = rearrange(v, "b n (h d) -> b h n d", h=h) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat([past_key, k], dim=-2) + v = torch.cat([past_value, v], dim=-2) + k_cache = k + v_cache = v + + if exists(rotary_pos_emb) and not has_context: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl)) + q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr))) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + if collab_heads: + k = k.expand(-1, h, -1, -1) + + if self.qk_norm: + q, k = map(l2norm, (q, k)) + scale = 1 / (self.scale.exp().clamp(min=1e-2)) + + dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots.clone() + + if talking_heads: + dots = einsum("b h i j, h k -> b k i j", dots, self.pre_softmax_proj).contiguous() + + if self.rel_pos_bias: + dots = self.rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if exists(attn_mask): + assert ( + 2 <= attn_mask.ndim <= 4 + ), "attention mask must have greater than 2 dimensions but less than or equal to 4" + if attn_mask.ndim == 2: + attn_mask = rearrange(attn_mask, "i j -> () () i j") + elif attn_mask.ndim == 3: + attn_mask = rearrange(attn_mask, "h i j -> () h i j") + dots.masked_fill_(~attn_mask, mask_value) + + if exists(self.max_attend_past): + i, j = dots.shape[-2:] + range_q = torch.arange(j - i, j, device=device) + range_k = torch.arange(j, device=device) + dist = rearrange(range_q, "i -> () () i ()") - rearrange(range_k, "j -> () () () j") + mask = dist > self.max_attend_past + dots.masked_fill_(mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn.clone() + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum("b h i j, h k -> b k i j", attn, self.post_softmax_proj).contiguous() + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + if head_scale: + out = out * self.head_scale_params + + out = rearrange(out, "b h n d -> b n (h d)") + + if exists(self.to_v_gate): + gates = self.to_v_gate(x) + out = out * gates.sigmoid() + + intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn) + + return self.to_out(out), intermediates, k_cache, v_cache + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rms_scaleshift_norm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + sandwich_norm=False, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + **kwargs, + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) + + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + self.causal = causal + + rel_pos_bias = "rel_pos_bias" in attn_kwargs + self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + + rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) + self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + + assert not ( + alibi_pos_bias and rel_pos_bias + ), "you can only choose Alibi positional bias or T5 relative positional bias, not both" + + if alibi_pos_bias: + alibi_num_heads = default(alibi_num_heads, heads) + assert alibi_num_heads <= heads, "number of ALiBi heads must be less than the total number of heads" + alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias + self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) + else: + self.rel_pos = None + + assert not (not pre_norm and sandwich_norm), "sandwich norm cannot be used when not using prenorm" + self.pre_norm = pre_norm + self.sandwich_norm = sandwich_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + self.cross_attend = cross_attend + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ("a", "c", "f") + elif cross_attend and only_cross: + default_block = ("c", "f") + else: + default_block = ("a", "f") + + if macaron: + default_block = ("f",) + default_block + + # qk normalization + + if use_qk_norm_attn: + attn_scale_init_value = ( + -math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len)) + if exists(qk_norm_attn_seq_len) + else None + ) + attn_kwargs = {**attn_kwargs, "qk_norm": True, "scale_init_value": attn_scale_init_value} + + # zero init + + if zero_init_branch_output: + attn_kwargs = {**attn_kwargs, "zero_init_output": True} + ff_kwargs = {**ff_kwargs, "zero_init_output": True} + + # calculate layer block order + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ("f",) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, "sandwich coefficient should be less than the depth" + layer_types = ("a",) * sandwich_coef + default_block * (depth - sandwich_coef) + ("f",) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) + + # calculate token shifting + + shift_tokens = cast_tuple(shift_tokens, len(layer_types)) + + # iterate and construct layers + + for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): + is_last_layer = ind == (len(self.layer_types) - 1) + + if layer_type == "a": + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == "c": + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == "f": + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f"invalid layer type {layer_type}") + + if layer_shift_tokens > 0: + shift_range_upper = layer_shift_tokens + 1 + shift_range_lower = -layer_shift_tokens if not causal else 0 + layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + + if exists(branch_fn): + layer = branch_fn(layer) + + residual_fn = GRUGating if gate_residual else Residual + residual = residual_fn(dim, scale_residual=scale_residual) + + layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c") + + pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None + post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None + post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None + + norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm]) + + self.layers.append(nn.ModuleList([norms, layer, residual])) + + def forward( + self, + x, + context=None, + full_context=None, # for passing a list of hidden states from an encoder + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False, + norm_scale_shift_inp=None, + past_key_values=None, + expected_seq_len=None, + ): + assert not ( + self.cross_attend ^ (exists(context) or exists(full_context)) + ), "context must be passed in if cross_attend is set to True" + assert context is None or full_context is None, "only one of full_context or context can be provided" + + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + norm_args = {} + if exists(norm_scale_shift_inp): + norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp + + rotary_pos_emb = None + if exists(self.rotary_pos_emb): + if not self.training and self.causal: + assert ( + expected_seq_len is not None + ), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" + elif expected_seq_len is None: + expected_seq_len = 0 + seq_len = x.shape[1] + if past_key_values is not None: + seq_len += past_key_values[0][0].shape[-2] + max_rotary_emb_length = max( + list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len] + ) + rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + + present_key_values = [] + cross_attn_count = 0 + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + if layer_type == "a": + layer_mem = mems.pop(0) if mems else None + + residual = x + + pre_branch_norm, post_branch_norm, post_main_norm = norm + + if exists(pre_branch_norm): + x = pre_branch_norm(x, **norm_args) + + if layer_type == "a" or layer_type == "c": + if past_key_values is not None: + layer_kv = past_key_values.pop(0) + layer_past = tuple(s.to(x.device) for s in layer_kv) + else: + layer_past = None + + if layer_type == "a": + out, inter, k, v = block( + x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past + ) + elif layer_type == "c": + if exists(full_context): + out, inter, k, v = block( + x, + full_context[cross_attn_count], + mask, + context_mask, + None, + None, + None, + prev_attn, + None, + layer_past, + ) + else: + out, inter, k, v = block( + x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past + ) + elif layer_type == "f": + out = block(x) + + if layer_type == "a" or layer_type == "c" and present_key_values is not None: + present_key_values.append((k.detach(), v.detach())) + + if exists(post_branch_norm): + out = post_branch_norm(out, **norm_args) + + x = residual_fn(out, residual) + + if layer_type in ("a", "c"): + intermediates.append(inter) + + if layer_type == "a" and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == "c" and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if exists(post_main_norm): + x = post_main_norm(x, **norm_args) + + if layer_type == "c": + cross_attn_count += 1 + + if layer_type == "f": + hiddens.append(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, attn_intermediates=intermediates, past_key_values=present_key_values + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on decoder" + super().__init__(causal=True, **kwargs) + + +class CrossAttender(AttentionLayers): + def __init__(self, **kwargs): + super().__init__(cross_attend=True, only_cross=True, **kwargs) + + +class ViTransformerWrapper(nn.Module): + def __init__(self, *, image_size, patch_size, attn_layers, num_classes=None, dropout=0.0, emb_dropout=0.0): + super().__init__() + assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder" + assert image_size % patch_size == 0, "image dimensions must be divisible by the patch size" + dim = attn_layers.dim + num_patches = (image_size // patch_size) ** 2 + patch_dim = 3 * patch_size**2 + + self.patch_size = patch_size + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None + + def forward(self, img, return_embeddings=False): + p = self.patch_size + + x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) + x = self.patch_to_embedding(x) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embedding[:, : (n + 1)] + x = self.dropout(x) + + x = self.attn_layers(x) + x = self.norm(x) + + if not exists(self.mlp_head) or return_embeddings: + return x + + return self.mlp_head(x[:, 0]) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + shift_mem_down=0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.shift_mem_down = shift_mem_down + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + def init_(self): + nn.init.kaiming_normal_(self.token_emb.weight) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_hiddens=False, + return_attn=False, + mems=None, + use_cache=False, + **kwargs, + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + if self.shift_mem_down and exists(mems): + mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :] + mems = [*mems_r, *mems_l] + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_hiddens: + hiddens = intermediates.hiddens + return out, hiddens + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] + + +class ContinuousTransformerWrapper(nn.Module): + def __init__( + self, *, max_seq_len, attn_layers, dim_in=None, dim_out=None, emb_dim=None, emb_dropout=0.0, use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + + self.max_seq_len = max_seq_len + + self.pos_emb = ( + AbsolutePositionalEmbedding(dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + + def forward(self, x, return_embeddings=False, mask=None, return_attn=False, mems=None, use_cache=False, **kwargs): + b, n, _, device = *x.shape, x.device + + x = self.project_in(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + out = self.project_out(x) if not return_embeddings else x + + res = [out] + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] diff --git a/viXTTS/TTS/tts/layers/vits/discriminator.py b/viXTTS/TTS/tts/layers/vits/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..c27d11bef632d02169aba7db9a9b38eba32c76a5 --- /dev/null +++ b/viXTTS/TTS/tts/layers/vits/discriminator.py @@ -0,0 +1,89 @@ +import torch +from torch import nn +from torch.nn.modules.conv import Conv1d + +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, 0.1) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class VitsDiscriminator(nn.Module): + """VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator. + + :: + waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats + |--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^ + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): + super().__init__() + self.nets = nn.ModuleList() + self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) + self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) + + def forward(self, x, x_hat=None): + """ + Args: + x (Tensor): ground truth waveform. + x_hat (Tensor): predicted waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + x_scores = [] + x_hat_scores = [] if x_hat is not None else None + x_feats = [] + x_hat_feats = [] if x_hat is not None else None + for net in self.nets: + x_score, x_feat = net(x) + x_scores.append(x_score) + x_feats.append(x_feat) + if x_hat is not None: + x_hat_score, x_hat_feat = net(x_hat) + x_hat_scores.append(x_hat_score) + x_hat_feats.append(x_hat_feat) + return x_scores, x_feats, x_hat_scores, x_hat_feats diff --git a/viXTTS/TTS/tts/layers/vits/networks.py b/viXTTS/TTS/tts/layers/vits/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..f97b584fe6ed311127a8c01a089b159946219cb2 --- /dev/null +++ b/viXTTS/TTS/tts/layers/vits/networks.py @@ -0,0 +1,288 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.glow import WN +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.helpers import sequence_mask + +LRELU_SLOPE = 0.1 + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size: int, + dropout_p: float, + language_emb_dim: int = None, + ): + """Text Encoder for VITS model. + + Args: + n_vocab (int): Number of characters for the embedding layer. + out_channels (int): Number of channels for the output. + hidden_channels (int): Number of channels for the hidden layers. + hidden_channels_ffn (int): Number of channels for the convolutional layers. + num_heads (int): Number of attention heads for the Transformer layers. + num_layers (int): Number of Transformer layers. + kernel_size (int): Kernel size for the FFN layers in Transformer network. + dropout_p (float): Dropout rate for the Transformer layers. + """ + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + + self.emb = nn.Embedding(n_vocab, hidden_channels) + + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + if language_emb_dim: + hidden_channels += language_emb_dim + + self.encoder = RelativePositionTransformer( + in_channels=hidden_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + hidden_channels_ffn=hidden_channels_ffn, + num_heads=num_heads, + num_layers=num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + layer_norm_type="2", + rel_attn_window_size=4, + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, lang_emb=None): + """ + Shapes: + - x: :math:`[B, T]` + - x_length: :math:`[B]` + """ + assert x.shape[0] == x_lengths.shape[0] + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + + # concat the lang emb in embedding chars + if lang_emb is not None: + x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=0, + cond_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.half_channels = channels // 2 + self.mean_only = mean_only + # input layer + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + # coupling layers + self.enc = WN( + hidden_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=dropout_p, + c_in_channels=cond_channels, + ) + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + log_scale = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(log_scale) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(log_scale, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-log_scale) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ResidualCouplingBlocks(nn.Module): + def __init__( + self, + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + num_flows=4, + cond_channels=0, + ): + """Redisual Coupling blocks for VITS flow layers. + + Args: + channels (int): Number of input and output tensor channels. + hidden_channels (int): Number of hidden network channels. + kernel_size (int): Kernel size of the WaveNet layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. + """ + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.num_flows = num_flows + self.cond_channels = cond_channels + + self.flows = nn.ModuleList() + for _ in range(num_flows): + self.flows.append( + ResidualCouplingBlock( + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + cond_channels=cond_channels, + mean_only=True, + ) + ) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + x = torch.flip(x, [1]) + else: + for flow in reversed(self.flows): + x = torch.flip(x, [1]) + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + cond_channels=0, + ): + """Posterior Encoder of VITS model. + + :: + x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of output tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of the WaveNet convolution layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.cond_channels = cond_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B, 1]` + - g: :math:`[B, C, 1]` + """ + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + mean, log_scale = torch.split(stats, self.out_channels, dim=1) + z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + return z, mean, log_scale, x_mask diff --git a/viXTTS/TTS/tts/layers/vits/stochastic_duration_predictor.py b/viXTTS/TTS/tts/layers/vits/stochastic_duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..98dbf0935ca0f6cd6e92fe6ecf063dde2ee4138f --- /dev/null +++ b/viXTTS/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -0,0 +1,294 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform + + +class DilatedDepthSeparableConv(nn.Module): + def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor: + """Dilated Depth-wise Separable Convolution module. + + :: + x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o + |-------------------------------------------------------------------------------------^ + + Args: + channels ([type]): [description] + kernel_size ([type]): [description] + num_layers ([type]): [description] + dropout_p (float, optional): [description]. Defaults to 0.0. + + Returns: + torch.tensor: Network output masked by the input sequence mask. + """ + super().__init__() + self.num_layers = num_layers + + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(num_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm2(channels)) + self.norms_2.append(LayerNorm2(channels)) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + if g is not None: + x = x + g + for i in range(self.num_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.dropout(y) + x = x + y + return x * x_mask + + +class ElementwiseAffine(nn.Module): + """Element-wise affine transform like no-population stats BatchNorm alternative. + + Args: + channels (int): Number of input tensor channels. + """ + + def __init__(self, channels): + super().__init__() + self.translation = nn.Parameter(torch.zeros(channels, 1)) + self.log_scale = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument + if not reverse: + y = (x * torch.exp(self.log_scale) + self.translation) * x_mask + logdet = torch.sum(self.log_scale * x_mask, [1, 2]) + return y, logdet + x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask + return x + + +class ConvFlow(nn.Module): + """Dilated depth separable convolutional based spline flow. + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of in network channels. + kernel_size (int): Convolutional kernel size. + num_layers (int): Number of convolutional layers. + num_bins (int, optional): Number of spline bins. Defaults to 10. + tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + num_layers: int, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.num_bins = num_bins + self.tail_bound = tail_bound + self.hidden_channels = hidden_channels + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0) + self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + return x + + +class StochasticDurationPredictor(nn.Module): + """Stochastic duration predictor with Spline Flows. + + It applies Variational Dequantization and Variational Data Augmentation. + + Paper: + SDP: https://arxiv.org/pdf/2106.06103.pdf + Spline Flow: https://arxiv.org/abs/1906.04032 + + :: + ## Inference + + x -> TextCondEncoder() -> Flow() -> dr_hat + noise ----------------------^ + + ## Training + |---------------------| + x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise + d -> DurCondEncoder() -> ^ | + |------------------------------------------------------------------------------| + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of convolutional layers. + dropout_p (float): Dropout rate. + num_flows (int, optional): Number of flow blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + dropout_p: float, + num_flows=4, + cond_channels=0, + language_emb_dim=0, + ): + super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + + # condition encoder text + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # posterior encoder + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + # condition encoder duration + self.post_pre = nn.Conv1d(1, hidden_channels, 1) + self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # flow layers + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + if cond_channels != 0 and cond_channels is not None: + self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) + + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1) + + def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - dr: :math:`[B, 1, T]` + - g: :math:`[B, C]` + """ + # condition encoder text + x = self.pre(x) + if g is not None: + x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert dr is not None + + # condition encoder duration + h = self.post_pre(dr) + h = self.post_convs(h, x_mask) + h = self.post_proj(h) * x_mask + noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = noise + + # posterior encoder + logdet_tot_q = 0.0 + for idx, flow in enumerate(self.post_flows): + z_q, logdet_q = flow(z_q, x_mask, g=(x + h)) + logdet_tot_q = logdet_tot_q + logdet_q + if idx > 0: + z_q = torch.flip(z_q, [1]) + + z_u, z_v = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (dr - u) * x_mask + + # posterior encoder - neg log likelihood + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + nll_posterior_encoder = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q + ) + + z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask + logdet_tot = torch.sum(-z0, [1, 2]) + z = torch.cat([z0, z_v], 1) + + # flow layers + for idx, flow in enumerate(flows): + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + if idx > 0: + z = torch.flip(z, [1]) + + # flow layers - neg log likelihood + nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot + return nll_flow_layers + nll_posterior_encoder + + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = torch.flip(z, [1]) + z = flow(z, x_mask, g=x, reverse=reverse) + + z0, _ = torch.split(z, [1, 1], 1) + logw = z0 + return logw diff --git a/viXTTS/TTS/tts/layers/vits/transforms.py b/viXTTS/TTS/tts/layers/vits/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3cac1b8d6d12fe98123ca554899978782cf3b4c5 --- /dev/null +++ b/viXTTS/TTS/tts/layers/vits/transforms.py @@ -0,0 +1,202 @@ +# adopted from https://github.com/bayesiains/nflows + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/gpt.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/gpt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa644796358d4b902005fe7625b884d27626a06 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/gpt.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/gpt_inference.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/gpt_inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06164822bcd24559872cc11b359cfa29a7060f0c Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/gpt_inference.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/hifigan_decoder.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/hifigan_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0756acf453ce82d3ae05114d8e455a6ac27c77f1 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/hifigan_decoder.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/latent_encoder.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/latent_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f925b658c81ec15d00948cadee500929095e1f89 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/latent_encoder.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/perceiver_encoder.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/perceiver_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6552b2ccf7a54d7ca3a576dc330a5ee6e6c6a739 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/perceiver_encoder.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/stream_generator.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/stream_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfc4a6b5ccb4cefa4c2496f3e15d1100733d7375 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/stream_generator.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/tokenizer.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e5f45a86fc2a97ffbbcefcd30b962198fe4736 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/xtts_manager.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/xtts_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..feaabf92450d22fb41acd1f3eba2be0299ff47f9 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/xtts_manager.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/__pycache__/zh_num2words.cpython-310.pyc b/viXTTS/TTS/tts/layers/xtts/__pycache__/zh_num2words.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72bdf95df7d30f307c71008d05c7949e87e4dcf8 Binary files /dev/null and b/viXTTS/TTS/tts/layers/xtts/__pycache__/zh_num2words.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/layers/xtts/dvae.py b/viXTTS/TTS/tts/layers/xtts/dvae.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd7a9d09f44cc8dae102a053c365462dc416b6d --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/dvae.py @@ -0,0 +1,393 @@ +import functools +from math import sqrt + +import torch +import torch.distributed as distributed +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from einops import rearrange + + +def default(val, d): + return val if val is not None else d + + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + + return inner + + +def dvae_wav_to_mel( + wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu") +): + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=1024, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + self.balancing_heuristic = balancing_heuristic + self.codes = None + self.max_codes = 64000 + self.codes_full = False + self.new_return_order = new_return_order + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input, return_soft_codes=False): + if self.balancing_heuristic and self.codes_full: + h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes) + mask = torch.logical_or(h > 0.9, h < 0.01).unsqueeze(1) + ep = self.embed.permute(1, 0) + ea = self.embed_avg.permute(1, 0) + rand_embed = torch.randn_like(ep) * mask + self.embed = (ep * ~mask + rand_embed).permute(1, 0) + self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) + self.cluster_size = self.cluster_size * ~mask.squeeze() + if torch.any(mask): + print(f"Reset {torch.sum(mask)} embedding codes.") + self.codes = None + self.codes_full = False + + flatten = input.reshape(-1, self.dim) + dist = flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True) + soft_codes = -dist + _, embed_ind = soft_codes.max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.balancing_heuristic: + if self.codes is None: + self.codes = embed_ind.flatten() + else: + self.codes = torch.cat([self.codes, embed_ind.flatten()]) + if len(self.codes) > self.max_codes: + self.codes = self.codes[-self.max_codes :] + self.codes_full = True + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(embed_onehot_sum) + distributed.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + quantize = input + (quantize - input).detach() + + if return_soft_codes: + return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,)) + elif self.new_return_order: + return quantize, embed_ind, diff + else: + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +# Fits a soft-discretized input to a normal-PDF across the specified dimension. +# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete +# values with the specified expected variance. +class DiscretizationLoss(nn.Module): + def __init__(self, discrete_bins, dim, expected_variance, store_past=0): + super().__init__() + self.discrete_bins = discrete_bins + self.dim = dim + self.dist = torch.distributions.Normal(0, scale=expected_variance) + if store_past > 0: + self.record_past = True + self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins)) + else: + self.record_past = False + + def forward(self, x): + other_dims = set(range(len(x.shape))) - set([self.dim]) + averaged = x.sum(dim=tuple(other_dims)) / x.sum() + averaged = averaged - averaged.mean() + + if self.record_past: + acc_count = self.accumulator.shape[0] + avg = averaged.detach().clone() + if self.accumulator_filled > 0: + averaged = torch.mean(self.accumulator, dim=0) * (acc_count - 1) / acc_count + averaged / acc_count + + # Also push averaged into the accumulator. + self.accumulator[self.accumulator_index] = avg + self.accumulator_index += 1 + if self.accumulator_index >= acc_count: + self.accumulator_index *= 0 + if self.accumulator_filled <= 0: + self.accumulator_filled += 1 + + return torch.sum(-self.dist.log_prob(averaged)) + + +class ResBlock(nn.Module): + def __init__(self, chan, conv, activation): + super().__init__() + self.net = nn.Sequential( + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 1), + ) + + def forward(self, x): + return self.net(x) + x + + +class UpsampledConv(nn.Module): + def __init__(self, conv, *args, **kwargs): + super().__init__() + assert "stride" in kwargs.keys() + self.stride = kwargs["stride"] + del kwargs["stride"] + self.conv = conv(*args, **kwargs) + + def forward(self, x): + up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest") + return self.conv(up) + + +# DiscreteVAE partially derived from lucidrains DALLE implementation +# Credit: https://github.com/lucidrains/DALLE-pytorch +class DiscreteVAE(nn.Module): + def __init__( + self, + positional_dims=2, + num_tokens=512, + codebook_dim=512, + num_layers=3, + num_resnet_blocks=0, + hidden_dim=64, + channels=3, + stride=2, + kernel_size=4, + use_transposed_convs=True, + encoder_norm=False, + activation="relu", + smooth_l1_loss=False, + straight_through=False, + normalization=None, # ((0.5,) * 3, (0.5,) * 3), + record_codes=False, + discretization_loss_averaging_steps=100, + lr_quantizer_args={}, + ): + super().__init__() + has_resblocks = num_resnet_blocks > 0 + + self.num_tokens = num_tokens + self.num_layers = num_layers + self.straight_through = straight_through + self.positional_dims = positional_dims + self.discrete_loss = DiscretizationLoss( + num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps + ) + + assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. + if positional_dims == 2: + conv = nn.Conv2d + conv_transpose = nn.ConvTranspose2d + else: + conv = nn.Conv1d + conv_transpose = nn.ConvTranspose1d + if not use_transposed_convs: + conv_transpose = functools.partial(UpsampledConv, conv) + + if activation == "relu": + act = nn.ReLU + elif activation == "silu": + act = nn.SiLU + else: + assert NotImplementedError() + + enc_layers = [] + dec_layers = [] + + if num_layers > 0: + enc_chans = [hidden_dim * 2**i for i in range(num_layers)] + dec_chans = list(reversed(enc_chans)) + + enc_chans = [channels, *enc_chans] + + dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] + dec_chans = [dec_init_chan, *dec_chans] + + enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + + pad = (kernel_size - 1) // 2 + for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act())) + if encoder_norm: + enc_layers.append(nn.GroupNorm(8, enc_out)) + dec_layers.append( + nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride=stride, padding=pad), act()) + ) + dec_out_chans = dec_chans[-1] + innermost_dim = dec_chans[0] + else: + enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act())) + dec_out_chans = hidden_dim + innermost_dim = hidden_dim + + for _ in range(num_resnet_blocks): + dec_layers.insert(0, ResBlock(innermost_dim, conv, act)) + enc_layers.append(ResBlock(innermost_dim, conv, act)) + + if num_resnet_blocks > 0: + dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) + + enc_layers.append(conv(innermost_dim, codebook_dim, 1)) + dec_layers.append(conv(dec_out_chans, channels, 1)) + + self.encoder = nn.Sequential(*enc_layers) + self.decoder = nn.Sequential(*dec_layers) + + self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss + self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) + + # take care of normalization within class + self.normalization = normalization + self.record_codes = record_codes + if record_codes: + self.codes = torch.zeros((1228800,), dtype=torch.long) + self.code_ind = 0 + self.total_codes = 0 + self.internal_step = 0 + + def norm(self, images): + if not self.normalization is not None: + return images + + means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) + arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()" + means, stds = map(lambda t: rearrange(t, arrange), (means, stds)) + images = images.clone() + images.sub_(means).div_(stds) + return images + + def get_debug_values(self, step, __): + if self.record_codes and self.total_codes > 0: + # Report annealing schedule + return {"histogram_codes": self.codes[: self.total_codes]} + else: + return {} + + @torch.no_grad() + @eval_decorator + def get_codebook_indices(self, images): + img = self.norm(images) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, _ = self.codebook(logits) + self.log_codes(codes) + return codes + + def decode(self, img_seq): + self.log_codes(img_seq) + if hasattr(self.codebook, "embed_code"): + image_embeds = self.codebook.embed_code(img_seq) + else: + image_embeds = F.embedding(img_seq, self.codebook.codebook) + b, n, d = image_embeds.shape + + kwargs = {} + if self.positional_dims == 1: + arrange = "b n d -> b d n" + else: + h = w = int(sqrt(n)) + arrange = "b (h w) d -> b d h w" + kwargs = {"h": h, "w": w} + image_embeds = rearrange(image_embeds, arrange, **kwargs) + images = [image_embeds] + for layer in self.decoder: + images.append(layer(images[-1])) + return images[-1], images[-2] + + def infer(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + return self.decode(codes) + + # Note: This module is not meant to be run in forward() except while training. It has special logic which performs + # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially + # more lossy (but useful for determining network performance). + def forward(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + sampled = sampled.permute((0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1)) + + if self.training: + out = sampled + for d in self.decoder: + out = d(out) + self.log_codes(codes) + else: + # This is non-differentiable, but gives a better idea of how the network is actually performing. + out, _ = self.decode(codes) + + # reconstruction loss + recon_loss = self.loss_fn(img, out, reduction="none") + + return recon_loss, commitment_loss, out + + def log_codes(self, codes): + # This is so we can debug the distribution of codes being learned. + if self.record_codes and self.internal_step % 10 == 0: + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i : i + l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + self.internal_step += 1 diff --git a/viXTTS/TTS/tts/layers/xtts/gpt.py b/viXTTS/TTS/tts/layers/xtts/gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b186b858a4dd8baec24a7214ae7bc573097fed --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/gpt.py @@ -0,0 +1,611 @@ +# ported from: https://github.com/neonbjb/tortoise-tts + +import functools +import math +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2Config + +from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel +from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder +from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=0.02, relative=False): + super().__init__() + # nn.Embedding + self.emb = torch.nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + self.relative = relative + self.seq_len = seq_len + + def forward(self, x): + sl = x.shape[1] + if self.relative: + start = random.randint(sl, self.seq_len) - sl + return self.emb(torch.arange(start, start + sl, device=x.device)) + else: + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def build_hf_gpt_transformer( + layers, + model_dim, + heads, + max_mel_seq_len, + max_text_seq_len, + max_prompt_len, + checkpointing, +): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + + mel_pos_emb = ( + LearnedPositionEmbeddings(max_mel_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + text_pos_emb = ( + LearnedPositionEmbeddings(max_text_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True) + return gpt, mel_pos_emb, text_pos_emb, None, None + + +class GPT(nn.Module): + def __init__( + self, + start_text_token=261, + stop_text_token=0, + layers=8, + model_dim=512, + heads=8, + max_text_tokens=120, + max_mel_tokens=250, + max_prompt_tokens=70, + max_conditioning_inputs=1, + code_stride_len=1024, + number_text_tokens=256, + num_audio_tokens=8194, + start_audio_token=8192, + stop_audio_token=8193, + train_solo_embeddings=False, + checkpointing=False, + average_conditioning_embeddings=False, + label_smoothing=0.0, + use_perceiver_resampler=False, + perceiver_cond_length_compression=256, + ): + """ + Args: + + """ + super().__init__() + + self.label_smoothing = label_smoothing + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.num_audio_tokens = num_audio_tokens + self.start_audio_token = start_audio_token + self.stop_audio_token = stop_audio_token + self.start_prompt_token = start_audio_token + self.stop_prompt_token = stop_audio_token + self.layers = layers + self.heads = heads + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2 + self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs + self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 + self.max_prompt_tokens = max_prompt_tokens + self.code_stride_len = code_stride_len + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.conditioning_dropout = nn.Dropout1d(0.1) + self.average_conditioning_embeddings = average_conditioning_embeddings + self.use_perceiver_resampler = use_perceiver_resampler + self.perceiver_cond_length_compression = perceiver_cond_length_compression + + self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) + self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim) + + ( + self.gpt, + self.mel_pos_embedding, + self.text_pos_embedding, + self.mel_layer_pos_embedding, + self.text_layer_pos_embedding, + ) = build_hf_gpt_transformer( + layers, + model_dim, + heads, + self.max_mel_tokens, + self.max_text_tokens, + self.max_prompt_tokens, + checkpointing, + ) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens) + self.mel_head = nn.Linear(model_dim, self.num_audio_tokens) + + if self.use_perceiver_resampler: + # XTTS v2 + self.conditioning_perceiver = PerceiverResampler( + dim=model_dim, + depth=2, + dim_context=model_dim, + num_latents=32, + dim_head=64, + heads=8, + ff_mult=4, + use_flash_attn=False, + ) + else: + # XTTS v1 + self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim) + self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim) + + def get_grad_norm_parameter_groups(self): + return { + "conditioning_encoder": list(self.conditioning_encoder.parameters()), + "conditioning_perceiver": list(self.conditioning_perceiver.parameters()) + if self.use_perceiver_resampler + else None, + "gpt": list(self.gpt.parameters()), + "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), + } + + def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False): + seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.gpt_inference = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + self.gpt.wte = self.mel_embedding + + if use_deepspeed: + import deepspeed + + self.ds_engine = deepspeed.init_inference( + model=self.gpt_inference.half(), # Transformers models + mp_size=1, # Number of GPU + dtype=torch.float32, # desired data type of output + replace_method="auto", # Lets DS autmatically identify the layer to replace + replace_with_kernel_inject=True, # replace the model with the kernel injector + ) + self.gpt_inference = self.ds_engine.module.eval() + + def set_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, code_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with stop_audio_token in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + for b in range(len(code_lengths)): + actual_end = code_lengths[b] + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_audio_token + return mel_input_tokens + + def get_logits( + self, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + prompt=None, + get_attns=False, + return_latent=False, + attn_mask_cond=None, + attn_mask_text=None, + attn_mask_mel=None, + ): + if prompt is not None: + offset = prompt.shape[1] + if second_inputs is not None: + emb = torch.cat([prompt, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([prompt, first_inputs], dim=1) + + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + attn_mask = None + if attn_mask_text is not None: + attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) + if prompt is not None: + attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) + attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1) + + gpt_out = self.gpt( + inputs_embeds=emb, + return_dict=True, + output_attentions=get_attns, + attention_mask=attn_mask, + ) + + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, offset:] + enc = self.final_norm(enc) + + if return_latent: + return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :] + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def get_prompts(self, prompt_codes): + """ + Create a prompt from the mel codes. This is used to condition the model on the mel codes. + Pad the prompt with start and stop mel tokens. + """ + prompt = prompt_codes + if self.training: + lengths = [] + # Compute the real prompt length based on the first encounter with the token 83 used for padding + for i in range(prompt_codes.shape[0]): + length = 0 + for j in range(prompt_codes.shape[1]): + if prompt_codes[i, j] == 83: + break + else: + length += 1 + lengths.append(length) + + # prompt_len = random.randint(1, 9) # in secs + prompt_len = 3 + prompt_len = prompt_len * 24 # in frames + if prompt_codes.shape[-1] >= prompt_len: + for i in range(prompt_codes.shape[0]): + if lengths[i] < prompt_len: + start = 0 + else: + start = random.randint(0, lengths[i] - prompt_len) + prompt = prompt_codes[:, start : start + prompt_len] + + # add start and stop tokens + prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token) + prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) + return prompt + + def get_style_emb(self, cond_input, return_latent=False): + """ + cond_input: (b, 80, s) or (b, 1, 80, s) + conds: (b, 1024, s) + """ + conds = None + if not return_latent: + if cond_input.ndim == 4: + cond_input = cond_input.squeeze(1) + conds = self.conditioning_encoder(cond_input) # (b, d, s) + if self.use_perceiver_resampler: + conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32) + else: + # already computed + conds = cond_input.unsqueeze(1) + return conds + + def forward( + self, + text_inputs, + text_lengths, + audio_codes, + wav_lengths, + cond_mels=None, + cond_idxs=None, + cond_lens=None, + cond_latents=None, + return_attentions=False, + return_latent=False, + ): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + cond_mels: MEL float tensor, (b, 1, 80,s) + cond_idxs: cond start and end indexs, (b, 2) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + """ + # ❗ FIXIT + if self.max_conditioning_inputs == 0: + assert cond_mels is None, " ❗ cond_mels is not None, but max_conditioning_inputs == 0" + + max_text_len = text_lengths.max() + code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3 + + if cond_lens is not None: + if self.use_perceiver_resampler: + cond_lens = cond_lens // self.perceiver_cond_length_compression + else: + cond_lens = cond_lens // self.code_stride_len + + if cond_idxs is not None: + # recompute cond idxs for mel lengths + for idx in range(cond_idxs.size(0)): + if self.use_perceiver_resampler: + cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression + else: + cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len + + # ensure that the cond_mel does not have padding + # if cond_lens is not None and cond_idxs is None: + # min_cond_len = torch.min(cond_lens) + # cond_mels = cond_mels[:, :, :, :min_cond_len] + + # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. + max_mel_len = code_lengths.max() + + if max_mel_len > audio_codes.shape[-1]: + audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1])) + + # 💖 Lovely assertions + assert ( + max_mel_len <= audio_codes.shape[-1] + ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})" + assert ( + max_text_len <= text_inputs.shape[-1] + ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})" + + # Append stop token to text inputs + text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) + + # Append silence token to mel codes + audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token) + + # Pad mel codes with stop_audio_token + audio_codes = self.set_mel_padding( + audio_codes, code_lengths - 3 + ) # -3 to get the real code lengths without consider start and stop tokens that was not added yet + + # Build input and target tensors + # Prepend start token to inputs and append stop token to targets + text_inputs, text_targets = self.set_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + audio_codes, mel_targets = self.set_inputs_and_targets( + audio_codes, self.start_audio_token, self.stop_audio_token + ) + + # Set attn_mask + attn_mask_cond = None + attn_mask_text = None + attn_mask_mel = None + if not return_latent: + attn_mask_cond = torch.ones( + cond_mels.shape[0], + cond_mels.shape[-1], + dtype=torch.bool, + device=text_inputs.device, + ) + attn_mask_text = torch.ones( + text_inputs.shape[0], + text_inputs.shape[1], + dtype=torch.bool, + device=text_inputs.device, + ) + attn_mask_mel = torch.ones( + audio_codes.shape[0], + audio_codes.shape[1], + dtype=torch.bool, + device=audio_codes.device, + ) + + if cond_idxs is not None: + # use masking approach + for idx, r in enumerate(cond_idxs): + l = r[1] - r[0] + attn_mask_cond[idx, l:] = 0.0 + elif cond_lens is not None: + for idx, l in enumerate(cond_lens): + attn_mask_cond[idx, l:] = 0.0 + + for idx, l in enumerate(text_lengths): + attn_mask_text[idx, l + 1 :] = 0.0 + + for idx, l in enumerate(code_lengths): + attn_mask_mel[idx, l + 1 :] = 0.0 + + # Compute text embeddings + positional embeddings + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + # Compute mel embeddings + positional embeddings + mel_emb = self.mel_embedding(audio_codes) + self.mel_pos_embedding(audio_codes) + + # Compute speech conditioning input + if cond_latents is None: + cond_latents = self.get_style_emb(cond_mels).transpose(1, 2) + + # Get logits + sub = -5 # don't ask me why 😄 + if self.training: + sub = -1 + + text_logits, mel_logits = self.get_logits( + text_emb, + self.text_head, + mel_emb, + self.mel_head, + prompt=cond_latents, + get_attns=return_attentions, + return_latent=return_latent, + attn_mask_cond=attn_mask_cond, + attn_mask_text=attn_mask_text, + attn_mask_mel=attn_mask_mel, + ) + if return_latent: + return mel_logits[:, :sub] # sub to prevent bla. + + if return_attentions: + return mel_logits + + # Set paddings to -1 to ignore them in loss + for idx, l in enumerate(text_lengths): + text_targets[idx, l + 1 :] = -1 + + for idx, l in enumerate(code_lengths): + mel_targets[idx, l + 1 :] = -1 + + # check if stoptoken is in every row of mel_targets + assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[ + 0 + ], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row." + + # ignore the loss for the segment used for conditioning + # coin flip for the segment to be ignored + if cond_idxs is not None: + cond_start = cond_idxs[idx, 0] + cond_end = cond_idxs[idx, 1] + mel_targets[idx, cond_start:cond_end] = -1 + + # Compute losses + loss_text = F.cross_entropy( + text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + loss_mel = F.cross_entropy( + mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + return loss_text.mean(), loss_mel.mean(), mel_logits + + def inference(self, cond_latents, text_inputs, **hf_generate_kwargs): + self.compute_embeddings(cond_latents, text_inputs) + return self.generate(cond_latents, text_inputs, **hf_generate_kwargs) + + def compute_embeddings( + self, + cond_latents, + text_inputs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + emb = torch.cat([cond_latents, emb], dim=1) + self.gpt_inference.store_prefix_emb(emb) + gpt_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_audio_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + gpt_inputs[:, -1] = self.start_audio_token + return gpt_inputs + + def generate( + self, + cond_latents, + text_inputs, + **hf_generate_kwargs, + ): + gpt_inputs = self.compute_embeddings(cond_latents, text_inputs) + gen = self.gpt_inference.generate( + gpt_inputs, + bos_token_id=self.start_audio_token, + pad_token_id=self.stop_audio_token, + eos_token_id=self.stop_audio_token, + max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1], + **hf_generate_kwargs, + ) + if "return_dict_in_generate" in hf_generate_kwargs: + return gen.sequences[:, gpt_inputs.shape[1] :], gen + return gen[:, gpt_inputs.shape[1] :] + + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.gpt_inference.generate_stream( + fake_inputs, + bos_token_id=self.start_audio_token, + pad_token_id=self.stop_audio_token, + eos_token_id=self.stop_audio_token, + max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1], + do_stream=True, + **hf_generate_kwargs, + ) diff --git a/viXTTS/TTS/tts/layers/xtts/gpt_inference.py b/viXTTS/TTS/tts/layers/xtts/gpt_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d44bd3decd2eb14a5bed14e5d2a8232386ef7076 --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/gpt_inference.py @@ -0,0 +1,136 @@ +import math + +import torch +from torch import nn +from transformers import GPT2PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + +class GPT2InferenceModel(GPT2PreTrainedModel): + """Override GPT2LMHeadModel to allow for prefix conditioning.""" + + def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.pos_embedding = pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + def store_prefix_emb(self, prefix_emb): + self.cached_prefix_emb = prefix_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_prefix_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1] + + # Create embedding + prefix_len = self.cached_prefix_emb.shape[1] + if input_ids.shape[1] != 1: + gen_inputs = input_ids[:, prefix_len:] + gen_emb = self.embeddings(gen_inputs) + gen_emb = gen_emb + self.pos_embedding(gen_emb) + if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]: + prefix_emb = self.cached_prefix_emb.repeat_interleave( + gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0 + ) + else: + prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype) + emb = torch.cat([prefix_emb, gen_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - (prefix_len + 1), attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) diff --git a/viXTTS/TTS/tts/layers/xtts/hifigan_decoder.py b/viXTTS/TTS/tts/layers/xtts/hifigan_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9add7826e694e40296493b0830a6920ca9b36f1a --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/hifigan_decoder.py @@ -0,0 +1,732 @@ +import torch +import torchaudio +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + +from TTS.utils.io import load_fsspec + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_parametrizations(l, "weight") + for l in self.convs2: + remove_parametrizations(l, "weight") + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_parametrizations(l, "weight") + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + cond_in_each_up_layer=False, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + self.cond_in_each_up_layer = cond_in_each_up_layer + + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_parametrizations(self.conv_pre, "weight") + + if not conv_post_weight_norm: + remove_parametrizations(self.conv_post, "weight") + + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(cond_channels, ch, 1)) + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + + if self.cond_in_each_up_layer: + o = o + self.conds[i](g) + + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + c = c.to(self.conv_pre.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_parametrizations(l, "weight") + for l in self.resblocks: + l.remove_weight_norm() + remove_parametrizations(self.conv_pre, "weight") + remove_parametrizations(self.conv_post, "weight") + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + print(" | > Layer missing in the model definition: {}".format(k)) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + return model_dict + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + +class ResNetSpeakerEncoder(nn.Module): + """This is copied from 🐸TTS to remove it from the dependencies.""" + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x + + def load_checkpoint( + self, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + try: + self.load_state_dict(state["model"]) + print(" > Model fully restored. ") + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + print(" > Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"]) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + print(" > Criterion load ignored because of:", error) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion + + +class HifiDecoder(torch.nn.Module): + def __init__( + self, + input_sample_rate=22050, + output_sample_rate=24000, + output_hop_length=256, + ar_mel_length_compression=1024, + decoder_input_dim=1024, + resblock_type_decoder="1", + resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_kernel_sizes_decoder=[3, 7, 11], + upsample_rates_decoder=[8, 8, 2, 2], + upsample_initial_channel_decoder=512, + upsample_kernel_sizes_decoder=[16, 16, 4, 4], + d_vector_dim=512, + cond_d_vector_in_each_upsampling_layer=True, + speaker_encoder_audio_config={ + "fft_size": 512, + "win_length": 400, + "hop_length": 160, + "sample_rate": 16000, + "preemphasis": 0.97, + "num_mels": 64, + }, + ): + super().__init__() + self.input_sample_rate = input_sample_rate + self.output_sample_rate = output_sample_rate + self.output_hop_length = output_hop_length + self.ar_mel_length_compression = ar_mel_length_compression + self.speaker_encoder_audio_config = speaker_encoder_audio_config + self.waveform_decoder = HifiganGenerator( + decoder_input_dim, + 1, + resblock_type_decoder, + resblock_dilation_sizes_decoder, + resblock_kernel_sizes_decoder, + upsample_kernel_sizes_decoder, + upsample_initial_channel_decoder, + upsample_rates_decoder, + inference_padding=0, + cond_channels=d_vector_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer, + ) + self.speaker_encoder = ResNetSpeakerEncoder( + input_dim=64, + proj_dim=512, + log_input=True, + use_torch_spec=True, + audio_config=speaker_encoder_audio_config, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, latents, g=None): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + + z = torch.nn.functional.interpolate( + latents.transpose(1, 2), + scale_factor=[self.ar_mel_length_compression / self.output_hop_length], + mode="linear", + ).squeeze(1) + # upsample to the right sr + if self.output_sample_rate != self.input_sample_rate: + z = torch.nn.functional.interpolate( + z, + scale_factor=[self.output_sample_rate / self.input_sample_rate], + mode="linear", + ).squeeze(0) + o = self.waveform_decoder(z, g=g) + return o + + @torch.no_grad() + def inference(self, c, g): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + return self.forward(c, g=g) + + def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + # remove unused keys + state = state["model"] + states_keys = list(state.keys()) + for key in states_keys: + if "waveform_decoder." not in key and "speaker_encoder." not in key: + del state[key] + + self.load_state_dict(state) + if eval: + self.eval() + assert not self.training + self.waveform_decoder.remove_weight_norm() diff --git a/viXTTS/TTS/tts/layers/xtts/latent_encoder.py b/viXTTS/TTS/tts/layers/xtts/latent_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d62a36f1529ddd1e9e6fdd92afc5c9f224f827 --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/latent_encoder.py @@ -0,0 +1,141 @@ +# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts + +import math + +import torch +from torch import nn +from torch.nn import functional as F + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def normalization(channels): + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +class QKVAttention(nn.Module): + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, qk_bias=0): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = weight + qk_bias + if mask is not None: + mask = mask.repeat(self.n_heads, 1, 1) + weight[mask.logical_not()] = -torch.inf + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """An attention block that allows spatial positions to attend to each other.""" + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + out_channels=None, + do_activation=False, + ): + super().__init__() + self.channels = channels + out_channels = channels if out_channels is None else out_channels + self.do_activation = do_activation + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, out_channels * 3, 1) + self.attention = QKVAttention(self.num_heads) + + self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) + self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) + + def forward(self, x, mask=None, qk_bias=0): + b, c, *spatial = x.shape + if mask is not None: + if len(mask.shape) == 2: + mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1) + if mask.shape[1] != x.shape[-1]: + mask = mask[:, : x.shape[-1], : x.shape[-1]] + + x = x.reshape(b, c, -1) + x = self.norm(x) + if self.do_activation: + x = F.silu(x, inplace=True) + qkv = self.qkv(x) + h = self.attention(qkv, mask=mask, qk_bias=qk_bias) + h = self.proj_out(h) + xp = self.x_proj(x) + return (xp + h).reshape(b, xp.shape[1], *spatial) + + +class ConditioningEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + ): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + """ + x: (b, 80, s) + """ + h = self.init(x) + h = self.attn(h) + return h diff --git a/viXTTS/TTS/tts/layers/xtts/perceiver_encoder.py b/viXTTS/TTS/tts/layers/xtts/perceiver_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7ee79b5018c80ad04c5766e7cd446862097c09 --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/perceiver_encoder.py @@ -0,0 +1,319 @@ +# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532 + +from collections import namedtuple +from functools import wraps + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from packaging import version +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +print_once = once(print) + +# main class + + +class Attend(nn.Module): + def __init__(self, dropout=0.0, causal=False, use_flash=False): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.causal = causal + self.register_buffer("mask", None, persistent=False) + + self.use_flash = use_flash + assert not ( + use_flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" + + # determine efficient attention configs for cuda and cpu + self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) + self.cpu_config = self.config(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once("A100 GPU detected, using flash attention if input tensor is on cuda") + self.cuda_config = self.config(True, False, False) + else: + print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda") + self.cuda_config = self.config(False, True, True) + + def get_mask(self, n, device): + if exists(self.mask) and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def flash_attn(self, q, k, v, mask=None): + _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda + + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if k.ndim == 3: + k = rearrange(k, "b ... -> b 1 ...").expand_as(q) + + if v.ndim == 3: + v = rearrange(v, "b ... -> b 1 ...").expand_as(q) + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal + ) + + return out + + def forward(self, q, k, v, mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device = q.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + if self.use_flash: + return self.flash_attn(q, k, v, mask=mask) + + kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" + + # similarity + + sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale + + # key padding mask + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # causal mask + + if self.causal: + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) + + return out + + +def Sequential(*mods): + return nn.Sequential(*filter(exists, mods)) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class RMSNorm(nn.Module): + def __init__(self, dim, scale=True, dim_cond=None): + super().__init__() + self.cond = exists(dim_cond) + self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None + + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) if scale else None + + def forward(self, x, cond=None): + gamma = default(self.gamma, 1) + out = F.normalize(x, dim=-1) * self.scale * gamma + + if not self.cond: + return out + + assert exists(cond) + gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) + gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) + return out * gamma + beta + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (kernel_size,) = self.kernel_size + (dilation,) = self.dilation + (stride,) = self.stride + + assert stride == 1 + self.causal_padding = dilation * (kernel_size - 1) + + def forward(self, x): + causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) + return super().forward(causal_padded_x) + + +class GEGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.gelu(gate) * x + + +def FeedForward(dim, mult=4, causal_conv=False): + dim_inner = int(dim * mult * 2 / 3) + + conv = None + if causal_conv: + conv = nn.Sequential( + Rearrange("b n d -> b d n"), + CausalConv1d(dim_inner, dim_inner, 3), + Rearrange("b d n -> b n d"), + ) + + return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth=2, + dim_context=None, + num_latents=32, + dim_head=64, + heads=8, + ff_mult=4, + use_flash_attn=False, + ): + super().__init__() + dim_context = default(dim_context, dim) + + self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() + + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + nn.init.normal_(self.latents, std=0.02) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + use_flash=use_flash_attn, + cross_attn_include_queries=True, + ), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = RMSNorm(dim) + + def forward(self, x, mask=None): + batch = x.shape[0] + + x = self.proj_context(x) + + latents = repeat(self.latents, "n d -> b n d", b=batch) + + for attn, ff in self.layers: + latents = attn(latents, x, mask=mask) + latents + latents = ff(latents) + latents + + return self.norm(latents) + + +class Attention(nn.Module): + def __init__( + self, + dim, + *, + dim_context=None, + causal=False, + dim_head=64, + heads=8, + dropout=0.0, + use_flash=False, + cross_attn_include_queries=False, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + self.cross_attn_include_queries = cross_attn_include_queries + + dim_inner = dim_head * heads + dim_context = default(dim_context, dim) + + self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) + self.to_q = nn.Linear(dim, dim_inner, bias=False) + self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) + self.to_out = nn.Linear(dim_inner, dim, bias=False) + + def forward(self, x, context=None, mask=None): + h, has_context = self.heads, exists(context) + + context = default(context, x) + + if has_context and self.cross_attn_include_queries: + context = torch.cat((x, context), dim=-2) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = self.attend(q, k, v, mask=mask) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) diff --git a/viXTTS/TTS/tts/layers/xtts/stream_generator.py b/viXTTS/TTS/tts/layers/xtts/stream_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e12f8995cf16de94b971036f6b88d255cd42ec6e --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/stream_generator.py @@ -0,0 +1,930 @@ +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +import copy +import inspect +import random +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import nn +from transformers import ( + BeamSearchScorer, + ConstrainedBeamSearchScorer, + DisjunctiveConstraint, + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + PhrasalConstraint, + PreTrainedModel, + StoppingCriteriaList, +) +from transformers.generation.utils import GenerateOutput, SampleOutput, logger + + +def setup_seed(seed): + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = False, + seed=0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + # setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + elif not has_default_max_length and generation_config.max_new_tokens is not None: + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None or generation_config.force_words_ids is not None + ) + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + if not has_default_typical_p: + raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + if generation_config.num_beams <= 1: + raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") + + if generation_config.do_sample: + raise ValueError("`do_sample` needs to be false for constrained generation.") + + if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1: + raise ValueError("`num_beam_groups` not supported yet for constrained generation.") + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any(not isinstance(token_ids, list) for token_ids in word_ids): + typeerror() + if any( + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + """Overload PreTrainedModel for streaming.""" + PreTrainedModel.generate_stream = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + + +if __name__ == "__main__": + from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result) diff --git a/viXTTS/TTS/tts/layers/xtts/tokenizer.py b/viXTTS/TTS/tts/layers/xtts/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..be33fff35b48651ba1c66f3404d9dabe1dd5d5b3 --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/tokenizer.py @@ -0,0 +1,846 @@ +import os +import re +import textwrap +from functools import cached_property + +import pypinyin +import torch +from hangul_romanize import Transliter +from hangul_romanize.rule import academic +from num2words import num2words +from spacy.lang.ar import Arabic +from spacy.lang.en import English +from spacy.lang.es import Spanish +from spacy.lang.ja import Japanese +from spacy.lang.zh import Chinese +from tokenizers import Tokenizer + +from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words + + +def get_spacy_lang(lang): + if lang == "zh": + return Chinese() + elif lang == "ja": + return Japanese() + elif lang == "ar": + return Arabic() + elif lang == "es": + return Spanish() + else: + # For most languages, Enlish does the job + return English() + + +def split_sentence(text, lang, text_split_length=250): + """Preprocess the input text""" + text_splits = [] + if text_split_length is not None and len(text) >= text_split_length: + text_splits.append("") + nlp = get_spacy_lang(lang) + nlp.add_pipe("sentencizer") + doc = nlp(text) + for sentence in doc.sents: + if len(text_splits[-1]) + len(str(sentence)) <= text_split_length: + # if the last sentence + the current sentence is less than the text_split_length + # then add the current sentence to the last sentence + text_splits[-1] += " " + str(sentence) + text_splits[-1] = text_splits[-1].lstrip() + elif len(str(sentence)) > text_split_length: + # if the current sentence is greater than the text_split_length + for line in textwrap.wrap( + str(sentence), + width=text_split_length, + drop_whitespace=True, + break_on_hyphens=False, + tabsize=1, + ): + text_splits.append(str(line)) + else: + text_splits.append(str(sentence)) + + if len(text_splits) > 1: + if text_splits[0] == "": + del text_splits[0] + else: + text_splits = [text.lstrip()] + + return text_splits + + +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = { + "en": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] + ], + "es": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("sra", "señora"), + ("sr", "señor"), + ("dr", "doctor"), + ("dra", "doctora"), + ("st", "santo"), + ("co", "compañía"), + ("jr", "junior"), + ("ltd", "limitada"), + ] + ], + "fr": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mme", "madame"), + ("mr", "monsieur"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "compagnie"), + ("jr", "junior"), + ("ltd", "limitée"), + ] + ], + "de": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("fr", "frau"), + ("dr", "doktor"), + ("st", "sankt"), + ("co", "firma"), + ("jr", "junior"), + ] + ], + "pt": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("sra", "senhora"), + ("sr", "senhor"), + ("dr", "doutor"), + ("dra", "doutora"), + ("st", "santo"), + ("co", "companhia"), + ("jr", "júnior"), + ("ltd", "limitada"), + ] + ], + "it": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # ("sig.ra", "signora"), + ("sig", "signore"), + ("dr", "dottore"), + ("st", "santo"), + ("co", "compagnia"), + ("jr", "junior"), + ("ltd", "limitata"), + ] + ], + "pl": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("p", "pani"), + ("m", "pan"), + ("dr", "doktor"), + ("sw", "święty"), + ("jr", "junior"), + ] + ], + "ar": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # There are not many common abbreviations in Arabic as in English. + ] + ], + "zh": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts. + ] + ], + "cs": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dr", "doktor"), # doctor + ("ing", "inženýr"), # engineer + ("p", "pan"), # Could also map to pani for woman but no easy way to do it + # Other abbreviations would be specialized and not as common. + ] + ], + "ru": [ + (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("г-жа", "госпожа"), # Mrs. + ("г-н", "господин"), # Mr. + ("д-р", "доктор"), # doctor + # Other abbreviations are less common or specialized. + ] + ], + "nl": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dhr", "de heer"), # Mr. + ("mevr", "mevrouw"), # Mrs. + ("dr", "dokter"), # doctor + ("jhr", "jonkheer"), # young lord or nobleman + # Dutch uses more abbreviations, but these are the most common ones. + ] + ], + "tr": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("b", "bay"), # Mr. + ("byk", "büyük"), # büyük + ("dr", "doktor"), # doctor + # Add other Turkish abbreviations here if needed. + ] + ], + "hu": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dr", "doktor"), # doctor + ("b", "bácsi"), # Mr. + ("nőv", "nővér"), # nurse + # Add other Hungarian abbreviations here if needed. + ] + ], + "ko": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # Korean doesn't typically use abbreviations in the same way as Latin-based scripts. + ] + ], +} + + +def expand_abbreviations_multilingual(text, lang="en"): + for regex, replacement in _abbreviations[lang]: + text = re.sub(regex, replacement, text) + return text + + +_symbols_multilingual = { + "en": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " and "), + ("@", " at "), + ("%", " percent "), + ("#", " hash "), + ("$", " dollar "), + ("£", " pound "), + ("°", " degree "), + ] + ], + "es": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " y "), + ("@", " arroba "), + ("%", " por ciento "), + ("#", " numeral "), + ("$", " dolar "), + ("£", " libra "), + ("°", " grados "), + ] + ], + "fr": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " et "), + ("@", " arobase "), + ("%", " pour cent "), + ("#", " dièse "), + ("$", " dollar "), + ("£", " livre "), + ("°", " degrés "), + ] + ], + "de": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " und "), + ("@", " at "), + ("%", " prozent "), + ("#", " raute "), + ("$", " dollar "), + ("£", " pfund "), + ("°", " grad "), + ] + ], + "pt": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " e "), + ("@", " arroba "), + ("%", " por cento "), + ("#", " cardinal "), + ("$", " dólar "), + ("£", " libra "), + ("°", " graus "), + ] + ], + "it": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " e "), + ("@", " chiocciola "), + ("%", " per cento "), + ("#", " cancelletto "), + ("$", " dollaro "), + ("£", " sterlina "), + ("°", " gradi "), + ] + ], + "pl": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " i "), + ("@", " małpa "), + ("%", " procent "), + ("#", " krzyżyk "), + ("$", " dolar "), + ("£", " funt "), + ("°", " stopnie "), + ] + ], + "ar": [ + # Arabic + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " و "), + ("@", " على "), + ("%", " في المئة "), + ("#", " رقم "), + ("$", " دولار "), + ("£", " جنيه "), + ("°", " درجة "), + ] + ], + "zh": [ + # Chinese + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " 和 "), + ("@", " 在 "), + ("%", " 百分之 "), + ("#", " 号 "), + ("$", " 美元 "), + ("£", " 英镑 "), + ("°", " 度 "), + ] + ], + "cs": [ + # Czech + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " a "), + ("@", " na "), + ("%", " procento "), + ("#", " křížek "), + ("$", " dolar "), + ("£", " libra "), + ("°", " stupně "), + ] + ], + "ru": [ + # Russian + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " и "), + ("@", " собака "), + ("%", " процентов "), + ("#", " номер "), + ("$", " доллар "), + ("£", " фунт "), + ("°", " градус "), + ] + ], + "nl": [ + # Dutch + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " en "), + ("@", " bij "), + ("%", " procent "), + ("#", " hekje "), + ("$", " dollar "), + ("£", " pond "), + ("°", " graden "), + ] + ], + "tr": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " ve "), + ("@", " at "), + ("%", " yüzde "), + ("#", " diyez "), + ("$", " dolar "), + ("£", " sterlin "), + ("°", " derece "), + ] + ], + "hu": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " és "), + ("@", " kukac "), + ("%", " százalék "), + ("#", " kettőskereszt "), + ("$", " dollár "), + ("£", " font "), + ("°", " fok "), + ] + ], + "ko": [ + # Korean + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " 그리고 "), + ("@", " 에 "), + ("%", " 퍼센트 "), + ("#", " 번호 "), + ("$", " 달러 "), + ("£", " 파운드 "), + ("°", " 도 "), + ] + ], +} + + +def expand_symbols_multilingual(text, lang="en"): + for regex, replacement in _symbols_multilingual[lang]: + text = re.sub(regex, replacement, text) + text = text.replace(" ", " ") # Ensure there are no double spaces + return text.strip() + + +_ordinal_re = { + "en": re.compile(r"([0-9]+)(st|nd|rd|th)"), + "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"), + "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"), + "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"), + "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"), + "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"), + "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"), + "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"), + "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals. + "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"), + "nl": re.compile(r"([0-9]+)(de|ste|e)"), + "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"), + "hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"), + "ko": re.compile(r"([0-9]+)(번째|번|차|째)"), +} +_number_re = re.compile(r"[0-9]+") +_currency_re = { + "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"), + "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"), + "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"), +} + +_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b") +_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b") +_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)") + + +def _remove_commas(m): + text = m.group(0) + if "," in text: + text = text.replace(",", "") + return text + + +def _remove_dots(m): + text = m.group(0) + if "." in text: + text = text.replace(".", "") + return text + + +def _expand_decimal_point(m, lang="en"): + amount = m.group(1).replace(",", ".") + return num2words(float(amount), lang=lang if lang != "cs" else "cz") + + +def _expand_currency(m, lang="en", currency="USD"): + amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", ".")))) + full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz") + + and_equivalents = { + "en": ", ", + "es": " con ", + "fr": " et ", + "de": " und ", + "pt": " e ", + "it": " e ", + "pl": ", ", + "cs": ", ", + "ru": ", ", + "nl": ", ", + "ar": ", ", + "tr": ", ", + "hu": ", ", + "ko": ", ", + } + + if amount.is_integer(): + last_and = full_amount.rfind(and_equivalents[lang]) + if last_and != -1: + full_amount = full_amount[:last_and] + + return full_amount + + +def _expand_ordinal(m, lang="en"): + return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz") + + +def _expand_number(m, lang="en"): + return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz") + + +def expand_numbers_multilingual(text, lang="en"): + if lang == "zh": + text = zh_num2words()(text) + else: + if lang in ["en", "ru"]: + text = re.sub(_comma_number_re, _remove_commas, text) + else: + text = re.sub(_dot_number_re, _remove_dots, text) + try: + text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text) + text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text) + text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text) + except: + pass + if lang != "tr": + text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text) + text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text) + text = re.sub(_number_re, lambda m: _expand_number(m, lang), text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def multilingual_cleaners(text, lang): + text = text.replace('"', "") + if lang == "tr": + text = text.replace("İ", "i") + text = text.replace("Ö", "ö") + text = text.replace("Ü", "ü") + text = lowercase(text) + text = expand_numbers_multilingual(text, lang) + text = expand_abbreviations_multilingual(text, lang) + text = expand_symbols_multilingual(text, lang=lang) + text = collapse_whitespace(text) + return text + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def chinese_transliterate(text): + return "".join( + [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)] + ) + + +def japanese_cleaners(text, katsu): + text = katsu.romaji(text) + text = lowercase(text) + return text + + +def korean_transliterate(text): + r = Transliter(academic) + return r.translit(text) + + +DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json") + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=None): + self.tokenizer = None + if vocab_file is not None: + self.tokenizer = Tokenizer.from_file(vocab_file) + self.char_limits = { + "en": 250, + "de": 253, + "fr": 273, + "es": 239, + "it": 213, + "pt": 203, + "pl": 224, + "zh": 82, + "ar": 166, + "cs": 186, + "ru": 182, + "nl": 251, + "tr": 226, + "ja": 71, + "hu": 224, + "ko": 95, + "vi": 250, + } + + @cached_property + def katsu(self): + import cutlet + + return cutlet.Cutlet() + + def check_input_length(self, txt, lang): + lang = lang.split("-")[0] # remove the region + limit = self.char_limits.get(lang, 250) + if len(txt) > limit: + print( + f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio." + ) + + def preprocess_text(self, txt, lang): + if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}: + txt = multilingual_cleaners(txt, lang) + if lang == "zh": + txt = chinese_transliterate(txt) + if lang == "ko": + txt = korean_transliterate(txt) + elif lang == "ja": + txt = japanese_cleaners(txt, self.katsu) + elif lang == "hi": + # @manmay will implement this + txt = basic_cleaners(txt) + elif lang == "vi": + txt = basic_cleaners(txt) + else: + raise NotImplementedError(f"Language '{lang}' is not supported.") + return txt + + def encode(self, txt, lang): + lang = lang.split("-")[0] # remove the region + self.check_input_length(txt, lang) + txt = self.preprocess_text(txt, lang) + lang = "zh-cn" if lang == "zh" else lang + txt = f"[{lang}]{txt}" + txt = txt.replace(" ", "[SPACE]") + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") + txt = txt.replace("[SPACE]", " ") + txt = txt.replace("[STOP]", "") + txt = txt.replace("[UNK]", "") + return txt + + def __len__(self): + return self.tokenizer.get_vocab_size() + + def get_number_tokens(self): + return max(self.tokenizer.get_vocab().values()) + 1 + + +def test_expand_numbers_multilingual(): + test_cases = [ + # English + ("In 12.5 seconds.", "In twelve point five seconds.", "en"), + ("There were 50 soldiers.", "There were fifty soldiers.", "en"), + ("This is a 1st test", "This is a first test", "en"), + ("That will be $20 sir.", "That will be twenty dollars sir.", "en"), + ("That will be 20€ sir.", "That will be twenty euro sir.", "en"), + ("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"), + ("That's 100,000.5.", "That's one hundred thousand point five.", "en"), + # French + ("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"), + ("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"), + ("Ceci est un 1er test", "Ceci est un premier test", "fr"), + ("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"), + ("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"), + ("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"), + ("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"), + # German + ("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"), + ("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"), + ("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender + ("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"), + ("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"), + ("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"), + # Spanish + ("En 12,5 segundos.", "En doce punto cinco segundos.", "es"), + ("Había 50 soldados.", "Había cincuenta soldados.", "es"), + ("Este es un 1er test", "Este es un primero test", "es"), + ("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"), + ("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"), + ("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"), + # Italian + ("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"), + ("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"), + ("Questo è un 1° test", "Questo è un primo test", "it"), + ("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"), + ("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"), + ("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"), + # Portuguese + ("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"), + ("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"), + ("Este é um 1º teste", "Este é um primeiro teste", "pt"), + ("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"), + ("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"), + ( + "Isso custará 20,15€ senhor.", + "Isso custará vinte euros e quinze cêntimos senhor.", + "pt", + ), # "cêntimos" should be "centavos" num2words issue + # Polish + ("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"), + ("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"), + ("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"), + ("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"), + # Arabic + ("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"), + ("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"), + # ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words + # ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'), + # Czech + ("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"), + ("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"), + ("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"), + ("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"), + # Russian + ("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"), + ("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"), + ("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"), + ("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"), + # Dutch + ("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"), + ("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"), + ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"), + ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"), + # Chinese (Simplified) + ("在12.5秒内", "在十二点五秒内", "zh"), + ("有50名士兵", "有五十名士兵", "zh"), + # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work + # ("那将是20€先生", '那将是二十欧元先生', 'zh'), + # Turkish + # ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR + ("50 asker vardı.", "elli asker vardı.", "tr"), + ("Bu 1. test", "Bu birinci test", "tr"), + # ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'), + # Hungarian + ("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"), + ("50 katona volt.", "ötven katona volt.", "hu"), + ("Ez az 1. teszt", "Ez az első teszt", "hu"), + # Korean + ("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"), + ("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"), + ("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"), + ] + for a, b, lang in test_cases: + out = expand_numbers_multilingual(a, lang=lang) + assert out == b, f"'{out}' vs '{b}'" + + +def test_abbreviations_multilingual(): + test_cases = [ + # English + ("Hello Mr. Smith.", "Hello mister Smith.", "en"), + ("Dr. Jones is here.", "doctor Jones is here.", "en"), + # Spanish + ("Hola Sr. Garcia.", "Hola señor Garcia.", "es"), + ("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"), + # French + ("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"), + ("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"), + # German + ("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"), + # Portuguese + ("Olá Sr. Silva.", "Olá senhor Silva.", "pt"), + ("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"), + # Italian + ("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"), + # ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern + # Polish + ("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"), + ("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"), + # Czech + ("P. Novák", "pan Novák", "cs"), + ("Dr. Vojtěch", "doktor Vojtěch", "cs"), + # Dutch + ("Dhr. Jansen", "de heer Jansen", "nl"), + ("Mevr. de Vries", "mevrouw de Vries", "nl"), + # Russian + ("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"), + ("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"), + # Turkish + ("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"), + ("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"), + # Hungarian + ("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"), + ] + + for a, b, lang in test_cases: + out = expand_abbreviations_multilingual(a, lang=lang) + assert out == b, f"'{out}' vs '{b}'" + + +def test_symbols_multilingual(): + test_cases = [ + ("I have 14% battery", "I have 14 percent battery", "en"), + ("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"), + ("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"), + ("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"), + ("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"), + ("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"), + ("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"), + ("Mám 14% baterie", "Mám 14 procento baterie", "cs"), + ("Těším se na tebe @ party", "Těším se na tebe na party", "cs"), + ("У меня 14% заряда", "У меня 14 процентов заряда", "ru"), + ("Я буду @ дома", "Я буду собака дома", "ru"), + ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"), + ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"), + ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"), + ("我的电量为 14%", "我的电量为 14 百分之", "zh"), + ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"), + ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"), + ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"), + ] + + for a, b, lang in test_cases: + out = expand_symbols_multilingual(a, lang=lang) + assert out == b, f"'{out}' vs '{b}'" + + +if __name__ == "__main__": + test_expand_numbers_multilingual() + test_abbreviations_multilingual() + test_symbols_multilingual() diff --git a/viXTTS/TTS/tts/layers/xtts/trainer/dataset.py b/viXTTS/TTS/tts/layers/xtts/trainer/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2f958cb5a5a66e1b7714887f1784a549200e479b --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/trainer/dataset.py @@ -0,0 +1,239 @@ +import os +import random +import sys + +import torch +import torch.nn.functional as F +import torch.utils.data + +from TTS.tts.models.xtts import load_audio + +torch.set_num_threads(1) + + +def key_samples_by_col(samples, col): + """Returns a dictionary of samples keyed by language.""" + samples_by_col = {} + for sample in samples: + col_val = sample[col] + assert isinstance(col_val, str) + if col_val not in samples_by_col: + samples_by_col[col_val] = [] + samples_by_col[col_val].append(sample) + return samples_by_col + + +def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False): + rel_clip = load_audio(gt_path, sample_rate) + # if eval uses a middle size sample when it is possible to be more reproducible + if is_eval: + sample_length = int((min_sample_length + max_sample_length) / 2) + else: + sample_length = random.randint(min_sample_length, max_sample_length) + gap = rel_clip.shape[-1] - sample_length + if gap < 0: + sample_length = rel_clip.shape[-1] // 2 + gap = rel_clip.shape[-1] - sample_length + + # if eval start always from the position 0 to be more reproducible + if is_eval: + rand_start = 0 + else: + rand_start = random.randint(0, gap) + + rand_end = rand_start + sample_length + rel_clip = rel_clip[:, rand_start:rand_end] + rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1])) + cond_idxs = [rand_start, rand_end] + return rel_clip, rel_clip.shape[-1], cond_idxs + + +class XTTSDataset(torch.utils.data.Dataset): + def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): + self.config = config + model_args = config.model_args + self.failed_samples = set() + self.debug_failures = model_args.debug_loading_failures + self.max_conditioning_length = model_args.max_conditioning_length + self.min_conditioning_length = model_args.min_conditioning_length + self.is_eval = is_eval + self.tokenizer = tokenizer + self.sample_rate = sample_rate + self.max_wav_len = model_args.max_wav_length + self.max_text_len = model_args.max_text_length + self.use_masking_gt_prompt_approach = model_args.gpt_use_masking_gt_prompt_approach + assert self.max_wav_len is not None and self.max_text_len is not None + + self.samples = samples + if not is_eval: + random.seed(config.training_seed) + # random.shuffle(self.samples) + random.shuffle(self.samples) + # order by language + self.samples = key_samples_by_col(self.samples, "language") + print(" > Sampling by language:", self.samples.keys()) + else: + # for evaluation load and check samples that are corrupted to ensures the reproducibility + self.check_eval_samples() + + def check_eval_samples(self): + print(" > Filtering invalid eval samples!!") + new_samples = [] + for sample in self.samples: + try: + tseq, _, wav, _, _, _ = self.load_item(sample) + except: + continue + # Basically, this audio file is nonexistent or too long to be supported by the dataset. + if ( + wav is None + or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) + or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len) + ): + continue + new_samples.append(sample) + self.samples = new_samples + print(" > Total eval samples after filtering:", len(self.samples)) + + def get_text(self, text, lang): + tokens = self.tokenizer.encode(text, lang) + tokens = torch.IntTensor(tokens) + assert not torch.any(tokens == 1), f"UNK token found in {text} -> {self.tokenizer.decode(tokens)}" + # The stop token should always be sacred. + assert not torch.any(tokens == 0), f"Stop token found in {text}" + return tokens + + def load_item(self, sample): + text = str(sample["text"]) + tseq = self.get_text(text, sample["language"]) + audiopath = sample["audio_file"] + wav = load_audio(audiopath, self.sample_rate) + if text is None or len(text.strip()) == 0: + raise ValueError + if wav is None or wav.shape[-1] < (0.5 * self.sample_rate): + # Ultra short clips are also useless (and can cause problems within some models). + raise ValueError + + if self.use_masking_gt_prompt_approach: + # get a slice from GT to condition the model + cond, _, cond_idxs = get_prompt_slice( + audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval + ) + # if use masking do not use cond_len + cond_len = torch.nan + else: + ref_sample = ( + sample["reference_path"] + if "reference_path" in sample and sample["reference_path"] is not None + else audiopath + ) + cond, cond_len, _ = get_prompt_slice( + ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval + ) + # if do not use masking use cond_len + cond_idxs = torch.nan + + return tseq, audiopath, wav, cond, cond_len, cond_idxs + + def __getitem__(self, index): + if self.is_eval: + sample = self.samples[index] + sample_id = str(index) + else: + # select a random language + lang = random.choice(list(self.samples.keys())) + # select random sample + index = random.randint(0, len(self.samples[lang]) - 1) + sample = self.samples[lang][index] + # a unique id for each sampel to deal with fails + sample_id = lang + "_" + str(index) + + # ignore samples that we already know that is not valid ones + if sample_id in self.failed_samples: + if self.debug_failures: + print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!") + # call get item again to get other sample + return self[1] + + # try to load the sample, if fails added it to the failed samples list + try: + tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample) + except: + if self.debug_failures: + print(f"error loading {sample['audio_file']} {sys.exc_info()}") + self.failed_samples.add(sample_id) + return self[1] + + # check if the audio and text size limits and if it out of the limits, added it failed_samples + if ( + wav is None + or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) + or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len) + ): + # Basically, this audio file is nonexistent or too long to be supported by the dataset. + # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. + if self.debug_failures and wav is not None and tseq is not None: + print( + f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}" + ) + self.failed_samples.add(sample_id) + return self[1] + + res = { + # 'real_text': text, + "text": tseq, + "text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long), + "wav": wav, + "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long), + "filenames": audiopath, + "conditioning": cond.unsqueeze(1), + "cond_lens": torch.tensor(cond_len, dtype=torch.long) + if cond_len is not torch.nan + else torch.tensor([cond_len]), + "cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]), + } + return res + + def __len__(self): + if self.is_eval: + return len(self.samples) + return sum([len(v) for v in self.samples.values()]) + + def collate_fn(self, batch): + # convert list of dicts to dict of lists + B = len(batch) + + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + # stack for features that already have the same shape + batch["wav_lengths"] = torch.stack(batch["wav_lengths"]) + batch["text_lengths"] = torch.stack(batch["text_lengths"]) + batch["conditioning"] = torch.stack(batch["conditioning"]) + batch["cond_lens"] = torch.stack(batch["cond_lens"]) + batch["cond_idxs"] = torch.stack(batch["cond_idxs"]) + + if torch.any(batch["cond_idxs"].isnan()): + batch["cond_idxs"] = None + + if torch.any(batch["cond_lens"].isnan()): + batch["cond_lens"] = None + + max_text_len = batch["text_lengths"].max() + max_wav_len = batch["wav_lengths"].max() + + # create padding tensors + text_padded = torch.IntTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, max_wav_len) + + # initialize tensors for zero padding + text_padded = text_padded.zero_() + wav_padded = wav_padded.zero_() + for i in range(B): + text = batch["text"][i] + text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text) + wav = batch["wav"][i] + wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav) + + batch["wav"] = wav_padded + batch["padded_text"] = text_padded + return batch diff --git a/viXTTS/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/viXTTS/TTS/tts/layers/xtts/trainer/gpt_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7a1d77835e87fa92b59b18e8d8439a50550d8b --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -0,0 +1,504 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union + +import torch +import torch.nn as nn +import torchaudio +from coqpit import Coqpit +from torch.nn import functional as F +from torch.utils.data import DataLoader +from trainer.torch import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram +from TTS.tts.layers.xtts.dvae import DiscreteVAE +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig +from TTS.utils.io import load_fsspec + + +@dataclass +class GPTTrainerConfig(XttsConfig): + lr: float = 5e-06 + training_seed: int = 1 + optimizer_wd_only_on_weights: bool = False + weighted_loss_attrs: dict = field(default_factory=lambda: {}) + weighted_loss_multipliers: dict = field(default_factory=lambda: {}) + test_sentences: List[dict] = field(default_factory=lambda: []) + + +@dataclass +class XttsAudioConfig(XttsAudioConfig): + dvae_sample_rate: int = 22050 + + +@dataclass +class GPTArgs(XttsArgs): + min_conditioning_length: int = 66150 + max_conditioning_length: int = 132300 + gpt_loss_text_ce_weight: float = 0.01 + gpt_loss_mel_ce_weight: float = 1.0 + gpt_num_audio_tokens: int = 8194 + debug_loading_failures: bool = False + max_wav_length: int = 255995 # ~11.6 seconds + max_text_length: int = 200 + tokenizer_file: str = "" + mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" + dvae_checkpoint: str = "" + xtts_checkpoint: str = "" + gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model + vocoder: str = "" # overide vocoder key on the config to avoid json write issues + + +def callback_clearml_load_save(operation_type, model_info): + # return None means skip the file upload/log, returning model_info will continue with the log/upload + # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size + assert operation_type in ("load", "save") + # print(operation_type, model_info.__dict__) + + if "similarities.pth" in model_info.__dict__["local_model_path"]: + return None + + return model_info + + +class GPTTrainer(BaseTTS): + def __init__(self, config: Coqpit): + """ + Tortoise GPT training class + """ + super().__init__(config, ap=None, tokenizer=None) + self.config = config + # init XTTS model + self.xtts = Xtts(self.config) + # create the tokenizer with the target vocabulary + self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) + # init gpt encoder and hifigan decoder + self.xtts.init_models() + + if self.args.xtts_checkpoint: + self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False) + + # set mel stats + if self.args.mel_norm_file: + self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file) + + # load GPT if available + if self.args.gpt_checkpoint: + gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) + # deal with coqui Trainer exported model + if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): + print("Coqui Trainer checkpoint detected! Converting it!") + gpt_checkpoint = gpt_checkpoint["model"] + states_keys = list(gpt_checkpoint.keys()) + for key in states_keys: + if "gpt." in key: + new_key = key.replace("gpt.", "") + gpt_checkpoint[new_key] = gpt_checkpoint[key] + del gpt_checkpoint[key] + else: + del gpt_checkpoint[key] + + # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible + if ( + "text_embedding.weight" in gpt_checkpoint + and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape + ): + num_new_tokens = ( + self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] + ) + print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") + + # add new tokens to a linear layer (text_head) + emb_g = gpt_checkpoint["text_embedding.weight"] + new_row = torch.randn(num_new_tokens, emb_g.shape[1]) + start_token_row = emb_g[-1, :] + emb_g = torch.cat([emb_g, new_row], axis=0) + emb_g[-1, :] = start_token_row + gpt_checkpoint["text_embedding.weight"] = emb_g + + # add new weights to the linear layer (text_head) + text_head_weight = gpt_checkpoint["text_head.weight"] + start_token_row = text_head_weight[-1, :] + new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1]) + text_head_weight = torch.cat([text_head_weight, new_entry], axis=0) + text_head_weight[-1, :] = start_token_row + gpt_checkpoint["text_head.weight"] = text_head_weight + + # add new biases to the linear layer (text_head) + text_head_bias = gpt_checkpoint["text_head.bias"] + start_token_row = text_head_bias[-1] + new_bias_entry = torch.zeros(num_new_tokens) + text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0) + text_head_bias[-1] = start_token_row + gpt_checkpoint["text_head.bias"] = text_head_bias + + self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) + print(">> GPT weights restored from:", self.args.gpt_checkpoint) + + # Mel spectrogram extractor for conditioning + if self.args.gpt_use_perceiver_resampler: + self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( + filter_length=2048, + hop_length=256, + win_length=1024, + normalize=False, + sampling_rate=config.audio.sample_rate, + mel_fmin=0, + mel_fmax=8000, + n_mel_channels=80, + mel_norm_file=self.args.mel_norm_file, + ) + else: + self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( + filter_length=4096, + hop_length=1024, + win_length=4096, + normalize=False, + sampling_rate=config.audio.sample_rate, + mel_fmin=0, + mel_fmax=8000, + n_mel_channels=80, + mel_norm_file=self.args.mel_norm_file, + ) + + # Load DVAE + self.dvae = DiscreteVAE( + channels=80, + normalization=None, + positional_dims=1, + num_tokens=self.args.gpt_num_audio_tokens - 2, + codebook_dim=512, + hidden_dim=512, + num_resnet_blocks=3, + kernel_size=3, + num_layers=2, + use_transposed_convs=False, + ) + + self.dvae.eval() + if self.args.dvae_checkpoint: + dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) + self.dvae.load_state_dict(dvae_checkpoint, strict=False) + print(">> DVAE weights restored from:", self.args.dvae_checkpoint) + else: + raise RuntimeError( + "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" + ) + + # Mel spectrogram extractor for DVAE + self.torch_mel_spectrogram_dvae = TorchMelSpectrogram( + mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + cond_mels: MEL float tensor, (b, num_samples, 80,t_m) + cond_idxs: cond start and end indexs, (b, 2) + cond_lens: long tensor, (b,) + """ + losses = self.xtts.gpt( + text_inputs, + text_lengths, + audio_codes, + wav_lengths, + cond_mels=cond_mels, + cond_idxs=cond_idxs, + cond_lens=cond_lens, + ) + return losses + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 + test_audios = {} + if self.config.test_sentences: + # init gpt for inference mode + self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) + self.xtts.gpt.eval() + print(" | > Synthesizing test sentences.") + for idx, s_info in enumerate(self.config.test_sentences): + wav = self.xtts.synthesize( + s_info["text"], + self.config, + s_info["speaker_wav"], + s_info["language"], + gpt_cond_len=3, + )["wav"] + test_audios["{}-audio".format(idx)] = wav + + # delete inference layers + del self.xtts.gpt.gpt_inference + del self.xtts.gpt.gpt.wte + return {"audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate) + + def format_batch(self, batch: Dict) -> Dict: + return batch + + @torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + batch["text_lengths"] = batch["text_lengths"] + batch["wav_lengths"] = batch["wav_lengths"] + batch["text_inputs"] = batch["padded_text"] + batch["cond_idxs"] = batch["cond_idxs"] + # compute conditioning mel specs + # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor + B, num_cond_samples, C, T = batch["conditioning"].size() + conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T) + paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped) + # transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel]) + n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1) + T_mel = paired_conditioning_mel.size(2) + paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel) + # get the conditioning embeddings + batch["cond_mels"] = paired_conditioning_mel + # compute codes using DVAE + if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate: + dvae_wav = torchaudio.functional.resample( + batch["wav"], + orig_freq=self.config.audio.sample_rate, + new_freq=self.config.audio.dvae_sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method="kaiser_window", + beta=14.769656459379492, + ) + else: + dvae_wav = batch["wav"] + dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav) + codes = self.dvae.get_codebook_indices(dvae_mel_spec) + + batch["audio_codes"] = codes + # delete useless batch tensors + del batch["padded_text"] + del batch["wav"] + del batch["conditioning"] + return batch + + def train_step(self, batch, criterion): + loss_dict = {} + cond_mels = batch["cond_mels"] + text_inputs = batch["text_inputs"] + text_lengths = batch["text_lengths"] + audio_codes = batch["audio_codes"] + wav_lengths = batch["wav_lengths"] + cond_idxs = batch["cond_idxs"] + cond_lens = batch["cond_lens"] + + loss_text, loss_mel, _ = self.forward( + text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens + ) + loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight + loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight + loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] + return {"model_outputs": None}, loss_dict + + def eval_step(self, batch, criterion): + # ignore masking for more consistent evaluation + batch["cond_idxs"] = None + return self.train_step(batch, criterion) + + def on_train_epoch_start(self, trainer): + trainer.model.eval() # the whole model to eval + # put gpt model in training mode + if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"): + trainer.model.module.xtts.gpt.train() + else: + trainer.model.xtts.gpt.train() + + def on_init_end(self, trainer): # pylint: disable=W0613 + # ignore similarities.pth on clearml save/upload + if self.config.dashboard_logger.lower() == "clearml": + from clearml.binding.frameworks import WeightsFileHandler + + WeightsFileHandler.add_pre_callback(callback_clearml_load_save) + + @torch.no_grad() + def inference( + self, + x, + aux_input=None, + ): # pylint: disable=dangerous-default-value + return None + + @staticmethod + def get_criterion(): + return None + + def get_sampler(self, dataset: TTSDataset, num_gpus=1): + # sampler for DDP + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + return batch_sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": # pylint: disable=W0613 + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval) + + # wait all the DDP process to be ready + if num_gpus > 1: + torch.distributed.barrier() + + # sort input sequences from short to long + # dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(dataset, num_gpus) + + # ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs + if sampler is None or is_eval: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + sampler=sampler, + batch_size = config.eval_batch_size if is_eval else config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the optimizer based on the config parameters.""" + # ToDo: deal with multi GPU training + if self.config.optimizer_wd_only_on_weights: + # parameters to only GPT model + net = self.xtts.gpt + + # normalizations + norm_modules = ( + nn.BatchNorm2d, + nn.InstanceNorm2d, + nn.BatchNorm1d, + nn.InstanceNorm1d, + nn.BatchNorm3d, + nn.InstanceNorm3d, + nn.GroupNorm, + nn.LayerNorm, + ) + # nn.Embedding + emb_modules = (nn.Embedding, nn.EmbeddingBag) + + param_names_notweights = set() + all_param_names = set() + param_map = {} + for mn, m in net.named_modules(): + for k, v in m.named_parameters(): + v.is_bias = k.endswith(".bias") + v.is_weight = k.endswith(".weight") + v.is_norm = isinstance(m, norm_modules) + v.is_emb = isinstance(m, emb_modules) + + fpn = "%s.%s" % (mn, k) if mn else k # full param name + all_param_names.add(fpn) + param_map[fpn] = v + if v.is_bias or v.is_norm or v.is_emb: + param_names_notweights.add(fpn) + + params_names_notweights = sorted(list(param_names_notweights)) + params_notweights = [param_map[k] for k in params_names_notweights] + params_names_weights = sorted(list(all_param_names ^ param_names_notweights)) + params_weights = [param_map[k] for k in params_names_weights] + + groups = [ + {"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]}, + {"params": params_notweights, "weight_decay": 0}, + ] + # torch.optim.AdamW + opt = get_optimizer( + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + parameters=groups, + ) + opt._group_names = [params_names_weights, params_names_notweights] + return opt + + return get_optimizer( + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + # optimize only for the GPT model + parameters=self.xtts.gpt.parameters(), + ) + + def get_scheduler(self, optimizer) -> List: + """Set the scheduler for the optimizer. + + Args: + optimizer: `torch.optim.Optimizer`. + """ + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + + def load_checkpoint( + self, + config, + checkpoint_path, + eval=False, + strict=True, + cache_storage="/tmp/tts_cache", + target_protocol="s3", + target_options={"anon": True}, + ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + + state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path) + + # load the model weights + self.xtts.load_state_dict(state, strict=strict) + + if eval: + self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) + self.eval() + assert not self.training + + @staticmethod + def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (GPTTrainerConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + return GPTTrainer(config) diff --git a/viXTTS/TTS/tts/layers/xtts/xtts_manager.py b/viXTTS/TTS/tts/layers/xtts/xtts_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..3e7d0f6c914fa3a4e706a5e28bbd745afcaa4d67 --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/xtts_manager.py @@ -0,0 +1,34 @@ +import torch + +class SpeakerManager(): + def __init__(self, speaker_file_path=None): + self.speakers = torch.load(speaker_file_path) + + @property + def name_to_id(self): + return self.speakers.keys() + + @property + def num_speakers(self): + return len(self.name_to_id) + + @property + def speaker_names(self): + return list(self.name_to_id.keys()) + + +class LanguageManager(): + def __init__(self, config): + self.langs = config["languages"] + + @property + def name_to_id(self): + return self.langs + + @property + def num_languages(self): + return len(self.name_to_id) + + @property + def language_names(self): + return list(self.name_to_id) diff --git a/viXTTS/TTS/tts/layers/xtts/zh_num2words.py b/viXTTS/TTS/tts/layers/xtts/zh_num2words.py new file mode 100644 index 0000000000000000000000000000000000000000..e59ccb66309aaebf67f4db972fc83058421d5ed8 --- /dev/null +++ b/viXTTS/TTS/tts/layers/xtts/zh_num2words.py @@ -0,0 +1,1209 @@ +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 - 2022 Jiayu DU + +import argparse +import csv +import os +import re +import string +import sys + +# fmt: off + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +FILLER_CHARS = ["呃", "啊"] + +ER_WHITELIST = ( + "(儿女|儿子|儿孙|女儿|儿媳|妻儿|" + "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|" + "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|" + "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)" +) +ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST) + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] + +CURRENCY_NAMES = "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)" +) + + +# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CN_PUNCS_STOP = "!?。。" +CN_PUNCS_NONSTOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-" +CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP + +PUNCS = CN_PUNCS + string.punctuation +PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma + + +# https://zh.wikipedia.org/wiki/全行和半行 +QJ2BJ = { + " ": " ", + "!": "!", + """: '"', + "#": "#", + "$": "$", + "%": "%", + "&": "&", + "'": "'", + "(": "(", + ")": ")", + "*": "*", + "+": "+", + ",": ",", + "-": "-", + ".": ".", + "/": "/", + "0": "0", + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + "6": "6", + "7": "7", + "8": "8", + "9": "9", + ":": ":", + ";": ";", + "<": "<", + "=": "=", + ">": ">", + "?": "?", + "@": "@", + "A": "A", + "B": "B", + "C": "C", + "D": "D", + "E": "E", + "F": "F", + "G": "G", + "H": "H", + "I": "I", + "J": "J", + "K": "K", + "L": "L", + "M": "M", + "N": "N", + "O": "O", + "P": "P", + "Q": "Q", + "R": "R", + "S": "S", + "T": "T", + "U": "U", + "V": "V", + "W": "W", + "X": "X", + "Y": "Y", + "Z": "Z", + "[": "[", + "\": "\\", + "]": "]", + "^": "^", + "_": "_", + "`": "`", + "a": "a", + "b": "b", + "c": "c", + "d": "d", + "e": "e", + "f": "f", + "g": "g", + "h": "h", + "i": "i", + "j": "j", + "k": "k", + "l": "l", + "m": "m", + "n": "n", + "o": "o", + "p": "p", + "q": "q", + "r": "r", + "s": "s", + "t": "t", + "u": "u", + "v": "v", + "w": "w", + "x": "x", + "y": "y", + "z": "z", + "{": "{", + "|": "|", + "}": "}", + "~": "~", +} +QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "") + + +# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources: +# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total +CN_CHARS_COMMON = ( + "一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举" + "乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互" + "亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从" + "仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优" + "伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚" + "佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣" + "侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯" + "俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌" + "偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚" + "僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六" + "兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况" + "冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈" + "刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐" + "剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼" + "劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹" + "区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵" + "卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔" + "叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊" + "同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆" + "呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎" + "咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌" + "响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛" + "唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴" + "啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌" + "嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡" + "嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓" + "嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢" + "圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥" + "坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩" + "垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基" + "埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填" + "塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复" + "夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖" + "套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮" + "妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈" + "娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻" + "婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱" + "嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽" + "宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾" + "宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝" + "尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山" + "屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃" + "峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧" + "崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉" + "巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡" + "带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐" + "庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷" + "建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖" + "彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循" + "徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀" + "态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓" + "恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟" + "悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭" + "惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥" + "慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我" + "戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔" + "托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡" + "抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥" + "拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫" + "振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎" + "掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭" + "揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘" + "摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢" + "擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整" + "敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗" + "旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝" + "星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡" + "晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜" + "曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽" + "杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅" + "枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔" + "柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩" + "株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯" + "桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘" + "棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂" + "楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱" + "榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞" + "橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正" + "此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒" + "毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮" + "氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽" + "汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽" + "沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼" + "泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿" + "流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸" + "浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅" + "淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟" + "渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆" + "溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕" + "滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶" + "漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧" + "澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵" + "灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈" + "烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯" + "焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵" + "熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍" + "牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷" + "犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎" + "猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎" + "玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊" + "珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊" + "琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙" + "瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱" + "璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯" + "田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐" + "疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒" + "痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩" + "瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙" + "皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾" + "省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦" + "睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知" + "矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮" + "砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍" + "碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷" + "磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭" + "祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣" + "秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗" + "穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立" + "竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯" + "笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅" + "箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾" + "簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮" + "粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢" + "縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁" + "绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩" + "绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕" + "编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐" + "网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲" + "羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者" + "耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋" + "职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸" + "肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳" + "胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒" + "腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻" + "臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般" + "舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊" + "芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈" + "苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆" + "茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐" + "荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛" + "莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥" + "菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著" + "葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺" + "蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷" + "蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸" + "薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱" + "虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆" + "蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗" + "蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃" + "螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡" + "蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒" + "袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂" + "褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉" + "觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认" + "讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词" + "诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请" + "诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡" + "谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹" + "豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼" + "贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤" + "赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑" + "跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮" + "踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇" + "躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较" + "辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄" + "迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆" + "选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒" + "道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱" + "邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴" + "郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝" + "酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭" + "醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒" + "钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼" + "钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨" + "铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐" + "锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹" + "锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣" + "镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼" + "闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶" + "阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈" + "隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳" + "零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰" + "靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵" + "韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额" + "颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰" + "饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥" + "馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑" + "骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高" + "髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾" + "鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨" + "鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓" + "鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶" + "鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟" + "鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄" + "黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷" + "鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃" + "㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡" + "䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽" + "𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯" + "𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟" + "𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟" + "𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟" + "𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓" +) +CN_CHARS_EXT = "吶诶屌囧飚屄" + +CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT +IN_CH_CHARS = {c: True for c in CN_CHARS} + +EN_CHARS = string.ascii_letters + string.digits +IN_EN_CHARS = {c: True for c in EN_CHARS} + +VALID_CHARS = CN_CHARS + EN_CHARS + " " +IN_VALID_CHARS = {c: True for c in VALID_CHARS} + + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + # self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + if small_unit: + return ChineseNumberUnit( + power=index + 1, simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + else: + raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + def get_value(value_string, use_zeros=True): + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string)) + result_string = value_string[: -result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power :]) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError("invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] + and result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]] + ): + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sil_parts]) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sp_parts]) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", 1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", 1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +def normalize_nsw(raw_text): + text = "^" + raw_text + "$" + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # restore P2P, O2O, B2C, B2B etc + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + + return text.lstrip("^").rstrip("$") + + +def remove_erhua(text): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + new_str = "" + while re.search("儿", text): + a = re.search("儿", text).span() + remove_er_flag = 0 + + if ER_WHITELIST_PATTERN.search(text): + b = ER_WHITELIST_PATTERN.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0: + new_str = new_str + text[0 : a[0]] + text = text[a[1] :] + else: + new_str = new_str + text[0 : b[1]] + text = text[b[1] :] + + text = new_str + text + return text + + +def remove_space(text): + tokens = text.split() + new = [] + for k, t in enumerate(tokens): + if k != 0: + if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]): + new.append(" ") + new.append(t) + return "".join(new) + + +class TextNorm: + def __init__( + self, + to_banjiao: bool = False, + to_upper: bool = False, + to_lower: bool = False, + remove_fillers: bool = False, + remove_erhua: bool = False, + check_chars: bool = False, + remove_space: bool = False, + cc_mode: str = "", + ): + self.to_banjiao = to_banjiao + self.to_upper = to_upper + self.to_lower = to_lower + self.remove_fillers = remove_fillers + self.remove_erhua = remove_erhua + self.check_chars = check_chars + self.remove_space = remove_space + + self.cc = None + if cc_mode: + from opencc import OpenCC # Open Chinese Convert: pip install opencc + + self.cc = OpenCC(cc_mode) + + def __call__(self, text): + if self.cc: + text = self.cc.convert(text) + + if self.to_banjiao: + text = text.translate(QJ2BJ_TRANSFORM) + + if self.to_upper: + text = text.upper() + + if self.to_lower: + text = text.lower() + + if self.remove_fillers: + for c in FILLER_CHARS: + text = text.replace(c, "") + + if self.remove_erhua: + text = remove_erhua(text) + + text = normalize_nsw(text) + + text = text.translate(PUNCS_TRANSFORM) + + if self.check_chars: + for c in text: + if not IN_VALID_CHARS.get(c): + print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr) + return "" + + if self.remove_space: + text = remove_space(text) + + return text + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + + # normalizer options + p.add_argument("--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao") + p.add_argument("--to_upper", action="store_true", help="convert to upper case") + p.add_argument("--to_lower", action="store_true", help="convert to lower case") + p.add_argument("--remove_fillers", action="store_true", help='remove filler chars such as "呃, 啊"') + p.add_argument("--remove_erhua", action="store_true", help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"') + p.add_argument("--check_chars", action="store_true", help="skip sentences containing illegal chars") + p.add_argument("--remove_space", action="store_true", help="remove whitespace") + p.add_argument( + "--cc_mode", choices=["", "t2s", "s2t"], default="", help="convert between traditional to simplified" + ) + + # I/O options + p.add_argument("--log_interval", type=int, default=10000, help="log interval in number of processed lines") + p.add_argument("--has_key", action="store_true", help="will be deprecated, set --format ark instead") + p.add_argument("--format", type=str, choices=["txt", "ark", "tsv"], default="txt", help="input format") + p.add_argument("ifile", help="input filename, assume utf-8 encoding") + p.add_argument("ofile", help="output filename") + + args = p.parse_args() + + if args.has_key: + args.format = "ark" + + normalizer = TextNorm( + to_banjiao=args.to_banjiao, + to_upper=args.to_upper, + to_lower=args.to_lower, + remove_fillers=args.remove_fillers, + remove_erhua=args.remove_erhua, + check_chars=args.check_chars, + remove_space=args.remove_space, + cc_mode=args.cc_mode, + ) + + normalizer = TextNorm( + to_banjiao=args.to_banjiao, + to_upper=args.to_upper, + to_lower=args.to_lower, + remove_fillers=args.remove_fillers, + remove_erhua=args.remove_erhua, + check_chars=args.check_chars, + remove_space=args.remove_space, + cc_mode=args.cc_mode, + ) + + ndone = 0 + with open(args.ifile, "r", encoding="utf8") as istream, open(args.ofile, "w+", encoding="utf8") as ostream: + if args.format == "tsv": + reader = csv.DictReader(istream, delimiter="\t") + assert "TEXT" in reader.fieldnames + print("\t".join(reader.fieldnames), file=ostream) + + for item in reader: + text = item["TEXT"] + + if text: + text = normalizer(text) + + if text: + item["TEXT"] = text + print("\t".join([item[f] for f in reader.fieldnames]), file=ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True) + else: + for l in istream: + key, text = "", "" + if args.format == "ark": # KALDI archive, line format: "key text" + cols = l.strip().split(maxsplit=1) + key, text = cols[0], cols[1] if len(cols) == 2 else "" + else: + text = l.strip() + + if text: + text = normalizer(text) + + if text: + if args.format == "ark": + print(key + "\t" + text, file=ostream) + else: + print(text, file=ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True) + print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True) diff --git a/viXTTS/TTS/tts/models/__init__.py b/viXTTS/TTS/tts/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd2e5f0875a84633e707702cd7d628409b12057 --- /dev/null +++ b/viXTTS/TTS/tts/models/__init__.py @@ -0,0 +1,14 @@ +from typing import Dict, List, Union + +from TTS.utils.generic_utils import find_module + + +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": + print(" > Using model: {}".format(config.model)) + # fetch the right model implementation. + if "base_model" in config and config["base_model"] is not None: + MyModel = find_module("TTS.tts.models", config.base_model.lower()) + else: + MyModel = find_module("TTS.tts.models", config.model.lower()) + model = MyModel.init_from_config(config=config, samples=samples) + return model diff --git a/viXTTS/TTS/tts/models/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/tts/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6ffefac38d9c081c982d69c1bd34b37b0ec7900 Binary files /dev/null and b/viXTTS/TTS/tts/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/models/__pycache__/base_tts.cpython-310.pyc b/viXTTS/TTS/tts/models/__pycache__/base_tts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e106b65ce2400bc2536a1cd8a9b762ef1b22b8ce Binary files /dev/null and b/viXTTS/TTS/tts/models/__pycache__/base_tts.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/models/__pycache__/xtts.cpython-310.pyc b/viXTTS/TTS/tts/models/__pycache__/xtts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cda97a0e8489c9b371c22e2b2d44960c7262f81 Binary files /dev/null and b/viXTTS/TTS/tts/models/__pycache__/xtts.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/models/align_tts.py b/viXTTS/TTS/tts/models/align_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e51de7d6ab37951e3838e6804ca8e9b71338cf --- /dev/null +++ b/viXTTS/TTS/tts/models/align_tts.py @@ -0,0 +1,448 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.align_tts.mdn import MDNBlock +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.io import load_fsspec + + +@dataclass +class AlignTTSArgs(Coqpit): + """ + Args: + num_chars (int): + number of unique input to characters + out_channels (int): + number of output tensor channels. It is equal to the expected spectrogram size. + hidden_channels (int): + number of channels in all the model layers. + hidden_channels_ffn (int): + number of channels in transformer's conv layers. + hidden_channels_dp (int): + number of channels in duration predictor network. + num_heads (int): + number of attention heads in transformer networks. + num_transformer_layers (int): + number of layers in encoder and decoder transformer blocks. + dropout_p (int): + dropout rate in transformer layers. + length_scale (int, optional): + coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1. + num_speakers (int, optional): + number of speakers for multi-speaker training. Defaults to 0. + external_c (bool, optional): + enable external speaker embeddings. Defaults to False. + c_in_channels (int, optional): + number of channels in speaker embedding vectors. Defaults to 0. + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 256 + hidden_channels_dp: int = 256 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + length_scale: float = 1.0 + num_speakers: int = 0 + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_dim: int = 0 + + +class AlignTTS(BaseTTS): + """AlignTTS with modified duration predictor. + https://arxiv.org/pdf/2003.01950.pdf + + Encoder -> DurationPredictor -> Decoder + + Check :class:`AlignTTSArgs` for the class arguments. + + Paper Abstract: + Targeting at both high efficiency and performance, we propose AlignTTS to predict the + mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a + sequence of characters, and the duration of each character is determined by a duration predictor.Instead of + adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented + to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s + how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean + option score (MOS), but also a high efficiency which is more than 50 times faster than real-time. + + Note: + Original model uses a separate character embedding layer for duration predictor. However, it causes the + duration predictor to overfit and prevents learning higher level interactions among characters. Therefore, + we predict durations based on encoder outputs which has higher level information about input characters. This + enables training without phases as in the original paper. + + Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture + differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. + + Examples: + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + >>> model = AlignTTS(config) + + """ + + # pylint: disable=dangerous-default-value + + def __init__( + self, + config: "AlignTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self.speaker_manager = speaker_manager + self.phase = -1 + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + + self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) + + self.embedded_speaker_dim = 0 + self.init_multispeaker(config) + + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + self.embedded_speaker_dim, + ) + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp) + + self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1) + + self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels) + + if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1) + + @staticmethod + def compute_log_probs(mu, log_sigma, y): + # pylint: disable=protected-access, c-extension-no-member + y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) + exponential = -0.5 * torch.mean( + torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1 + ) # B, L, T + logp = exponential - 0.5 * log_sigma.mean(dim=-1) + return logp + + def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): + # find the max alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + log_p = self.compute_log_probs(mu, log_sigma, y) + # [B, T_en, T_dec] + attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1) + dr_mas = torch.sum(attn, -1) + return dr_mas.squeeze(1), log_p + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Examples:: + - encoder output: [a,b,c,d] + - durations: [1, 3, 2, 1] + + - expanded: [a, b, b, b, c, c, d] + - attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de, attn.transpose(1, 2) + + def _forward_mdn(self, o_en, y, y_lengths, x_mask): + # MAS potentials and alignment + mu, log_sigma = self.mdn_block(o_en) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) + return dr_mas, mu, log_sigma, logp + + def forward( + self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None + ): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + """ + y = y.transpose(1, 2) + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None + if phase == 0: + # train encoder and MDN + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + attn = self.generate_attn(dr_mas, x_mask, y_mask) + elif phase == 1: + # train decoder + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) + elif phase == 2: + # train the whole except duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + elif phase == 3: + # train duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(x, x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + else: + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + dr_mas_log = torch.log(dr_mas + 1).squeeze(1) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "alignments": attn, + "durations_log": o_dr_log, + "durations_mas_log": dr_mas_log, + "mu": mu, + "log_sigma": log_sigma, + "logp": logp, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - g: :math:`[B, C]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + # pad input to prevent dropping the last word + # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # o_dr_log = self.duration_predictor(x, x_mask) + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + # duration predictor pass + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn} + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input, self.phase) + loss_dict = criterion( + outputs["logp"], + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + outputs["durations_mas_log"], + text_lengths, + phase=self.phase, + ) + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel + + return AlignTTSLoss(self.config) + + @staticmethod + def _set_phase(config, global_step): + """Decide AlignTTS training phase""" + if isinstance(config.phase_start_steps, list): + vals = [i < global_step for i in config.phase_start_steps] + if not True in vals: + phase = 0 + else: + phase = ( + len(config.phase_start_steps) + - [i < global_step for i in config.phase_start_steps][::-1].index(True) + - 1 + ) + else: + phase = None + return phase + + def on_epoch_start(self, trainer): + """Set AlignTTS training phase on epoch start.""" + self.phase = self._set_phase(trainer.config, trainer.total_steps_done) + + @staticmethod + def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (AlignTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return AlignTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/viXTTS/TTS/tts/models/bark.py b/viXTTS/TTS/tts/models/bark.py new file mode 100644 index 0000000000000000000000000000000000000000..e5edffd4ef4150b47d1ad7da5a705ab4f44ed889 --- /dev/null +++ b/viXTTS/TTS/tts/models/bark.py @@ -0,0 +1,284 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from coqpit import Coqpit +from encodec import EncodecModel +from transformers import BertTokenizer + +from TTS.tts.layers.bark.inference_funcs import ( + codec_decode, + generate_coarse, + generate_fine, + generate_text_semantic, + generate_voice, + load_voice, +) +from TTS.tts.layers.bark.load_model import load_model +from TTS.tts.layers.bark.model import GPT +from TTS.tts.layers.bark.model_fine import FineGPT +from TTS.tts.models.base_tts import BaseTTS + + +@dataclass +class BarkAudioConfig(Coqpit): + sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +class Bark(BaseTTS): + def __init__( + self, + config: Coqpit, + tokenizer: BertTokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased"), + ) -> None: + super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None) + self.config.num_chars = len(tokenizer) + self.tokenizer = tokenizer + self.semantic_model = GPT(config.semantic_config) + self.coarse_model = GPT(config.coarse_config) + self.fine_model = FineGPT(config.fine_config) + self.encodec = EncodecModel.encodec_model_24khz() + self.encodec.set_target_bandwidth(6.0) + + @property + def device(self): + return next(self.parameters()).device + + def load_bark_models(self): + self.semantic_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text" + ) + self.coarse_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["coarse"], + device=self.device, + config=self.config, + model_type="coarse", + ) + self.fine_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine" + ) + + def train_step( + self, + ): + pass + + def text_to_semantic( + self, + text: str, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate semantic array from text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy semantic array to be fed into `semantic_to_waveform` + """ + x_semantic = generate_text_semantic( + text, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + return x_semantic + + def semantic_to_waveform( + self, + semantic_tokens: np.ndarray, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + ): + """Generate audio array from semantic input. + + Args: + semantic_tokens: semantic token output from `text_to_semantic` + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_coarse_gen = generate_coarse( + semantic_tokens, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + ) + x_fine_gen = generate_fine( + x_coarse_gen, + self, + history_prompt=history_prompt, + temp=0.5, + base=base, + ) + audio_arr = codec_decode(x_fine_gen, self) + return audio_arr, x_coarse_gen, x_fine_gen + + def generate_audio( + self, + text: str, + history_prompt: Optional[str] = None, + text_temp: float = 0.7, + waveform_temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate audio array from input text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_semantic = self.text_to_semantic( + text, + history_prompt=history_prompt, + temp=text_temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + audio_arr, c, f = self.semantic_to_waveform( + x_semantic, history_prompt=history_prompt, temp=waveform_temp, base=base + ) + return audio_arr, [x_semantic, c, f] + + def generate_voice(self, audio, speaker_id, voice_dir): + """Generate a voice from the given audio and text. + + Args: + audio (str): Path to the audio file. + speaker_id (str): Speaker name. + voice_dir (str): Path to the directory to save the generate voice. + """ + if voice_dir is not None: + voice_dirs = [voice_dir] + try: + _ = load_voice(speaker_id, voice_dirs) + except (KeyError, FileNotFoundError): + output_path = os.path.join(voice_dir, speaker_id + ".npz") + os.makedirs(voice_dir, exist_ok=True) + generate_voice(audio, self, output_path) + + def _set_voice_dirs(self, voice_dirs): + def_voice_dir = None + if isinstance(self.config.DEF_SPEAKER_DIR, str): + os.makedirs(self.config.DEF_SPEAKER_DIR, exist_ok=True) + if os.path.isdir(self.config.DEF_SPEAKER_DIR): + def_voice_dir = self.config.DEF_SPEAKER_DIR + _voice_dirs = [def_voice_dir] if def_voice_dir is not None else [] + if voice_dirs is not None: + if isinstance(voice_dirs, str): + voice_dirs = [voice_dirs] + _voice_dirs = voice_dirs + _voice_dirs + return _voice_dirs + + # TODO: remove config from synthesize + def synthesize( + self, text, config, speaker_id="random", voice_dirs=None, **kwargs + ): # pylint: disable=unused-argument + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (BarkConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + speaker_wav (str): Path to the speaker audio file for cloning a new voice. It is cloned and saved in + `voice_dirs` with the name `speaker_id`. Defaults to None. + voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Model specific inference settings used by `generate_audio()` and `TTS.tts.layers.bark.inference_funcs.generate_text_semantic(). + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + speaker_id = "random" if speaker_id is None else speaker_id + voice_dirs = self._set_voice_dirs(voice_dirs) + history_prompt = load_voice(self, speaker_id, voice_dirs) + outputs = self.generate_audio(text, history_prompt=history_prompt, **kwargs) + return_dict = { + "wav": outputs[0], + "text_inputs": text, + } + + return return_dict + + def eval_step(self): + ... + + def forward(self): + ... + + def inference(self): + ... + + @staticmethod + def init_from_config(config: "BarkConfig", **kwargs): # pylint: disable=unused-argument + return Bark(config) + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint( + self, + config, + checkpoint_dir, + text_model_path=None, + coarse_model_path=None, + fine_model_path=None, + hubert_model_path=None, + hubert_tokenizer_path=None, + eval=False, + strict=True, + **kwargs, + ): + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt") + coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt") + fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt") + hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt") + hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth") + + self.config.LOCAL_MODEL_PATHS["text"] = text_model_path + self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path + self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path + self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path + self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path + + self.load_bark_models() + + if eval: + self.eval() diff --git a/viXTTS/TTS/tts/models/base_tacotron.py b/viXTTS/TTS/tts/models/base_tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..f38dace23559d9940b9a57cab479686f42172540 --- /dev/null +++ b/viXTTS/TTS/tts/models/base_tacotron.py @@ -0,0 +1,305 @@ +import copy +from abc import abstractmethod +from typing import Dict, Tuple + +import torch +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.losses import TacotronLoss +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec +from TTS.utils.training import gradual_training_scheduler + + +class BaseTacotron(BaseTTS): + """Base class shared by Tacotron and Tacotron2""" + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields as class attributes + for key in config: + setattr(self, key, config[key]) + + # layers + self.embedding = None + self.encoder = None + self.decoder = None + self.postnet = None + + # init tensors + self.embedded_speakers = None + self.embedded_speakers_projected = None + + # global style token + if self.gst and self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim + self.gst_layer = None + + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim + self.capacitron_vae_layer = None + + # additional layers + self.decoder_backward = None + self.coarse_decoder = None + + @staticmethod + def _format_aux_input(aux_input: Dict) -> Dict: + """Set missing fields to their default values""" + if aux_input: + return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + return None + + ############################# + # INIT FUNCTIONS + ############################# + + def _init_backward_decoder(self): + """Init the backward decoder for Forward-Backward decoding.""" + self.decoder_backward = copy.deepcopy(self.decoder) + + def _init_coarse_decoder(self): + """Init the coarse decoder for Double-Decoder Consistency.""" + self.coarse_decoder = copy.deepcopy(self.decoder) + self.coarse_decoder.r_init = self.ddc_r + self.coarse_decoder.set_r(self.ddc_r) + + ############################# + # CORE FUNCTIONS + ############################# + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self): + pass + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load model checkpoint and set up internals. + + Args: + config (Coqpi): model configuration. + checkpoint_path (str): path to checkpoint file. + eval (bool, optional): whether to load model for evaluation. + cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. + """ + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + # TODO: set r in run-time by taking it from the new config + if "r" in state: + # set r from the state (for compatibility with older checkpoints) + self.decoder.set_r(state["r"]) + elif "config" in state: + # set r from config used at training time (for inference) + self.decoder.set_r(state["config"]["r"]) + else: + # set r from the new config (for new-models) + self.decoder.set_r(config.r) + if eval: + self.eval() + print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") + assert not self.training + + def get_criterion(self) -> nn.Module: + """Get the model criterion used in training.""" + return TacotronLoss(self.config) + + @staticmethod + def init_from_config(config: Coqpit): + """Initialize model from config.""" + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return BaseTacotron(config, ap, tokenizer, speaker_manager) + + ########################## + # TEST AND LOG FUNCTIONS # + ########################## + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + ############################# + # COMMON COMPUTE FUNCTIONS + ############################# + + def compute_masks(self, text_lengths, mel_lengths): + """Compute masks against sequence paddings.""" + # B x T_in_max (boolean) + input_mask = sequence_mask(text_lengths) + output_mask = None + if mel_lengths is not None: + max_len = mel_lengths.max() + r = self.decoder.r + max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len + output_mask = sequence_mask(mel_lengths, max_len=max_len) + return input_mask, output_mask + + def _backward_pass(self, mel_specs, encoder_outputs, mask): + """Run backwards decoder""" + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask + ) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() + return decoder_outputs_b, alignments_b + + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): + """Double Decoder Consistency""" + T = mel_specs.shape[1] + if T % self.coarse_decoder.r > 0: + padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) + mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) + decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( + encoder_outputs.detach(), mel_specs, input_mask + ) + # scale_factor = self.decoder.r_init / self.decoder.r + alignments_backward = torch.nn.functional.interpolate( + alignments_backward.transpose(1, 2), + size=alignments.shape[1], + mode="nearest", + ).transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward[:, :T, :] + return decoder_outputs_backward, alignments_backward + + ############################# + # EMBEDDING FUNCTIONS + ############################# + + def compute_gst(self, inputs, style_input, speaker_embedding=None): + """Compute global style token""" + if isinstance(style_input, dict): + # multiply each style token with a weight + query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) + if speaker_embedding is not None: + query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) + + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + # ignore style token and return zero tensor + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + else: + # compute style tokens + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) + return inputs + + def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None): + """Capacitron Variational Autoencoder""" + ( + VAE_outputs, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) = self.capacitron_vae_layer( + reference_mel_info, + text_info, + speaker_embedding, # pylint: disable=not-callable + ) + + VAE_outputs = VAE_outputs.to(inputs.device) + encoder_output = self._concat_speaker_embedding( + inputs, VAE_outputs + ) # concatenate to the output of the basic tacotron encoder + return ( + encoder_output, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) + + @staticmethod + def _add_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = outputs + embedded_speakers_ + return outputs + + @staticmethod + def _concat_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, embedded_speakers_], dim=-1) + return outputs + + ############################# + # CALLBACKS + ############################# + + def on_epoch_start(self, trainer): + """Callback for setting values wrt gradual training schedule. + + Args: + trainer (TrainerTTS): TTS trainer object that is used to train this model. + """ + if self.gradual_training: + r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) + trainer.config.r = r + self.decoder.set_r(r) + if trainer.config.bidirectional_decoder: + trainer.model.decoder_backward.set_r(r) + print(f"\n > Number of output frames: {self.decoder.r}") diff --git a/viXTTS/TTS/tts/models/base_tts.py b/viXTTS/TTS/tts/models/base_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..7871cc38c327f59a70b6695fded64e34ff7d7626 --- /dev/null +++ b/viXTTS/TTS/tts/models/base_tts.py @@ -0,0 +1,459 @@ +import os +import random +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper + +from TTS.model import BaseTrainerModel +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights, get_speaker_manager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + +# pylint: skip-file + + +class BaseTTS(BaseTrainerModel): + """Base `tts` class. Every new `tts` model must inherit this. + + It defines common `tts` specific functions on top of `Model` implementation. + """ + + MODEL_TYPE = "tts" + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__() + self.config = config + self.ap = ap + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + config_num_chars = ( + self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars + ) + num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + if "characters" in config: + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. + + Args: + config (Coqpit): Model configuration. + """ + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if self.speaker_manager is not None: + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if self.language_manager is not None and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `TTSDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + energy = batch["energy"] + language_ids = batch["language_ids"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_names"], + } + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + print(" > Using Length weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None + else: + speaker_id_mapping = None + d_vector_mapping = None + + # setup multi-lingual attributes + if self.language_manager is not None: + language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), + compute_energy=config.get("compute_energy", False), + energy_cache_path=config.get("energy_cache_path", None), + samples=samples, + ap=self.ap, + return_wav=config.return_wav if "return_wav" in config else False, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + use_noise_augment=False if is_eval else config.use_noise_augment, + verbose=verbose, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + language_id_mapping=language_id_mapping, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=config.shuffle if sampler is None else False, # if there is no other sampler + collate_fn=dataset.collate_fn, + drop_last=config.drop_last, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def _get_test_aux_input( + self, + ) -> Dict: + d_vector = None + if self.config.use_d_vector_file: + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] + d_vector = (random.sample(sorted(d_vector), 1),) + + aux_inputs = { + "speaker_id": None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1), + "d_vector": d_vector, + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + if isinstance(sen, list): + aux_inputs = self.get_aux_input_from_test_sentences(sen) + sen = aux_inputs["text"] + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return test_figures, test_audios + + def on_init_start(self, trainer): + """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.pth") + self.speaker_manager.save_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `speakers.pth` is saved to {output_path}.") + print(" > `speakers_file` is updated in the config.json.") + + if self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `language_ids.json` is saved to {output_path}.") + print(" > `language_ids_file` is updated in the config.json.") + + +class BaseTTSE2E(BaseTTS): + def _set_model_args(self, config: Coqpit): + self.config = config + if "Config" in config.__class__.__name__: + num_chars = ( + self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + ) + self.config.model_args.num_chars = num_chars + self.config.num_chars = num_chars + self.args = config.model_args + self.args.num_chars = num_chars + elif "Args" in config.__class__.__name__: + self.args = config + self.args.num_chars = self.args.num_chars + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/viXTTS/TTS/tts/models/delightful_tts.py b/viXTTS/TTS/tts/models/delightful_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..b1cf886bea1f22a0ec1b0524f5a80ea8db0b55f8 --- /dev/null +++ b/viXTTS/TTS/tts/models/delightful_tts.py @@ -0,0 +1,1770 @@ +import os +from dataclasses import dataclass, field +from itertools import chain +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample +from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel +from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.models.base_tts import BaseTTSE2E +from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram +from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 +from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy +from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy +from TTS.utils.audio.processor import AudioProcessor +from TTS.utils.io import load_fsspec +from TTS.vocoder.layers.losses import MultiScaleSTFTLoss +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + + +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) + if cuda: + return aux_id.cuda() + return aux_id + + +def embedding_to_torch(d_vector, cuda=False): + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).float() + d_vector = d_vector.squeeze().unsqueeze(0) + if cuda: + return d_vector.cuda() + return d_vector + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + return mask + + +def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: + out_list = torch.jit.annotate(List[torch.Tensor], []) + for batch in input_ele: + if len(batch.shape) == 1: + one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) + else: + one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded + + +def init_weights(m: nn.Module, mean: float = 0.0, std: float = 0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: + return torch.ceil(lens / stride).int() + + +def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: + assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." + return torch.randn(shape) * np.sqrt(2 / shape[1]) + + +# pylint: disable=redefined-outer-name +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path: str): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load( + file_path, + ) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window # pylint: disable=global-statement + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + return spec + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def wav_to_energy(y, n_fft, hop_length, win_length, center=False): + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return torch.norm(spec, dim=1, keepdim=True) + + +def name_mel_basis(spec, n_fft, fmax): + n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" + return n_fft_len + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis # pylint: disable=global-statement + mel_basis_key = name_mel_basis(spec, n_fft, fmax) + # pylint: disable=too-many-function-args + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[mel_basis_key], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T_y]` + + Return Shapes: + - spec : :math:`[B,C,T_spec]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + mel_basis_key = name_mel_basis(y, n_fft, fmax) + wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn( + sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) # pylint: disable=too-many-function-args + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[mel_basis_key], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create balancer weight for torch WeightedSampler""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class ForwardTTSE2eF0Dataset(F0Dataset): + """Override F0Dataset to avoid slow computing of pitches""" + + def __init__( + self, + ap, + samples: Union[List[List], List[Dict]], + verbose=False, + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + super().__init__( + samples=samples, + ap=ap, + verbose=verbose, + cache_path=cache_path, + precompute_num_workers=precompute_num_workers, + normalize_f0=normalize_f0, + ) + + def _compute_and_save_pitch(self, wav_file, pitch_file=None): + wav, _ = load_audio(wav_file) + f0 = compute_f0( + x=wav.numpy()[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + pitch_fmin=self.ap.pitch_fmin, + win_length=self.ap.win_length, + ) + # skip the last F0 value to align with the spectrogram + if wav.shape[1] % self.ap.hop_length != 0: + f0 = f0[:-1] + if pitch_file: + np.save(pitch_file, f0) + return f0 + + def compute_or_load(self, wav_file, audio_name): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(audio_name, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(wav_file=wav_file, pitch_file=pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + +class ForwardTTSE2eDataset(TTSDataset): + def __init__(self, *args, **kwargs): + # don't init the default F0Dataset in TTSDataset + compute_f0 = kwargs.pop("compute_f0", False) + kwargs["compute_f0"] = False + self.attn_prior_cache_path = kwargs.pop("attn_prior_cache_path") + + super().__init__(*args, **kwargs) + + self.compute_f0 = compute_f0 + self.pad_id = self.tokenizer.characters.pad_id + self.ap = kwargs["ap"] + + if self.compute_f0: + self.f0_dataset = ForwardTTSE2eF0Dataset( + ap=self.ap, + samples=self.samples, + cache_path=kwargs["f0_cache_path"], + precompute_num_workers=kwargs["precompute_num_workers"], + ) + + if self.attn_prior_cache_path is not None: + os.makedirs(self.attn_prior_cache_path, exist_ok=True) + + def __getitem__(self, idx): + item = self.samples[idx] + + rel_wav_path = Path(item["audio_file"]).relative_to(item["root_path"]).with_suffix("") + rel_wav_path = str(rel_wav_path).replace("/", "_") + + raw_text = item["text"] + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) + + try: + token_ids = self.get_token_ids(idx, item["text"]) + except: + print(idx, item) + # pylint: disable=raise-missing-from + raise OSError + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + attn_prior = None + if self.attn_prior_cache_path is not None: + attn_prior = self.load_or_compute_attn_prior(token_ids, wav, rel_wav_path) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "pitch": f0, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "attn_prior": attn_prior, + "audio_unique_name": item["audio_unique_name"], + } + + def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): + """Load or compute and save the attention prior.""" + attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy") + # pylint: disable=no-else-return + if os.path.exists(attn_prior_file): + return np.load(attn_prior_file) + else: + token_len = len(token_ids) + mel_len = wav.shape[1] // self.ap.hop_length + attn_prior = compute_attn_prior(token_len, mel_len) + np.save(attn_prior_file, attn_prior) + return attn_prior + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - pitch :math:`[B, T]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - attn_prior: :math:`[[T_token, T_mel]]` + """ + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + pitch_padded = None + if self.compute_f0: + pitch_lens = [p.shape[0] for p in batch["pitch"]] + pitch_lens = torch.LongTensor(pitch_lens) + pitch_lens_max = torch.max(pitch_lens) + pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) + pitch_padded = pitch_padded.zero_() + self.pad_id + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + + for i in range(B): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + if self.compute_f0: + pitch = batch["pitch"][i] + pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) + + return { + "text_input": token_padded, + "text_lengths": token_lens, + "text_rel_lens": token_rel_lens, + "pitch": pitch_padded, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_unique_names": batch["audio_unique_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "attn_priors": batch["attn_prior"] if batch["attn_prior"][0] is not None else None, + } + + +############################## +# CONFIG DEFINITIONS +############################## + + +@dataclass +class VocoderConfig(Coqpit): + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + use_spectral_norm_discriminator: bool = False + upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) + periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + pretrained_model_path: Optional[str] = None + + +@dataclass +class DelightfulTtsAudioConfig(Coqpit): + sample_rate: int = 22050 + hop_length: int = 256 + win_length: int = 1024 + fft_size: int = 1024 + mel_fmin: float = 0.0 + mel_fmax: float = 8000 + num_mels: int = 100 + pitch_fmax: float = 640.0 + pitch_fmin: float = 1.0 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + do_trim_silence: bool = True + trim_db: int = 45 + do_rms_norm: bool = False + db_level: float = None + power: float = 1.5 + griffin_lim_iters: int = 60 + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + min_level_db: int = -100 + max_norm: float = 4.0 + + +@dataclass +class DelightfulTtsArgs(Coqpit): + num_chars: int = 100 + spec_segment_size: int = 32 + n_hidden_conformer_encoder: int = 512 + n_layers_conformer_encoder: int = 6 + n_heads_conformer_encoder: int = 8 + dropout_conformer_encoder: float = 0.1 + kernel_size_conv_mod_conformer_encoder: int = 7 + kernel_size_depthwise_conformer_encoder: int = 7 + lrelu_slope: float = 0.3 + n_hidden_conformer_decoder: int = 512 + n_layers_conformer_decoder: int = 6 + n_heads_conformer_decoder: int = 8 + dropout_conformer_decoder: float = 0.1 + kernel_size_conv_mod_conformer_decoder: int = 11 + kernel_size_depthwise_conformer_decoder: int = 11 + bottleneck_size_p_reference_encoder: int = 4 + bottleneck_size_u_reference_encoder: int = 512 + ref_enc_filters_reference_encoder = [32, 32, 64, 64, 128, 128] + ref_enc_size_reference_encoder: int = 3 + ref_enc_strides_reference_encoder = [1, 2, 1, 2, 1] + ref_enc_pad_reference_encoder = [1, 1] + ref_enc_gru_size_reference_encoder: int = 32 + ref_attention_dropout_reference_encoder: float = 0.2 + token_num_reference_encoder: int = 32 + predictor_kernel_size_reference_encoder: int = 5 + n_hidden_variance_adaptor: int = 512 + kernel_size_variance_adaptor: int = 5 + dropout_variance_adaptor: float = 0.5 + n_bins_variance_adaptor: int = 256 + emb_kernel_size_variance_adaptor: int = 3 + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: str = None + speaker_embedding_channels: int = 384 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + freeze_vocoder: bool = False + freeze_text_encoder: bool = False + freeze_duration_predictor: bool = False + freeze_pitch_predictor: bool = False + freeze_energy_predictor: bool = False + freeze_basis_vectors_predictor: bool = False + freeze_decoder: bool = False + length_scale: float = 1.0 + + +############################## +# MODEL DEFINITION +############################## +class DelightfulTTS(BaseTTSE2E): + """ + Paper:: + https://arxiv.org/pdf/2110.12612.pdf + + Paper Abstract:: + This paper describes the Microsoft end-to-end neural text to speech (TTS) system: DelightfulTTS for Blizzard Challenge 2021. + The goal of this challenge is to synthesize natural and high-quality speech from text, and we approach this goal in two perspectives: + The first is to directly model and generate waveform in 48 kHz sampling rate, which brings higher perception quality than previous systems + with 16 kHz or 24 kHz sampling rate; The second is to model the variation information in speech through a systematic design, which improves + the prosody and naturalness. Specifically, for 48 kHz modeling, we predict 16 kHz mel-spectrogram in acoustic model, and + propose a vocoder called HiFiNet to directly generate 48 kHz waveform from predicted 16 kHz mel-spectrogram, which can better trade off training + efficiency, modelling stability and voice quality. We model variation information systematically from both explicit (speaker ID, language ID, pitch and duration) and + implicit (utterance-level and phoneme-level prosody) perspectives: 1) For speaker and language ID, we use lookup embedding in training and + inference; 2) For pitch and duration, we extract the values from paired text-speech data in training and use two predictors to predict the values in inference; 3) + For utterance-level and phoneme-level prosody, we use two reference encoders to extract the values in training, and use two separate predictors to predict the values in inference. + Additionally, we introduce an improved Conformer block to better model the local and global dependency in acoustic model. For task SH1, DelightfulTTS achieves 4.17 mean score in MOS test + and 4.35 in SMOS test, which indicates the effectiveness of our proposed system + + + Model training:: + text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg + spec --------^ + + Examples: + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig + >>> config = ForwardTTSE2eConfig() + >>> model = ForwardTTSE2e(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config=config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) + self.ap = ap + + self._set_model_args(config) + self.init_multispeaker(config) + self.binary_loss_weight = None + + self.args.out_channels = self.config.audio.num_mels + self.args.num_mels = self.config.audio.num_mels + self.acoustic_model = AcousticModel(args=self.args, tokenizer=tokenizer, speaker_manager=speaker_manager) + + self.waveform_decoder = HifiganGenerator( + self.config.audio.num_mels, + 1, + self.config.vocoder.resblock_type_decoder, + self.config.vocoder.resblock_dilation_sizes_decoder, + self.config.vocoder.resblock_kernel_sizes_decoder, + self.config.vocoder.upsample_kernel_sizes_decoder, + self.config.vocoder.upsample_initial_channel_decoder, + self.config.vocoder.upsample_rates_decoder, + inference_padding=0, + # cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.config.init_discriminator: + self.disc = VitsDiscriminator( + use_spectral_norm=self.config.vocoder.use_spectral_norm_discriminator, + periods=self.config.vocoder.periods_discriminator, + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def energy_scaler(self): + return self.acoustic_model.energy_scaler + + @property + def length_scale(self): + return self.acoustic_model.length_scale + + @length_scale.setter + def length_scale(self, value): + self.acoustic_model.length_scale = value + + @property + def pitch_mean(self): + return self.acoustic_model.pitch_mean + + @pitch_mean.setter + def pitch_mean(self, value): + self.acoustic_model.pitch_mean = value + + @property + def pitch_std(self): + return self.acoustic_model.pitch_std + + @pitch_std.setter + def pitch_std(self, value): + self.acoustic_model.pitch_std = value + + @property + def mel_basis(self): + return build_mel_basis( + sample_rate=self.ap.sample_rate, + fft_size=self.ap.fft_size, + num_mels=self.ap.num_mels, + mel_fmax=self.ap.mel_fmax, + mel_fmin=self.ap.mel_fmin, + ) # pylint: disable=function-redefined + + def init_for_training(self) -> None: + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + self.config.steps_to_start_discriminator <= 0 + ) # pylint: disable=attribute-defined-outside-init + self.update_energy_scaler = True # pylint: disable=attribute-defined-outside-init + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + self.args.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.args.embedded_speaker_dim = self.args.speaker_embedding_channels + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + self.args.embedded_speaker_dim = self.args.d_vector_dim + + def _freeze_layers(self): + if self.args.freeze_vocoder: + for param in self.vocoder.paramseters(): + param.requires_grad = False + + if self.args.freeze_text_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_duration_predictor: + for param in self.durarion_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_pitch_predictor: + for param in self.pitch_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_energy_predictor: + for param in self.energy_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_decoder: + for param in self.decoder.parameters(): + param.requires_grad = False + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + spec_lengths: torch.LongTensor, + spec: torch.FloatTensor, + waveform: torch.FloatTensor, + pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, + attn_priors: torch.FloatTensor = None, + d_vectors: torch.FloatTensor = None, + speaker_idx: torch.LongTensor = None, + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None. + spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + waveform (torch.FloatTensor): Waveform. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): Spectral energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. + attn_priors (torch.FloatTentrasor): Attention priors for the aligner network. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - spec_lengths: :math:`[B]` + - spec: :math:`[B, T_max2, C_spec]` + - waveform: :math:`[B, 1, T_max2 * hop_length]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T_max2]` + - energy: :math:`[B, 1, T_max2]` + """ + encoder_outputs = self.acoustic_model( + tokens=x, + src_lens=x_lengths, + mel_lens=spec_lengths, + mels=spec, + pitches=pitch, + energies=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_idx, + ) + + # use mel-spec from the decoder + vocoder_input = encoder_outputs["model_outputs"] # [B, T_max2, C_mel] + + vocoder_input_slices, slice_ids = rand_segments( + x=vocoder_input.transpose(1, 2), + x_lengths=spec_lengths, + segment_size=self.args.spec_segment_size, + let_short_samples=True, + pad_short=True, + ) + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input_slices.detach(), g=g) + wav_seg = segment( + waveform, + slice_ids * self.ap.hop_length, + self.args.spec_segment_size * self.ap.hop_length, + pad_short=True, + ) + model_outputs = {**encoder_outputs} + model_outputs["acoustic_model_outputs"] = encoder_outputs["model_outputs"] + model_outputs["model_outputs"] = vocoder_output + model_outputs["waveform_seg"] = wav_seg + model_outputs["slice_ids"] = slice_ids + return model_outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"d_vectors": None, "speaker_ids": None}, pitch_transform=None, energy_transform=None + ): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + pitch_transform=pitch_transform, + energy_transform=energy_transform, + p_control=None, + d_control=None, + ) + vocoder_input = encoder_outputs["model_outputs"].transpose(1, 2) # [B, T_max2, C_mel] -> [B, C_mel, T_max2] + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input, g=g) + model_outputs = {**encoder_outputs} + model_outputs["model_outputs"] = vocoder_output + return model_outputs + + @torch.no_grad() + def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + ) + model_outputs = {**encoder_outputs} + return model_outputs + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + if optimizer_idx == 0: + tokens = batch["text_input"] + token_lenghts = batch["text_lengths"] + mel = batch["mel_input"] + mel_lens = batch["mel_lengths"] + waveform = batch["waveform"] # [B, T, C] -> [B, C, T] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_priors = batch["attn_priors"] + energy = batch["energy"] + + # generator pass + outputs = self.forward( + x=tokens, + x_lengths=token_lenghts, + spec_lengths=mel_lens, + spec=mel, + waveform=waveform, + pitch=pitch, + energy=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_ids, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + if self.train_disc: + # compute scores and features + scores_d_fake, _, scores_d_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_fake=scores_d_fake, + scores_disc_real=scores_d_real, + ) + return outputs, loss_dict + return None, None + + if optimizer_idx == 1: + mel = batch["mel_input"] + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True + ) + + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + ) + + scores_d_fake = None + feats_d_fake = None + feats_d_real = None + + if self.train_disc: + # compute discriminator scores and features + scores_d_fake, feats_d_fake, _, feats_d_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=True): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_output=self.model_outputs_cache["acoustic_model_outputs"].transpose(1, 2), + mel_target=batch["mel_input"], + mel_lens=batch["mel_lengths"], + dur_output=self.model_outputs_cache["dr_log_pred"], + dur_target=self.model_outputs_cache["dr_log_target"].detach(), + pitch_output=self.model_outputs_cache["pitch_pred"], + pitch_target=self.model_outputs_cache["pitch_target"], + energy_output=self.model_outputs_cache["energy_pred"], + energy_target=self.model_outputs_cache["energy_target"], + src_lens=batch["text_lengths"], + waveform=self.model_outputs_cache["waveform_seg"], + waveform_hat=self.model_outputs_cache["model_outputs"], + p_prosody_ref=self.model_outputs_cache["p_prosody_ref"], + p_prosody_pred=self.model_outputs_cache["p_prosody_pred"], + u_prosody_ref=self.model_outputs_cache["u_prosody_ref"], + u_prosody_pred=self.model_outputs_cache["u_prosody_pred"], + aligner_logprob=self.model_outputs_cache["aligner_logprob"], + aligner_hard=self.model_outputs_cache["aligner_mas"], + aligner_soft=self.model_outputs_cache["aligner_soft"], + binary_loss_weight=self.binary_loss_weight, + feats_fake=feats_d_fake, + feats_real=feats_d_real, + scores_fake=scores_d_fake, + spec_slice=mel_slice, + spec_slice_hat=mel_slice_hat, + skip_disc=not self.train_disc, + ) + + loss_dict["avg_text_length"] = batch["text_lengths"].float().mean() + loss_dict["avg_mel_length"] = batch["mel_lengths"].float().mean() + loss_dict["avg_text_batch_occupancy"] = ( + batch["text_lengths"].float() / batch["text_lengths"].float().max() + ).mean() + loss_dict["avg_mel_batch_occupancy"] = ( + batch["mel_lengths"].float() / batch["mel_lengths"].float().max() + ).mean() + + return self.model_outputs_cache, loss_dict + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def _log(self, batch, outputs, name_prefix="train"): + figures, audios = {}, {} + + # encoder outputs + model_outputs = outputs[1]["acoustic_model_outputs"] + alignments = outputs[1]["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, None, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec.T, None, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + pitch_avg = abs(outputs[1]["pitch_target"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs[1]["pitch_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot energy figures + energy_avg = abs(outputs[1]["energy_target"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs[1]["energy_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_pitch(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_pitch(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + + # plot the attention mask computed from the predicted durations + alignments_hat = outputs[1]["alignments_dp"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + encoder_audio = mel_to_wav_numpy( + mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.mel_basis, **self.config.audio + ) + audios[f"{name_prefix}/encoder_audio"] = encoder_audio + + # vocoder outputs + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + + vocoder_figures = plot_results(y_hat=y_hat, y=y, ap=self.ap, name_prefix=name_prefix) + figures.update(vocoder_figures) + + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios[f"{name_prefix}/vocoder_audio"] = sample_voice + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use, unused-argument + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previous training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav = None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector = None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} + + def plot_outputs(self, text, wav, alignment, outputs): + figures = {} + pitch_avg_pred = outputs["pitch"].cpu() + energy_avg_pred = outputs["energy"].cpu() + spec = wav_to_mel( + y=torch.from_numpy(wav[None, :]), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + )[0].transpose(0, 1) + pitch = compute_f0( + x=wav[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + ) + input_text = self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(text, language="en")) + input_text = input_text.replace("", "_") + durations = outputs["durations"] + pitch_avg = average_over_durations(torch.from_numpy(pitch)[None, None, :], durations.cpu()) # [1, 1, n_frames] + pitch_avg_pred_denorm = (pitch_avg_pred * self.pitch_std) + self.pitch_mean + figures["alignment"] = plot_alignment(alignment.transpose(1, 2), output_fig=False) + figures["spectrogram"] = plot_spectrogram(spec) + figures["pitch_from_wav"] = plot_pitch(pitch, spec) + figures["pitch_avg_from_wav"] = plot_avg_pitch(pitch_avg.squeeze(), input_text) + figures["pitch_avg_pred"] = plot_avg_pitch(pitch_avg_pred_denorm.squeeze(), input_text) + figures["energy_avg_pred"] = plot_avg_pitch(energy_avg_pred.squeeze(), input_text) + return figures + + def synthesize( + self, + text: str, + speaker_id: str = None, + d_vector: torch.tensor = None, + pitch_transform=None, + **kwargs, + ): # pylint: disable=unused-argument + # TODO: add cloning support with ref_waveform + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + + # set speaker inputs + _speaker_id = None + if speaker_id is not None and self.args.use_speaker_embedding: + if isinstance(speaker_id, str) and self.args.use_speaker_embedding: + # get the speaker id for the speaker embedding layer + _speaker_id = self.speaker_manager.name_to_id[speaker_id] + _speaker_id = id_to_torch(_speaker_id, cuda=is_cuda) + + if speaker_id is not None and self.args.use_d_vector_file: + # get the average d_vector for the speaker + d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False) + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference( + text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": _speaker_id}, + pitch_transform=pitch_transform, + # energy_transform=energy_transform + ) + + # collect outputs + wav = outputs["model_outputs"][0].data.cpu().numpy() + alignments = outputs["alignments"] + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + def synthesize_with_gl(self, text: str, speaker_id, d_vector): + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference_spec_decoder( + x=text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}, + ) + + # collect outputs + S = outputs["model_outputs"].cpu().numpy()[0].T + S = db_to_amp_numpy(x=S, gain=1, base=None) + wav = mel_to_wav_numpy(mel=S, mel_basis=self.mel_basis, **self.config.audio) + alignments = outputs["alignments"] + return_dict = { + "wav": wav[None, :], + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + outputs = self.synthesize( + aux_inputs["text"], + config=self.config, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + outputs_gl = self.synthesize_with_gl( + aux_inputs["text"], + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + # speaker_name = self.speaker_manager.speaker_names[aux_inputs["speaker_id"]] + test_audios["{}-audio".format(idx)] = outputs["wav"].T + test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_names and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + + ac = self.ap + + # compute spectrograms + batch["mel_input"] = wav_to_mel( + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + center=False, + ) + + # TODO: Align pitch properly + # assert ( + # batch["pitch"].shape[2] == batch["mel_input"].shape[2] + # ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}" + batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] if batch["pitch"] is not None else None + batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() + + # zero the padding frames + batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) + + # format attn priors as we now the max mel length + # TODO: fix 1 diff b/w mel_lengths and attn_priors + + if self.config.use_attn_priors: + attn_priors_np = batch["attn_priors"] + + batch["attn_priors"] = torch.zeros( + batch["mel_input"].shape[0], + batch["mel_lengths"].max(), + batch["text_lengths"].max(), + device=batch["mel_input"].device, + ) + + for i in range(batch["mel_input"].shape[0]): + batch["attn_priors"][i, : attn_priors_np[i].shape[0], : attn_priors_np[i].shape[1]] = torch.from_numpy( + attn_priors_np[i] + ) + + batch["energy"] = None + batch["energy"] = wav_to_energy( # [B, 1, T_max2] + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + center=False, + ) + batch["energy"] = self.energy_scaler(batch["energy"]) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = ForwardTTSE2eDataset( + samples=samples, + ap=self.ap, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + compute_f0=config.compute_f0, + f0_cache_path=config.f0_cache_path, + attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences ascendingly by length + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + + # get pitch mean and std + self.pitch_mean = dataset.f0_dataset.mean + self.pitch_std = dataset.f0_dataset.std + return loader + + def get_criterion(self): + return [VitsDiscriminatorLoss(self.config), DelightfulTTSLoss(self.config)] + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + optimizer_disc = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc + ) + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer_gen = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer_disc, optimizer_gen] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def on_epoch_end(self, trainer): # pylint: disable=unused-argument + # stop updating mean and var + # TODO: do the same for F0 + self.energy_scaler.eval() + + @staticmethod + def init_from_config( + config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=False + ): # pylint: disable=unused-argument + """Initiate model from config + + Args: + config (ForwardTTSE2eConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config.model_args, samples) + ap = AudioProcessor.init_from_config(config=config) + return DelightfulTTS(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager, ap=ap) + + def load_checkpoint(self, config, checkpoint_path, eval=False): + """Load model from a checkpoint created by the 👟""" + # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_state_dict(self): + """Custom state dict of the model with all the necessary components for inference.""" + save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} + + if hasattr(self, "emb_g"): + save_state["speaker_ids"] = self.speaker_manager.speaker_names + + if self.args.use_d_vector_file: + # TODO: implement saving of d_vectors + ... + return save_state + + def save(self, config, checkpoint_path): + """Save model to a file.""" + save_state = self.get_state_dict(config, checkpoint_path) # pylint: disable=too-many-function-args + save_state["pitch_mean"] = self.pitch_mean + save_state["pitch_std"] = self.pitch_std + torch.save(save_state, checkpoint_path) + + def on_train_step_start(self, trainer) -> None: + """Enable the discriminator training based on `steps_to_start_discriminator` + + Args: + trainer (Trainer): Trainer object. + """ + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + trainer.total_steps_done >= self.config.steps_to_start_discriminator + ) + + +class DelightfulTTSLoss(nn.Module): + def __init__(self, config): + super().__init__() + + self.mse_loss = nn.MSELoss() + self.mae_loss = nn.L1Loss() + self.forward_sum_loss = ForwardSumLoss() + self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) + + self.mel_loss_alpha = config.mel_loss_alpha + self.aligner_loss_alpha = config.aligner_loss_alpha + self.pitch_loss_alpha = config.pitch_loss_alpha + self.energy_loss_alpha = config.energy_loss_alpha + self.u_prosody_loss_alpha = config.u_prosody_loss_alpha + self.p_prosody_loss_alpha = config.p_prosody_loss_alpha + self.dur_loss_alpha = config.dur_loss_alpha + self.char_dur_loss_alpha = config.char_dur_loss_alpha + self.binary_alignment_loss_alpha = config.binary_align_loss_alpha + + self.vocoder_mel_loss_alpha = config.vocoder_mel_loss_alpha + self.feat_loss_alpha = config.feat_loss_alpha + self.gen_loss_alpha = config.gen_loss_alpha + self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + def forward( + self, + mel_output, + mel_target, + mel_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + energy_output, + energy_target, + src_lens, + waveform, + waveform_hat, + p_prosody_ref, + p_prosody_pred, + u_prosody_ref, + u_prosody_pred, + aligner_logprob, + aligner_hard, + aligner_soft, + binary_loss_weight=None, + feats_fake=None, + feats_real=None, + scores_fake=None, + spec_slice=None, + spec_slice_hat=None, + skip_disc=False, + ): + """ + Shapes: + - mel_output: :math:`(B, C_mel, T_mel)` + - mel_target: :math:`(B, C_mel, T_mel)` + - mel_lens: :math:`(B)` + - dur_output: :math:`(B, T_src)` + - dur_target: :math:`(B, T_src)` + - pitch_output: :math:`(B, 1, T_src)` + - pitch_target: :math:`(B, 1, T_src)` + - energy_output: :math:`(B, 1, T_src)` + - energy_target: :math:`(B, 1, T_src)` + - src_lens: :math:`(B)` + - waveform: :math:`(B, 1, T_wav)` + - waveform_hat: :math:`(B, 1, T_wav)` + - p_prosody_ref: :math:`(B, T_src, 4)` + - p_prosody_pred: :math:`(B, T_src, 4)` + - u_prosody_ref: :math:`(B, 1, 256) + - u_prosody_pred: :math:`(B, 1, 256) + - aligner_logprob: :math:`(B, 1, T_mel, T_src)` + - aligner_hard: :math:`(B, T_mel, T_src)` + - aligner_soft: :math:`(B, T_mel, T_src)` + - spec_slice: :math:`(B, C_mel, T_mel)` + - spec_slice_hat: :math:`(B, C_mel, T_mel)` + """ + loss_dict = {} + src_mask = sequence_mask(src_lens).to(mel_output.device) # (B, T_src) + mel_mask = sequence_mask(mel_lens).to(mel_output.device) # (B, T_mel) + + dur_target.requires_grad = False + mel_target.requires_grad = False + pitch_target.requires_grad = False + + masked_mel_predictions = mel_output.masked_select(mel_mask[:, None]) + mel_targets = mel_target.masked_select(mel_mask[:, None]) + mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) + + p_prosody_ref = p_prosody_ref.detach() + p_prosody_loss = 0.5 * self.mae_loss( + p_prosody_ref.masked_select(src_mask.unsqueeze(-1)), + p_prosody_pred.masked_select(src_mask.unsqueeze(-1)), + ) + + u_prosody_ref = u_prosody_ref.detach() + u_prosody_loss = 0.5 * self.mae_loss(u_prosody_ref, u_prosody_pred) + + duration_loss = self.mse_loss(dur_output, dur_target) + + pitch_output = pitch_output.masked_select(src_mask[:, None]) + pitch_target = pitch_target.masked_select(src_mask[:, None]) + pitch_loss = self.mse_loss(pitch_output, pitch_target) + + energy_output = energy_output.masked_select(src_mask[:, None]) + energy_target = energy_target.masked_select(src_mask[:, None]) + energy_loss = self.mse_loss(energy_output, energy_target) + + forward_sum_loss = self.forward_sum_loss(aligner_logprob, src_lens, mel_lens) + + total_loss = ( + (mel_loss * self.mel_loss_alpha) + + (duration_loss * self.dur_loss_alpha) + + (u_prosody_loss * self.u_prosody_loss_alpha) + + (p_prosody_loss * self.p_prosody_loss_alpha) + + (pitch_loss * self.pitch_loss_alpha) + + (energy_loss * self.energy_loss_alpha) + + (forward_sum_loss * self.aligner_loss_alpha) + ) + + if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) + total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + if binary_loss_weight: + loss_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + loss_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + loss_dict["loss_aligner"] = self.aligner_loss_alpha * forward_sum_loss + loss_dict["loss_mel"] = self.mel_loss_alpha * mel_loss + loss_dict["loss_duration"] = self.dur_loss_alpha * duration_loss + loss_dict["loss_u_prosody"] = self.u_prosody_loss_alpha * u_prosody_loss + loss_dict["loss_p_prosody"] = self.p_prosody_loss_alpha * p_prosody_loss + loss_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + loss_dict["loss_energy"] = self.energy_loss_alpha * energy_loss + loss_dict["loss"] = total_loss + + # vocoder losses + if not skip_disc: + loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_dict["vocoder_loss_feat"] = loss_feat + loss_dict["vocoder_loss_gen"] = loss_gen + loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen + + loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.vocoder_mel_loss_alpha + loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) + loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha + loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha + + loss_dict["vocoder_loss_mel"] = loss_mel + loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg + loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc + + loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_stft_sc + loss_stft_mg + return loss_dict diff --git a/viXTTS/TTS/tts/models/forward_tts.py b/viXTTS/TTS/tts/models/forward_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e9ac8a14d1b92295bfae532364839f993283ee --- /dev/null +++ b/viXTTS/TTS/tts/models/forward_tts.py @@ -0,0 +1,862 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram +from TTS.utils.io import load_fsspec + + +@dataclass +class ForwardTTSArgs(Coqpit): + """ForwardTTS Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + use_aligner (bool): + Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. + If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the + pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True. + + use_pitch (bool): + Use pitch predictor to learn the pitch. Defaults to True. + + use_energy (bool): + Use energy predictor to learn the energy. Defaults to True. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + energy_predictor_hidden_channels (int): + Number of hidden channels in the energy predictor. Defaults to 256. + + energy_predictor_dropout_p (float): + Dropout rate for the energy predictor. Defaults to 0.1. + + energy_predictor_kernel_size (int): + Kernel size of conv layers in the energy predictor. Defaults to 3. + + energy_embedding_kernel_size (int): + Kernel size of the projection layer in the energy predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 384 + use_aligner: bool = True + # pitch params + use_pitch: bool = True + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + + # energy params + use_energy: bool = False + energy_predictor_hidden_channels: int = 256 + energy_predictor_kernel_size: int = 3 + energy_predictor_dropout_p: float = 0.1 + energy_embedding_kernel_size: int = 3 + + # duration params + duration_predictor_hidden_channels: int = 256 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + + positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + detach_duration_predictor: bool = False + max_duration: int = 75 + num_speakers: int = 1 + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_dim: int = None + d_vector_file: str = None + + +class ForwardTTS(BaseTTS): + """General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment + network and a pitch predictor. + + If the alignment network is used, the model learns the text-to-speech alignment + from the data instead of using pre-computed durations. + + If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each + input character as in the FastPitch model. + + `ForwardTTS` can be configured to one of these architectures, + + - FastPitch + - SpeedySpeech + - FastSpeech + - FastSpeech2 (requires average speech energy predictor) + + Args: + config (Coqpit): Model coqpit class. + speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models. + Defaults to None. + + Examples: + >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs + >>> config = ForwardTTSArgs() + >>> model = ForwardTTS(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self._set_model_args(config) + + self.init_multispeaker(config) + + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_pitch = self.args.use_pitch + self.use_energy = self.args.use_energy + self.binary_loss_weight = 0.0 + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.embedded_speaker_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) + + self.decoder = Decoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, + ) + + if self.args.use_pitch: + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + ) + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_energy: + self.energy_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.energy_predictor_hidden_channels, + self.args.energy_predictor_kernel_size, + self.args.energy_predictor_dropout_p, + ) + self.energy_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.energy_embedding_kernel_size, + padding=int((self.args.energy_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_aligner: + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # init speaker manager + if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # init d-vector embedding + if config.use_d_vector_file: + self.embedded_speaker_dim = config.d_vector_dim + if self.args.d_vector_dim != self.args.hidden_channels: + #self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Shapes: + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + + Examples:: + + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + def _forward_encoder( + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Sum encoder outputs and speaker embeddings + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ + if hasattr(self, "emb_g"): + g = g.type(torch.LongTensor) + g = self.emb_g(g) # [B, C, 1] + if g is not None: + g = g.unsqueeze(-1) + # [B, T, C] + x_emb = self.emb(x) + # encoder pass + #o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g) + # speaker conditioning + # TODO: try different ways of conditioning + if g is not None: + if hasattr(self, "proj_g"): + g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) + o_en = o_en + g + return o_en, x_mask, g, x_emb + + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_over_durations(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_energy_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + energy: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Energy predictor forward pass. + + 1. Predict energy from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average energy values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_energy = self.energy_predictor(o_en, x_mask) + if energy is not None: + avg_energy = average_over_durations(energy, dr) + o_energy_emb = self.energy_emb(avg_energy) + return o_energy_emb, o_energy, avg_energy + o_energy_emb = self.energy_emb(o_energy) + return o_energy_emb, o_energy + + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + g = self._set_speaker_input(aux_input) + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + # encoder pass + o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner + o_alignment_dur = None + alignment_soft = None + alignment_logprob = None + alignment_mas = None + if self.use_aligner: + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) + alignment_soft = alignment_soft.transpose(1, 2) + alignment_mas = alignment_mas.transpose(1, 2) + dr = o_alignment_dur + # pitch predictor pass + o_pitch = None + avg_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + avg_energy = None + if self.args.use_energy: + o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr) + o_en = o_en + o_energy_emb + # decoder pass + o_de, attn = self._forward_decoder( + o_en, dr, x_mask, y_lengths, g=None + ) # TODO: maybe pass speaker embedding (g) too + outputs = { + "model_outputs": o_de, # [B, T, C] + "durations_log": o_dr_log.squeeze(1), # [B, T] + "durations": o_dr.squeeze(1), # [B, T] + "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, + "energy_avg": o_energy, + "energy_avg_gt": avg_energy, + "alignments": attn, # [B, T_de, T_en] + "alignment_soft": alignment_soft, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = self._set_speaker_input(aux_input) + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + + # pitch predictor pass + o_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + if self.args.use_energy: + o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask) + o_en = o_en + o_energy_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "energy": o_energy, + "durations_log": o_dr_log, + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] if self.args.use_pitch else None + energy = batch["energy"] if self.args.use_energy else None + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + + # forward pass + outputs = self.forward( + text_input, + text_lengths, + mel_lengths, + y=mel_input, + dr=durations, + pitch=pitch, + energy=energy, + aux_input=aux_input, + ) + # use aligner's output as the duration target + if self.use_aligner: + durations = outputs["o_alignment_dur"] + # use float32 in AMP + with autocast(enabled=False): + # compute loss + loss_dict = criterion( + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch_avg"] if self.use_pitch else None, + pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, + energy_output=outputs["energy_avg"] if self.use_energy else None, + energy_target=outputs["energy_avg_gt"] if self.use_energy else None, + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, + alignment_soft=outputs["alignment_soft"], + alignment_hard=outputs["alignment_mas"], + binary_loss_weight=self.binary_loss_weight, + ) + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + """Create common logger outputs.""" + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot energy figures + if self.args.use_energy: + energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel + + return ForwardTTSLoss(self.config) + + def on_train_step_start(self, trainer): + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (ForwardTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return ForwardTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/viXTTS/TTS/tts/models/glow_tts.py b/viXTTS/TTS/tts/models/glow_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..bfd1a2b618bd9bfdc7d12dd4eb16a6febcaf8cde --- /dev/null +++ b/viXTTS/TTS/tts/models/glow_tts.py @@ -0,0 +1,557 @@ +import math +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F + +from TTS.tts.configs.glow_tts_config import GlowTTSConfig +from TTS.tts.layers.glow_tts.decoder import Decoder +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.io import load_fsspec + + +class GlowTTS(BaseTTS): + """GlowTTS model. + + Paper:: + https://arxiv.org/abs/2005.11129 + + Paper abstract:: + Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate + mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained + without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS, + a flow-based generative model for parallel TTS that does not require any external aligner. By combining the + properties of flows and dynamic programming, the proposed model searches for the most probable monotonic + alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard + monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows + enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over + the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our + model can be easily extended to a multi-speaker setting. + + Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. + + Examples: + Init only model layers. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig(num_chars=2) + >>> model = GlowTTS(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig() + >>> model = GlowTTS.init_from_config(config, verbose=False) + """ + + def __init__( + self, + config: GlowTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + # init multi-speaker layers if necessary + self.init_multispeaker(config) + + self.run_data_dep_init = config.data_dep_init_steps > 0 + self.encoder = Encoder( + self.num_chars, + out_channels=self.out_channels, + hidden_channels=self.hidden_channels_enc, + hidden_channels_dp=self.hidden_channels_dp, + encoder_type=self.encoder_type, + encoder_params=self.encoder_params, + mean_only=self.mean_only, + use_prenet=self.use_encoder_prenet, + dropout_p_dp=self.dropout_p_dp, + c_in_channels=self.c_in_channels, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + def init_multispeaker(self, config: Coqpit): + """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding + vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets + speaker embedding vector dimension to the d-vector dimension from the config. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # set ultimate speaker embedding size + if config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + if self.speaker_manager is not None: + assert ( + config.d_vector_dim == self.speaker_manager.embedding_dim + ), " [!] d-vector dimension mismatch b/w config and speaker manager." + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.embedded_speaker_dim = self.hidden_channels_enc + self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + # set conditioning dimensions + self.c_in_channels = self.embedded_speaker_dim + + @staticmethod + def compute_outputs(attn, o_mean, o_log_scale, x_mask): + """Compute and format the mode outputs with the given alignment map""" + y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + # compute total duration with adjustment + o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask + return y_mean, y_log_scale, o_attn_dur + + def unlock_act_norm_layers(self): + """Unlock activation normalization layers for data depended initalization.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(True) + + def lock_act_norm_layers(self): + """Lock activation normalization layers.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(False) + + def _set_speaker_input(self, aux_input: Dict): + if aux_input is None: + d_vectors = None + speaker_ids = None + else: + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]: + g = self._set_speaker_input(aux_input) + # speaker embedding + if g is not None: + if hasattr(self, "emb_g"): + # use speaker embedding layer + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + else: + # use d-vector + g = F.normalize(g).unsqueeze(-1) # [b, h, 1] + return g + + def forward( + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + + aux_input (Dict): + Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. + :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding + layer. :math:`B` + + Returns: + Dict: + - z: :math: `[B, T_de, C]` + - logdet: :math:`B` + - y_mean: :math:`[B, T_de, C]` + - y_log_scale: :math:`[B, T_de, C]` + - alignments: :math:`[B, T_en, T_de]` + - durations_log: :math:`[B, T_en, 1]` + - total_durations_log: :math:`[B, T_en, 1]` + """ + # [B, T, C] -> [B, C, T] + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + # [B, 1, T_en, T_de] + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "z": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def inference_with_MAS( + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + It's similar to the teacher forcing in Tacotron. + It was proposed in: https://arxiv.org/abs/2104.05557 + + Shapes: + - x: :math:`[B, T]` + - x_lenghts: :math:`B` + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path between z and encoder output + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + + # get predited aligned distribution + z = y_mean * y_mask + + # reverse the decoder and predict using the aligned distribution + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = { + "model_outputs": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def decoder_inference( + self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Shapes: + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + g = self._speaker_embedding(aux_input) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # reverse decoder and predict + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = {} + outputs["model_outputs"] = y.transpose(1, 2) + outputs["logdet"] = logdet + return outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + x_lengths = aux_input["x_lengths"] + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # compute output durations + w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale + w_ceil = torch.clamp_min(torch.ceil(w), 1) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = None + # compute masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # compute attention mask + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask + # decoder pass + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "model_outputs": y.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + """A single training step. Forward pass and loss computation. Run data depended initialization for the + first `config.data_dep_init_steps` steps. + + Args: + batch (dict): [description] + criterion (nn.Module): [description] + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + if self.run_data_dep_init and self.training: + # compute data-dependent initialization of activation norm layers + self.unlock_act_norm_layers() + with torch.no_grad(): + _ = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + outputs = None + loss_dict = None + self.lock_act_norm_layers() + else: + # normal training step + outputs = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + + with autocast(enabled=False): # avoid mixed_precision in criterion + loss_dict = criterion( + outputs["z"].float(), + outputs["y_mean"].float(), + outputs["y_log_scale"].float(), + outputs["logdet"].float(), + mel_lengths, + outputs["durations_log"].float(), + outputs["total_durations_log"].float(), + text_lengths, + ) + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + alignments = outputs["alignments"] + text_input = batch["text_input"][:1] if batch["text_input"] is not None else None + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None + speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None + + # model runs reverse flow to predict spectrograms + pred_outputs = self.inference( + text_input, + aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + model_outputs = pred_outputs["model_outputs"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + if len(test_sentences) == 0: + print(" | [!] No test sentences provided.") + else: + for idx, sen in enumerate(test_sentences): + outputs = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + + test_audios["{}-audio".format(idx)] = outputs["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return test_figures, test_audios + + def preprocess(self, y, y_lengths, y_max_length, attn=None): + if y_max_length is not None: + y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze + y = y[:, :, :y_max_length] + if attn is not None: + attn = attn[:, :, :, :y_max_length] + y_lengths = torch.div(y_lengths, self.num_squeeze, rounding_mode="floor") * self.num_squeeze + return y, y_lengths, y_max_length, attn + + def store_inverse(self): + self.decoder.store_inverse() + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.store_inverse() + assert not self.training + + @staticmethod + def get_criterion(): + from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel + + return GlowTTSLoss() + + def on_train_step_start(self, trainer): + """Decide on every training step wheter enable/disable data depended initialization.""" + self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps + + @staticmethod + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/viXTTS/TTS/tts/models/neuralhmm_tts.py b/viXTTS/TTS/tts/models/neuralhmm_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..e2414108721571c9a1cf143fdca2fa74174a9684 --- /dev/null +++ b/viXTTS/TTS/tts/models/neuralhmm_tts.py @@ -0,0 +1,385 @@ +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec + + +class NeuralhmmTTS(BaseTTS): + """Neural HMM TTS model. + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality + than statistical speech synthesis using HMMs.However, neural TTS is generally not probabilistic + and uses non-monotonic attention. Attention failures increase training time and can make + synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with + an autoregressive left-right no-skip hidden Markov model defined by a neural network. + Based on this proposal, we modify Tacotron 2 to obtain an HMM-based neural TTS model with + monotonic alignment, trained to maximise the full sequence likelihood without approximation. + We also describe how to combine ideas from classical and contemporary TTS for best results. + The resulting example system is smaller and simpler than Tacotron 2, and learns to speak with + fewer iterations and less data, whilst achieving comparable naturalness prior to the post-net. + Our approach also allows easy control over speaking rate. Audio examples and code + are available at https://shivammehta25.github.io/Neural-HMM/ . + + Note: + - This is a parameter efficient version of OverFlow (15.3M vs 28.6M). Since it has half the + number of parameters as OverFlow the synthesis output quality is suboptimal (but comparable to Tacotron2 + without Postnet), but it learns to speak with even lesser amount of data and is still significantly faster + than other attention-based methods. + + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.neuralhmm_tts_config.NeuralhmmTTSConfig` for class arguments. + """ + + def __init__( + self, + config: "NeuralhmmTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, mels.transpose(1, 2), mel_len + ) + + outputs = { + "log_probs": log_probs, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) + + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict = default_input_dict.copy() + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(default_input_dict, aux_input) + return default_input_dict + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + mels, mel_outputs_len = outputs["hmm_outputs"], outputs["hmm_outputs_len"] + + mels = self.inverse_normalize(mels) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + print( + f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + print( + f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + print( + f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + ) + statistics = torch.load(trainer.config.mel_statistics_parameter_path) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0]), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + print(" | > Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/viXTTS/TTS/tts/models/overflow.py b/viXTTS/TTS/tts/models/overflow.py new file mode 100644 index 0000000000000000000000000000000000000000..92b3c767de4cb5180df4a58d6cfdc1ed194caad7 --- /dev/null +++ b/viXTTS/TTS/tts/models/overflow.py @@ -0,0 +1,401 @@ +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.decoder import Decoder +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.io import load_fsspec + + +class Overflow(BaseTTS): + """OverFlow TTS model. + + Paper:: + https://arxiv.org/abs/2211.06892 + + Paper abstract:: + Neural HMMs are a type of neural transducer recently proposed for + sequence-to-sequence modelling in text-to-speech. They combine the best features + of classic statistical speech synthesis and modern neural TTS, requiring less + data and fewer training updates, and are less prone to gibberish output caused + by neural attention failures. In this paper, we combine neural HMM TTS with + normalising flows for describing the highly non-Gaussian distribution of speech + acoustics. The result is a powerful, fully probabilistic model of durations and + acoustics that can be trained using exact maximum likelihood. Compared to + dominant flow-based acoustic models, our approach integrates autoregression for + improved modelling of long-range dependences such as utterance-level prosody. + Experiments show that a system based on our proposal gives more accurate + pronunciations and better subjective speech quality than comparable methods, + whilst retaining the original advantages of neural HMMs. Audio examples and code + are available at https://shivammehta25.github.io/OverFlow/. + + Note: + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.overflow.OverFlowConfig` for class arguments. + """ + + def __init__( + self, + config: "OverFlowConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + z, z_lengths, logdet = self.decoder(mels.transpose(1, 2), mel_len) + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, z, z_lengths + ) + + outputs = { + "log_probs": log_probs + logdet, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) + + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict = default_input_dict.copy() + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(default_input_dict, aux_input) + return default_input_dict + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + + mels, mel_outputs_len, _ = self.decoder( + outputs["hmm_outputs"].transpose(1, 2), outputs["hmm_outputs_len"], reverse=True + ) + mels = self.inverse_normalize(mels.transpose(1, 2)) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + verbose (bool): If True, print init messages. Defaults to True. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config, verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Overflow(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.decoder.store_inverse() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + print( + f" | > Data parameters not found for: {trainer.config.mel_statistics_parameter_path}. Computing mel normalization parameters..." + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + print( + f" | > Saving data parameters to: {trainer.config.mel_statistics_parameter_path}: value: {data_mean, data_std, init_transition_prob}" + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + print( + f" | > Data parameters found for: {trainer.config.mel_statistics_parameter_path}. Loading mel normalization parameters..." + ) + statistics = torch.load(trainer.config.mel_statistics_parameter_path) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + print(f" | > Data parameters loaded with value: {data_mean, data_std, init_transition_prob}") + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0], self.decoder), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + print(" | > Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/viXTTS/TTS/tts/models/tacotron.py b/viXTTS/TTS/tts/models/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..474ec4641d0a569fc1938442ab9f7ce4bb980119 --- /dev/null +++ b/viXTTS/TTS/tts/models/tacotron.py @@ -0,0 +1,409 @@ +# coding: utf-8 + +from typing import Dict, List, Tuple, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron(BaseTacotron): + """Tacotron as in https://arxiv.org/abs/1703.10135 + It's an autoregressive encoder-attention-decoder-postnet architecture. + Check `TacotronConfig` for the arguments. + + Args: + config (TacotronConfig): Configuration for the Tacotron model. + speaker_manager (SpeakerManager): Speaker manager to handle multi-speaker settings. Only use if the model is + a multi-speaker model. Defaults to None. + """ + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # set speaker embedding channel size for determining `in_channels` for the connected layers. + # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based + # on the number of speakers infered from the dataset. + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0) + self.embedding.weight.data.normal_(0, 0.3) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = PostCBHG(self.decoder_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=self.embedded_speaker_dim + if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding + else None, + text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """ + Shapes: + text: [B, T_in] + text_lengths: [B] + mel_specs: [B, T_out, C] + mel_lengths: [B] + aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C] + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + inputs = self.embedding(text) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x T_in x encoder_in_features + encoder_outputs = self.encoder(inputs) + # sequence masking + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # global style token + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # speaker embedding + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=[inputs, text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + # decoder_outputs: B x decoder_in_features x T_out + # alignments: B x T_in x encoder_in_features + # stop_tokens: B x T_in + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if output_mask is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x T_out x decoder_in_features + postnet_outputs = self.postnet(decoder_outputs) + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs) + # B x T_out x posnet_dim + postnet_outputs = self.last_linear(postnet_outputs) + # B x T_out x decoder_in_features + decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text_input, aux_input=None): + aux_input = self._format_aux_input(aux_input) + inputs = self.embedding(text_input) + encoder_outputs = self.encoder(inputs) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[aux_input["style_mel"], reference_mel_length] + if aux_input["style_mel"] is not None + else None, + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=aux_input["d_vectors"] + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + ) + if self.num_speakers > 1: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"]) + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = self.last_linear(postnet_outputs) + decoder_outputs = decoder_outputs.transpose(1, 2) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: + """Perform a single training step by fetching the right set of samples from the batch. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([torch.nn.Module]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + linear_input = batch["linear_input"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + linear_input.float(), + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + postnet_outputs = outputs["model_outputs"] + decoder_outputs = outputs["decoder_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + linear_input = batch["linear_input"] + + pred_linear_spec = postnet_outputs[0].data.cpu().numpy() + pred_mel_spec = decoder_outputs[0].data.cpu().numpy() + gt_linear_spec = linear_input[0].data.cpu().numpy() + gt_mel_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "pred_linear_spec": plot_spectrogram(pred_linear_spec, ap, output_fig=False), + "real_linear_spec": plot_spectrogram(gt_linear_spec, ap, output_fig=False), + "pred_mel_spec": plot_spectrogram(pred_mel_spec, ap, output_fig=False), + "real_mel_spec": plot_spectrogram(gt_mel_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_spectrogram(pred_linear_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (TacotronConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Tacotron(new_config, ap, tokenizer, speaker_manager) diff --git a/viXTTS/TTS/tts/models/tacotron2.py b/viXTTS/TTS/tts/models/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..71ab1eac37aa70900a795cf8aa3df7a9ce77c49c --- /dev/null +++ b/viXTTS/TTS/tts/models/tacotron2.py @@ -0,0 +1,433 @@ +# coding: utf-8 + +from typing import Dict, List, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron2(BaseTacotron): + """Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`. + + Paper:: + https://arxiv.org/abs/1712.05884 + + Paper abstract:: + This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text. + The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character + embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize + timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable + to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation + studies of key components of our system and evaluate the impact of using mel spectrograms as the input to + WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic + intermediate representation enables significant simplification of the WaveNet architecture. + + Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments. + + Args: + config (TacotronConfig): + Configuration for the Tacotron2 model. + speaker_manager (SpeakerManager): + Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. + """ + + def __init__( + self, + config: "Tacotron2Config", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + self.decoder_output_dim = config.out_channels + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # init multi-speaker layers + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = Postnet(self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron VAE Layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=self.embedded_speaker_dim + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + @staticmethod + def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): + """Final reshape of the model output tensors.""" + mel_outputs = mel_outputs.transpose(1, 2) + mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) + return mel_outputs, mel_outputs_postnet, alignments + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """Forward pass for training with Teacher Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + mel_specs: :math:`[B, T_out, C]` + mel_lengths: :math:`[B]` + aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]` + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + # compute mask for padding + # B x T_in_max (boolean) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x D_embed x T_in_max + embedded_inputs = self.embedding(text).transpose(1, 2) + # B x T_in_max x D_en + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + # capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=[embedded_inputs.transpose(1, 2), text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None, + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + + # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if mel_lengths is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x mel_dim x T_out + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) + # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text, aux_input=None): + """Forward pass for inference with no Teacher-Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + """ + aux_input = self._format_aux_input(aux_input) + embedded_inputs = self.embedding(text).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[aux_input["style_mel"], reference_mel_length] + if aux_input["style_mel"] is not None + else None, + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=aux_input["d_vectors"] + if self.capacitron_vae.capacitron_use_speaker_embedding + else None, + ) + + if self.num_speakers > 1: + if not self.use_d_vector_file: + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None] + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + embedded_speakers = aux_input["d_vectors"] + + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module): + """A single training step. Forward pass and loss computation. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([type]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + None, + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + """Create dashboard log information.""" + postnet_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + + pred_spec = postnet_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (Tacotron2Config): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(new_config, samples) + return Tacotron2(new_config, ap, tokenizer, speaker_manager) diff --git a/viXTTS/TTS/tts/models/tortoise.py b/viXTTS/TTS/tts/models/tortoise.py new file mode 100644 index 0000000000000000000000000000000000000000..16644ff95eee6799f5e78603e2011f63b05a1011 --- /dev/null +++ b/viXTTS/TTS/tts/models/tortoise.py @@ -0,0 +1,911 @@ +import os +import random +from contextlib import contextmanager +from dataclasses import dataclass +from time import time + +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit +from tqdm import tqdm + +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram +from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel +from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice +from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead +from TTS.tts.layers.tortoise.clvp import CLVP +from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps +from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts +from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter +from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.tortoise.vocoder import VocConf, VocType +from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment +from TTS.tts.models.base_tts import BaseTTS + + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +def deterministic_state(seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed + + +def load_discrete_vocoder_diffuser( + trained_diffusion_steps=4000, + desired_diffusion_steps=200, + cond_free=True, + cond_free_k=1, + sampler="ddim", +): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion( + use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), + model_mean_type="epsilon", + model_var_type="learned_range", + loss_type="mse", + betas=get_named_beta_schedule("linear", trained_diffusion_steps), + conditioning_free=cond_free, + conditioning_free_k=cond_free_k, + sampler=sampler, + ) + + +def format_conditioning(clip, cond_length=132300, device="cuda", **kwargs): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start : rand_start + cond_length] + mel_clip = TorchMelSpectrogram(**kwargs)(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + print( + "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text." + ) + return codes + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + return codes + + +def do_spectrogram_diffusion( + diffusion_model, + diffuser, + latents, + conditioning_latents, + temperature=1, + verbose=True, +): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = ( + latents.shape[1] * 4 * 24000 // 22050 + ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent( + latents, conditioning_latents, output_seq_len, False + ) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.sample_loop( + diffusion_model, + output_shape, + noise=noise, + model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, + progress=verbose, + ) + return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] + + +def classify_audio_clip(clip, model_dir): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead( + 2, + spec_dim=1, + embedding_dim=512, + depth=5, + downsample_factor=4, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + base_channels=32, + dropout=0, + kernel_size=5, + distribute_zero_label=False, + ) + classifier.load_state_dict(torch.load(os.path.join(model_dir, "classifier.pth"), map_location=torch.device("cpu"))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024**3) + batch_size = 1 + if availableGb > 14: + batch_size = 16 + elif availableGb > 10: + batch_size = 8 + elif availableGb > 7: + batch_size = 4 + return batch_size + + +@dataclass +class TortoiseAudioConfig(Coqpit): + sample_rate: int = 22050 + diffusion_sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +@dataclass +class TortoiseArgs(Coqpit): + """A dataclass to represent Tortoise model arguments that define the model structure. + + Args: + autoregressive_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + high_vram (bool, optional): Whether to use high VRAM. Defaults to False. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + ar_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + diff_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + + For UnifiedVoice model: + ar_max_mel_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + ar_max_conditioning_inputs (int, optional): The maximum conditioning inputs for the autoregressive model. Defaults to 2. + ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + ar_model_dim (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + ar_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + ar_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + + For DiffTTS model: + diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024. + diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10. + diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100. + diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200. + diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024. + diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193. + diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0. + diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False. + diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16. + diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0. + diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0. + + For ConditionalLatentVariablePerseq model: + clvp_dim_text (int): The dimension of the text input for the CLVP module. Defaults to 768. + clvp_dim_speech (int): The dimension of the speech input for the CLVP module. Defaults to 768. + clvp_dim_latent (int): The dimension of the latent representation for the CLVP module. Defaults to 768. + clvp_num_text_tokens (int): The number of text tokens used by the CLVP module. Defaults to 256. + clvp_text_enc_depth (int): The depth of the text encoder in the CLVP module. Defaults to 20. + clvp_text_seq_len (int): The maximum sequence length of the text input for the CLVP module. Defaults to 350. + clvp_text_heads (int): The number of attention heads used by the text encoder in the CLVP module. Defaults to 12. + clvp_num_speech_tokens (int): The number of speech tokens used by the CLVP module. Defaults to 8192. + clvp_speech_enc_depth (int): The depth of the speech encoder in the CLVP module. Defaults to 20. + clvp_speech_heads (int): The number of attention heads used by the speech encoder in the CLVP module. Defaults to 12. + clvp_speech_seq_len (int): The maximum sequence length of the speech input for the CLVP module. Defaults to 430. + clvp_use_xformers (bool): A flag indicating whether the model uses transformers in the CLVP module. Defaults to True. + duration_const (int): A constant value used in the model. Defaults to 102400. + """ + + autoregressive_batch_size: int = 1 + enable_redaction: bool = False + high_vram: bool = False + kv_cache: bool = True + ar_checkpoint: str = None + clvp_checkpoint: str = None + diff_checkpoint: str = None + num_chars: int = 255 + vocoder: VocType = VocConf.Univnet + + # UnifiedVoice params + ar_max_mel_tokens: int = 604 + ar_max_text_tokens: int = 402 + ar_max_conditioning_inputs: int = 2 + ar_layers: int = 30 + ar_model_dim: int = 1024 + ar_heads: int = 16 + ar_number_text_tokens: int = 255 + ar_start_text_token: int = 255 + ar_checkpointing: bool = False + ar_train_solo_embeddings: bool = False + + # DiffTTS params + diff_model_channels: int = 1024 + diff_num_layers: int = 10 + diff_in_channels: int = 100 + diff_out_channels: int = 200 + diff_in_latent_channels: int = 1024 + diff_in_tokens: int = 8193 + diff_dropout: int = 0 + diff_use_fp16: bool = False + diff_num_heads: int = 16 + diff_layer_drop: int = 0 + diff_unconditioned_percentage: int = 0 + + # clvp params + clvp_dim_text: int = 768 + clvp_dim_speech: int = 768 + clvp_dim_latent: int = 768 + clvp_num_text_tokens: int = 256 + clvp_text_enc_depth: int = 20 + clvp_text_seq_len: int = 350 + clvp_text_heads: int = 12 + clvp_num_speech_tokens: int = 8192 + clvp_speech_enc_depth: int = 20 + clvp_speech_heads: int = 12 + clvp_speech_seq_len: int = 430 + clvp_use_xformers: bool = True + # constants + duration_const: int = 102400 + + +class Tortoise(BaseTTS): + """Tortoise model class. + + Currently only supports inference. + + Examples: + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> from TTS.tts.models.tortoise import Tortoise + >>> config = TortoiseConfig() + >>> model = Tortoise.inif_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.mel_norm_path = None + self.config = config + self.ar_checkpoint = self.args.ar_checkpoint + self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.autoregressive_batch_size = ( + pick_best_batch_size_for_gpu() + if self.args.autoregressive_batch_size is None + else self.args.autoregressive_batch_size + ) + self.enable_redaction = self.args.enable_redaction + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.tokenizer = VoiceBpeTokenizer() + + self.autoregressive = UnifiedVoice( + max_mel_tokens=self.args.ar_max_mel_tokens, + max_text_tokens=self.args.ar_max_text_tokens, + max_conditioning_inputs=self.args.ar_max_conditioning_inputs, + layers=self.args.ar_layers, + model_dim=self.args.ar_model_dim, + heads=self.args.ar_heads, + number_text_tokens=self.args.ar_number_text_tokens, + start_text_token=self.args.ar_start_text_token, + checkpointing=self.args.ar_checkpointing, + train_solo_embeddings=self.args.ar_train_solo_embeddings, + ).cpu() + + self.diffusion = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ).cpu() + + self.clvp = CLVP( + dim_text=self.args.clvp_dim_text, + dim_speech=self.args.clvp_dim_speech, + dim_latent=self.args.clvp_dim_latent, + num_text_tokens=self.args.clvp_num_text_tokens, + text_enc_depth=self.args.clvp_text_enc_depth, + text_seq_len=self.args.clvp_text_seq_len, + text_heads=self.args.clvp_text_heads, + num_speech_tokens=self.args.clvp_num_speech_tokens, + speech_enc_depth=self.args.clvp_speech_enc_depth, + speech_heads=self.args.clvp_speech_heads, + speech_seq_len=self.args.clvp_speech_seq_len, + use_xformers=self.args.clvp_use_xformers, + ).cpu() + + self.vocoder = self.args.vocoder.value.constructor().cpu() + + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + self.rlg_diffusion = None + + if self.args.high_vram: + self.autoregressive = self.autoregressive.to(self.device) + self.diffusion = self.diffusion.to(self.device) + self.clvp = self.clvp.to(self.device) + self.vocoder = self.vocoder.to(self.device) + self.high_vram = self.args.high_vram + + @contextmanager + def temporary_cuda(self, model): + if self.high_vram: + yield model + else: + m = model.to(self.device) + yield m + m = model.cpu() + + def get_conditioning_latents( + self, + voice_samples, + return_mels=False, + latent_averaging_mode=0, + original_tortoise=False, + ): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of arbitrary reference clips, which should be *pairs* of torch tensors containing arbitrary kHz waveform data. + :param latent_averaging_mode: 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + """ + assert latent_averaging_mode in [ + 0, + 1, + 2, + ], "latent_averaging mode has to be one of (0, 1, 2)" + + with torch.no_grad(): + voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples] + + auto_conds = [] + for ls in voice_samples: + auto_conds.append(format_conditioning(ls[0], device=self.device, mel_norm_file=self.mel_norm_path)) + auto_conds = torch.stack(auto_conds, dim=1) + with self.temporary_cuda(self.autoregressive) as ar: + auto_latent = ar.get_conditioning(auto_conds) + + diffusion_conds = [] + + DURS_CONST = self.args.duration_const + for ls in voice_samples: + # The diffuser operates at a sample rate of 24000 (except for the latent inputs) + sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1] + if latent_averaging_mode == 0: + sample = pad_or_truncate(sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + sample.to(self.device), + do_normalization=False, + device=self.device, + ) + diffusion_conds.append(cond_mel) + else: + from math import ceil + + if latent_averaging_mode == 2: + temp_diffusion_conds = [] + for chunk in range(ceil(sample.shape[1] / DURS_CONST)): + current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST] + current_sample = pad_or_truncate(current_sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + current_sample.to(self.device), + do_normalization=False, + device=self.device, + ) + if latent_averaging_mode == 1: + diffusion_conds.append(cond_mel) + elif latent_averaging_mode == 2: + temp_diffusion_conds.append(cond_mel) + if latent_averaging_mode == 2: + diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0)) + diffusion_conds = torch.stack(diffusion_conds, dim=1) + + with self.temporary_cuda(self.diffusion) as diffusion: + diffusion_latent = diffusion.get_conditioning(diffusion_conds) + + if return_mels: + return auto_latent, diffusion_latent, auto_conds, diffusion_conds + return auto_latent, diffusion_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_auto.pth"), + map_location=torch.device("cpu"), + ) + ) + self.rlg_diffusion = RandomLatentConverter(2048).eval() + self.rlg_diffusion.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_diffuser.pth"), + map_location=torch.device("cpu"), + ) + ) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) + + def synthesize(self, text, config, speaker_id="random", voice_dirs=None, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (TortoiseConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + + speaker_id = "random" if speaker_id is None else speaker_id + + if voice_dirs is not None: + voice_dirs = [voice_dirs] + voice_samples, conditioning_latents = load_voice(speaker_id, voice_dirs) + + else: + voice_samples, conditioning_latents = load_voice(speaker_id) + + outputs = self.inference_with_config( + text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs + ) + + return_dict = { + "wav": outputs["wav"], + "deterministic_seed": outputs["deterministic_seed"], + "text_inputs": outputs["text"], + "voice_samples": outputs["voice_samples"], + "conditioning_latents": outputs["conditioning_latents"], + } + + return return_dict + + def inference_with_config(self, text, config, **kwargs): + """ + inference with config + #TODO describe in detail + """ + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_p": config.top_p, + "cond_free_k": config.cond_free_k, + "diffusion_temperature": config.diffusion_temperature, + "sampler": config.sampler, + } + # Presets are defined here. + presets = { + "single_sample": { + "num_autoregressive_samples": 8, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast_old": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 30, + "cond_free": False, + }, + "very_fast": { + "num_autoregressive_samples": 32, + "diffusion_iterations": 30, + "sampler": "dpm++2m", + }, + "fast": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 50, + "sampler": "ddim", + }, + "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80}, + "standard": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 200, + }, + "high_quality": { + "num_autoregressive_samples": 256, + "diffusion_iterations": 400, + }, + } + if "preset" in kwargs: + settings.update(presets[kwargs["preset"]]) + kwargs.pop("preset") + settings.update(kwargs) # allow overriding of preset settings with kwargs + return self.inference(text, **settings) + + def inference( + self, + text, + voice_samples=None, + conditioning_latents=None, + k=1, + verbose=True, + use_deterministic_seed=None, + return_deterministic_state=False, + latent_averaging_mode=0, + # autoregressive generation parameters follow + num_autoregressive_samples=16, + temperature=0.8, + length_penalty=1, + repetition_penalty=2.0, + top_p=0.8, + max_mel_tokens=500, + # diffusion generation parameters follow + diffusion_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + sampler="ddim", + half=True, + original_tortoise=False, + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + voice_samples: (List[Tuple[torch.Tensor]]) List of an arbitrary number of reference clips, which should be tuple-pairs + of torch tensors containing arbitrary kHz waveform data. + conditioning_latents: (Tuple[autoregressive_conditioning_latent, diffusion_conditioning_latent]) A tuple of + (autoregressive_conditioning_latent, diffusion_conditioning_latent), which can be provided in lieu + of voice_samples. This is ignored unless `voice_samples=None`. Conditioning latents can be retrieved + via `get_conditioning_latents()`. + k: (int) The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + latent_averaging_mode: (int) 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + verbose: (bool) Whether or not to print log messages indicating the progress of creating a clip. Default=true. + num_autoregressive_samples: (int) Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + temperature: (float) The softmax temperature of the autoregressive model. + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce + the incidence of long silences or "uhhhhhhs", etc. + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + max_mel_tokens: (int) Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + typical_sampling: (bool) Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666 + I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but could use some tuning. + typical_mass: (float) The typical_mass parameter from the typical_sampling algorithm. + diffusion_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively + refine the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, however. + cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output of the two + is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and dramatically improves realism. + cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert ( + text_tokens.shape[-1] < 400 + ), "Too much text provided. Break the text up into separate segments and re-try inference." + + if voice_samples is not None: + ( + auto_conditioning, + diffusion_conditioning, + _, + _, + ) = self.get_conditioning_latents( + voice_samples, + return_mels=True, + latent_averaging_mode=latent_averaging_mode, + original_tortoise=original_tortoise, + ) + elif conditioning_latents is not None: + auto_conditioning, diffusion_conditioning = conditioning_latents + else: + ( + auto_conditioning, + diffusion_conditioning, + ) = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + diffusion_conditioning = diffusion_conditioning.to(self.device) + + diffuser = load_discrete_vocoder_diffuser( + desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler + ) + + # in the case of single_sample, + orig_batch_size = self.autoregressive_batch_size + while num_autoregressive_samples % self.autoregressive_batch_size: + self.autoregressive_batch_size //= 2 + with torch.no_grad(): + samples = [] + num_batches = num_autoregressive_samples // self.autoregressive_batch_size + stop_mel_token = self.autoregressive.stop_mel_token + calm_token = ( + 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + ) + self.autoregressive = self.autoregressive.to(self.device) + if verbose: + print("Generating autoregressive samples..") + with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech( + auto_conditioning, + text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs, + ) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + self.autoregressive_batch_size = orig_batch_size # in the case of single_sample + + clip_results = [] + with self.temporary_cuda(self.clvp) as clvp, torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=half + ): + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + clvp_res = clvp( + text_tokens.repeat(batch.shape[0], 1), + batch, + return_loss=False, + ) + clip_results.append(clvp_res) + + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + del samples + + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning + # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these + # results, but will increase memory usage. + with self.temporary_cuda(self.autoregressive) as autoregressive: + best_latents = autoregressive( + auto_conditioning.repeat(k, 1), + text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), + best_results, + torch.tensor( + [best_results.shape[-1] * self.autoregressive.mel_length_compression], + device=text_tokens.device, + ), + return_latent=True, + clip_inputs=False, + ) + del auto_conditioning + + if verbose: + print("Transforming autoregressive outputs into audio..") + wav_candidates = [] + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) + + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for code in range(codes.shape[-1]): + if codes[0, code] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :code] + break + with self.temporary_cuda(self.diffusion) as diffusion: + mel = do_spectrogram_diffusion( + diffusion, + diffuser, + latents, + diffusion_conditioning, + temperature=diffusion_temperature, + verbose=verbose, + ) + with self.temporary_cuda(self.vocoder) as vocoder: + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + + def potentially_redact(clip, text): + if self.enable_redaction: + return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) + return clip + + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] + + if len(wav_candidates) > 1: + res = wav_candidates + else: + res = wav_candidates[0] + + return_dict = { + "wav": res, + "deterministic_seed": None, + "text": None, + "voice_samples": None, + "conditioning_latents": None, + } + if return_deterministic_state: + return_dict = { + "wav": res, + "deterministic_seed": deterministic_seed, + "text": text, + "voice_samples": voice_samples, + "conditioning_latents": conditioning_latents, + } + return return_dict + + def forward(self): + raise NotImplementedError("Tortoise Training is not implemented") + + def eval_step(self): + raise NotImplementedError("Tortoise Training is not implemented") + + @staticmethod + def init_from_config(config: "TortoiseConfig", **kwargs): # pylint: disable=unused-argument + return Tortoise(config) + + def load_checkpoint( + self, + config, + checkpoint_dir, + ar_checkpoint_path=None, + diff_checkpoint_path=None, + clvp_checkpoint_path=None, + vocoder_checkpoint_path=None, + eval=False, + strict=True, + **kwargs, + ): # pylint: disable=unused-argument, redefined-builtin + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + if self.models_dir is None: + self.models_dir = checkpoint_dir + ar_path = ar_checkpoint_path or os.path.join(checkpoint_dir, "autoregressive.pth") + diff_path = diff_checkpoint_path or os.path.join(checkpoint_dir, "diffusion_decoder.pth") + clvp_path = clvp_checkpoint_path or os.path.join(checkpoint_dir, "clvp2.pth") + vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") + self.mel_norm_path = os.path.join(checkpoint_dir, "mel_norms.pth") + + if os.path.exists(ar_path): + # remove keys from the checkpoint that are not in the model + checkpoint = torch.load(ar_path, map_location=torch.device("cpu")) + + # strict set False + # due to removed `bias` and `masked_bias` changes in Transformers + self.autoregressive.load_state_dict(checkpoint, strict=False) + + if os.path.exists(diff_path): + self.diffusion.load_state_dict(torch.load(diff_path), strict=strict) + + if os.path.exists(clvp_path): + self.clvp.load_state_dict(torch.load(clvp_path), strict=strict) + + if os.path.exists(vocoder_checkpoint_path): + self.vocoder.load_state_dict( + config.model_args.vocoder.value.optionally_index( + torch.load( + vocoder_checkpoint_path, + map_location=torch.device("cpu"), + ) + ) + ) + + if eval: + self.autoregressive.post_init_gpt2_config(self.args.kv_cache) + self.autoregressive.eval() + self.diffusion.eval() + self.clvp.eval() + self.vocoder.eval() + + def train_step(self): + raise NotImplementedError("Tortoise Training is not implemented") diff --git a/viXTTS/TTS/tts/models/vits.py b/viXTTS/TTS/tts/models/vits.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b1f59618ab7e0ebc4cabc9a9ac40fdb9843d99 --- /dev/null +++ b/viXTTS/TTS/tts/models/vits.py @@ -0,0 +1,1999 @@ +import math +import os +from dataclasses import dataclass, field, replace +from itertools import chain +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder +from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint +from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment +from TTS.utils.io import load_fsspec +from TTS.utils.samplers import BucketBatchSampler +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + +############################## +# IO / Feature extraction +############################## + +# pylint: disable=global-statement +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################# +# CONFIGS +############################# + + +@dataclass +class VitsAudioConfig(Coqpit): + fft_size: int = 1024 + sample_rate: int = 22050 + win_length: int = 1024 + hop_length: int = 256 + num_mels: int = 80 + mel_fmin: int = 0 + mel_fmax: int = None + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create inverse frequency weights for balancing the dataset. + Use `multi_dict` to scale relative weights.""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + # check if all keys are in the multi_dict + for k in multi_dict: + assert k in unique_attr_names, f"{k} not in {unique_attr_names}" + # scale weights + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class VitsDataset(TTSDataset): + def __init__(self, model_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + self.model_args = model_args + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + if self.model_args.encoder_sample_rate is not None: + if wav.size(1) % self.model_args.encoder_sample_rate != 0: + wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] + + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "audio_unique_name": item["audio_unique_name"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - audio_unique_names: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "audio_unique_names": batch["audio_unique_name"], + } + + +############################## +# MODEL DEFINITION +############################## + + +@dataclass +class VitsArgs(Coqpit): + """VITS model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels of the decoder. Defaults to 513. + + spec_segment_size (int): + Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. + + hidden_channels (int): + Number of hidden channels of the model. Defaults to 192. + + hidden_channels_ffn_text_encoder (int): + Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. + + num_heads_text_encoder (int): + Number of attention heads of the text encoder transformer. Defaults to 2. + + num_layers_text_encoder (int): + Number of transformer layers in the text encoder. Defaults to 6. + + kernel_size_text_encoder (int): + Kernel size of the text encoder transformer FFN layers. Defaults to 3. + + dropout_p_text_encoder (float): + Dropout rate of the text encoder. Defaults to 0.1. + + dropout_p_duration_predictor (float): + Dropout rate of the duration predictor. Defaults to 0.1. + + kernel_size_posterior_encoder (int): + Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. + + dilatation_posterior_encoder (int): + Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. + + num_layers_posterior_encoder (int): + Number of posterior encoder's WaveNet layers. Defaults to 16. + + kernel_size_flow (int): + Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. + + dilatation_flow (int): + Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. + + num_layers_flow (int): + Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. + + resblock_type_decoder (str): + Type of the residual block in the decoder network. Defaults to "1". + + resblock_kernel_sizes_decoder (List[int]): + Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. + + resblock_dilation_sizes_decoder (List[List[int]]): + Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. + + upsample_rates_decoder (List[int]): + Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these + values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. + + upsample_initial_channel_decoder (int): + Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. + + upsample_kernel_sizes_decoder (List[int]): + Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + + periods_multi_period_discriminator (List[int]): + Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. + + use_sdp (bool): + Use Stochastic Duration Predictor. Defaults to True. + + noise_scale (float): + Noise scale used for the sample noise tensor in training. Defaults to 1.0. + + inference_noise_scale (float): + Noise scale used for the sample noise tensor in inference. Defaults to 0.667. + + length_scale (float): + Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. + + noise_scale_dp (float): + Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. + + inference_noise_scale_dp (float): + Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. + + max_inference_len (int): + Maximum inference length to limit the memory use. Defaults to None. + + init_discriminator (bool): + Initialize the disciminator network if set True. Set False for inference. Defaults to True. + + use_spectral_norm_disriminator (bool): + Use spectral normalization over weight norm in the discriminator. Defaults to False. + + use_speaker_embedding (bool): + Enable/Disable speaker embedding for multi-speaker models. Defaults to False. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_file (List[str]): + List of paths to the files including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + detach_dp_input (bool): + Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + + use_language_embedding (bool): + Enable/Disable language embedding for multilingual models. Defaults to False. + + embedded_language_dim (int): + Number of language embedding channels. Defaults to 4. + + num_languages (int): + Number of languages for the language embedding layer. Defaults to 0. + + language_ids_file (str): + Path to the language mapping file for the Language Manager. Defaults to None. + + use_speaker_encoder_as_loss (bool): + Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. + + speaker_encoder_config_path (str): + Path to the file speaker encoder config file, to use for SCL. Defaults to "". + + speaker_encoder_model_path (str): + Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + + freeze_encoder (bool): + Freeze the encoder weigths during training. Defaults to False. + + freeze_DP (bool): + Freeze the duration predictor weigths during training. Defaults to False. + + freeze_PE (bool): + Freeze the posterior encoder weigths during training. Defaults to False. + + freeze_flow_encoder (bool): + Freeze the flow encoder weigths during training. Defaults to False. + + freeze_waveform_decoder (bool): + Freeze the waveform decoder weigths during training. Defaults to False. + + encoder_sample_rate (int): + If not None this sample rate will be used for training the Posterior Encoder, + flow, text_encoder and duration predictor. The decoder part (vocoder) will be + trained with the `config.audio.sample_rate`. Defaults to None. + + interpolate_z (bool): + If `encoder_sample_rate` not None and this parameter True the nearest interpolation + will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` + to the `config.audio.sample_rate`. If it is False you will need to add extra + `upsample_rates_decoder` to match the shape. Defaults to True. + + """ + + num_chars: int = 100 + out_channels: int = 513 + spec_segment_size: int = 32 + hidden_channels: int = 192 + hidden_channels_ffn_text_encoder: int = 768 + num_heads_text_encoder: int = 2 + num_layers_text_encoder: int = 6 + kernel_size_text_encoder: int = 3 + dropout_p_text_encoder: float = 0.1 + dropout_p_duration_predictor: float = 0.5 + kernel_size_posterior_encoder: int = 5 + dilation_rate_posterior_encoder: int = 1 + num_layers_posterior_encoder: int = 16 + kernel_size_flow: int = 5 + dilation_rate_flow: int = 1 + num_layers_flow: int = 4 + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + use_sdp: bool = True + noise_scale: float = 1.0 + inference_noise_scale: float = 0.667 + length_scale: float = 1 + noise_scale_dp: float = 1.0 + inference_noise_scale_dp: float = 1.0 + max_inference_len: int = None + init_discriminator: bool = True + use_spectral_norm_disriminator: bool = False + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: List[str] = None + speaker_embedding_channels: int = 256 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + detach_dp_input: bool = True + use_language_embedding: bool = False + embedded_language_dim: int = 4 + num_languages: int = 0 + language_ids_file: str = None + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True + freeze_encoder: bool = False + freeze_DP: bool = False + freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False + encoder_sample_rate: int = None + interpolate_z: bool = True + reinit_DP: bool = False + reinit_text_encoder: bool = False + + +class Vits(BaseTTS): + """VITS TTS model + + Paper:: + https://arxiv.org/pdf/2106.06103.pdf + + Paper Abstract:: + Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel + sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. + In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than + current two-stage models. Our method adopts variational inference augmented with normalizing flows and + an adversarial training process, which improves the expressive power of generative modeling. We also propose a + stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the + uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the + natural one-to-many relationship in which a text input can be spoken in multiple ways + with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) + on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly + available TTS systems and achieves a MOS comparable to ground truth. + + Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. + + Examples: + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> from TTS.tts.models.vits import Vits + >>> config = VitsConfig() + >>> model = Vits(config) + """ + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) + + self.init_multispeaker(config) + self.init_multilingual(config) + self.init_upsampling() + + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size + + self.text_encoder = TextEncoder( + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, + ) + + self.posterior_encoder = PosteriorEncoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, + cond_channels=self.embedded_speaker_dim, + ) + + self.flow = ResidualCouplingBlocks( + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, + cond_channels=self.embedded_speaker_dim, + ) + + if self.args.use_sdp: + self.duration_predictor = StochasticDurationPredictor( + self.args.hidden_channels, + 192, + 3, + self.args.dropout_p_duration_predictor, + 4, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + language_emb_dim=self.embedded_language_dim, + ) + else: + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + 256, + 3, + self.args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, + ) + + self.waveform_decoder = HifiganGenerator( + self.args.hidden_channels, + 1, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, + inference_padding=0, + cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.args.init_discriminator: + self.disc = VitsDiscriminator( + periods=self.args.periods_multi_period_discriminator, + use_spectral_norm=self.args.use_spectral_norm_disriminator, + ) + + @property + def device(self): + return next(self.parameters()).device + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + # TODO: make this a function + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.encoder is None and ( + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + + self.speaker_manager.encoder.eval() + print(" > External Speaker Encoder Loaded !!") + + if ( + hasattr(self.speaker_manager.encoder, "audio_config") + and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"] + ): + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], + ) + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + print(" > initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + def init_multilingual(self, config: Coqpit): + """Initialize multilingual modules of a model. + + Args: + config (Coqpit): Model configuration. + """ + if self.args.language_ids_file is not None: + self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + + if self.args.use_language_embedding and self.language_manager: + print(" > initialization of language-embedding layers.") + self.num_languages = self.language_manager.num_languages + self.embedded_language_dim = self.args.embedded_language_dim + self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) + else: + self.embedded_language_dim = 0 + + def init_upsampling(self): + """ + Initialize upsampling modules of a model. + """ + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate + ) # pylint: disable=W0201 + + def on_epoch_start(self, trainer): # pylint: disable=W0613 + """Freeze layers at the beginning of an epoch""" + self._freeze_layers() + # set the device of speaker encoder + if self.args.use_speaker_encoder_as_loss: + self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device) + + def on_init_end(self, trainer): # pylint: disable=W0613 + """Reinit layes if needed""" + if self.args.reinit_DP: + before_dict = get_module_weights_sum(self.duration_predictor) + # Applies weights_reset recursively to every submodule of the duration predictor + self.duration_predictor.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.duration_predictor) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") + print(" > Duration Predictor was reinit.") + + if self.args.reinit_text_encoder: + before_dict = get_module_weights_sum(self.text_encoder) + # Applies weights_reset recursively to every submodule of the duration predictor + self.text_encoder.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.text_encoder) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") + print(" > Text Encoder was reinit.") + + def get_aux_input(self, aux_input: Dict): + sid, g, lid, _ = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid, durations = None, None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn + + def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): + spec_segment_size = self.spec_segment_size + if self.args.encoder_sample_rate: + # recompute the slices and spec_segment_size if needed + slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids + spec_segment_size = spec_segment_size * int(self.interpolate_factor) + # interpolate z if needed + if self.args.interpolate_z: + z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0) + # recompute the mask if needed + if y_lengths is not None and y_mask is not None: + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + return z, spec_segment_size, slice_ids, y_mask + + def forward( # pylint: disable=dangerous-default-value + self, + x: torch.tensor, + x_lengths: torch.tensor, + y: torch.tensor, + y_lengths: torch.tensor, + waveform: torch.tensor, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + ) -> Dict: + """Forward pass of the model. + + Args: + x (torch.tensor): Batch of input character sequence IDs. + x_lengths (torch.tensor): Batch of input character sequence lengths. + y (torch.tensor): Batch of input spectrograms. + y_lengths (torch.tensor): Batch of input spectrogram lengths. + waveform (torch.tensor): Batch of ground truth waveforms per sample. + aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. + Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. + + Returns: + Dict: model outputs keyed by the output name. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - y: :math:`[B, C, T_spec]` + - y_lengths: :math:`[B]` + - waveform: :math:`[B, 1, T_wav]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + """ + outputs = {} + sid, g, lid, _ = self._set_cond_input(aux_input) + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # duration predictor + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + + # select a random feature segment for the waveform decoder + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + + # interpolate z if needed + z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) + + o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform, + slice_ids * self.config.audio.hop_length, + spec_segment_size * self.config.audio.hop_length, + pad_short=True, + ) + + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0) + + # resample audio to speaker encoder sample_rate + # pylint: disable=W0105 + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + + pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + + outputs.update( + { + "model_outputs": o, + "alignments": attn.squeeze(1), + "m_p": m_p, + "logs_p": logs_p, + "z": z, + "z_p": z_p, + "m_q": m_q, + "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, + } + ) + return outputs + + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, + x, + aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, + ): # pylint: disable=dangerous-default-value + """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` + - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + """ + sid, g, lid, durations = self._set_cond_input(aux_input) + x_lengths = self._set_x_lengths(x, aux_input) + + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + if durations is None: + if self.args.use_sdp: + logw = self.duration_predictor( + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, + ) + else: + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) + w = torch.exp(logw) * x_mask * self.length_scale + else: + assert durations.shape[-1] == x.shape[-1] + w = durations.unsqueeze(0) + + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + + m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + + # upsampling if needed + z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) + + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) + + outputs = { + "model_outputs": o, + "alignments": attn.squeeze(1), + "durations": w_ceil, + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference_voice_conversion( + self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None + ): + """Inference for voice conversion + + Args: + reference_wav (Tensor): Reference wavform. Tensor of shape [B, T] + speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B] + d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]` + reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B] + reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` + """ + # compute spectrograms + y = wav_to_spec( + reference_wav, + self.config.audio.fft_size, + self.config.audio.hop_length, + self.config.audio.win_length, + center=False, + ) + y_lengths = torch.tensor([y.size(-1)]).to(y.device) + speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector + speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector + wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) + return wav + + def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): + """Forward pass for voice conversion + + TODO: create an end-point for voice conversion + + Args: + y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] + y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] + speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] + speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] + """ + assert self.num_speakers > 0, "num_speakers have to be larger than 0." + # speaker embedding + if self.args.use_speaker_embedding and not self.args.use_d_vector_file: + g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1) + g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1) + elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: + g_src = F.normalize(speaker_cond_src).unsqueeze(-1) + g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) + else: + raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") + + z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + + spec_lens = batch["spec_lens"] + + if optimizer_idx == 0: + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + waveform = batch["waveform"] + + # generator pass + outputs = self.forward( + tokens, + token_lenghts, + spec, + spec_lens, + waveform, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + # compute scores and features + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_real, + scores_disc_fake, + ) + return outputs, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + if self.args.encoder_sample_rate: + spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) + else: + spec_segment_size = self.spec_segment_size + + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True + ) + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), + z_len=spec_lens, + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + figures = plot_results(y_hat, y, ap, name_prefix) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name_prefix}/audio": sample_voice} + + alignments = outputs[1]["alignments"] + align_img = alignments[0].data.cpu().numpy().T + + figures.update( + { + "alignment": plot_alignment(align_img, output_fig=False), + } + ) + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.name_to_id and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding: + language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + if self.args.encoder_sample_rate: + wav = self.audio_resampler(batch["waveform"]) + else: + wav = batch["waveform"] + + # compute spectrograms + batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + + if self.args.encoder_sample_rate: + # recompute spec with high sampling rate to the loss + spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + # remove extra stft frames if needed + if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] + else: + spec_mel = batch["spec"] + + batch["mel"] = spec_to_mel( + spec=spec_mel, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + ) + + if self.args.encoder_sample_rate: + assert batch["spec"].shape[2] == int( + batch["mel"].shape[2] / self.interpolate_factor + ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + else: + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + + if self.args.encoder_sample_rate: + assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0 + else: + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] + + if weights is not None: + w_sampler = WeightedRandomSampler(weights, len(weights)) + batch_sampler = BucketBatchSampler( + w_sampler, + data=data_items, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + sort_key=lambda x: os.path.getsize(x["audio_file"]), + drop_last=True, + ) + else: + batch_sampler = None + # sampler for DDP + if batch_sampler is None: + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + batch_sampler = ( + DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler + ) # TODO: check batch_sampler with multi-gpu + return batch_sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = VitsDataset( + model_args=self.args, + samples=samples, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + verbose=verbose, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + if sampler is None: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=False, # setting this False might cause issues in AMP training. + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + if num_gpus > 1: + loader = DataLoader( + dataset, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + # select generator parameters + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer0, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def get_criterion(self): + """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in + `train_step()`""" + from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel + VitsDiscriminatorLoss, + VitsGeneratorLoss, + ) + + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] + + def load_checkpoint( + self, config, checkpoint_path, eval=False, strict=True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + # compat band-aid for the pre-trained models to not use the encoder baked into the model + # TODO: consider baking the speaker encoder into the model and call it from there. + # as it is probably easier for model distribution. + state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} + + if self.args.encoder_sample_rate is not None and eval: + # audio resampler is not used in inference time + self.audio_resampler = None + + # handle fine-tuning from a checkpoint with additional speakers + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_speakers} additional speakers.") + emb_g = state["model"]["emb_g.weight"] + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) + + if eval: + self.eval() + assert not self.training + + def load_fairseq_checkpoint( + self, config, checkpoint_dir, eval=False, strict=True + ): # pylint: disable=unused-argument, redefined-builtin + """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms + Performs some changes for compatibility. + + Args: + config (Coqpit): 🐸TTS model config. + checkpoint_dir (str): Path to the checkpoint directory. + eval (bool, optional): Set to True for evaluation. Defaults to False. + """ + import json + + from TTS.tts.utils.text.cleaners import basic_cleaners + + self.disc = None + # set paths + config_file = os.path.join(checkpoint_dir, "config.json") + checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") + vocab_file = os.path.join(checkpoint_dir, "vocab.txt") + # set config params + with open(config_file, "r", encoding="utf-8") as file: + # Load the JSON data as a dictionary + config_org = json.load(file) + self.config.audio.sample_rate = config_org["data"]["sampling_rate"] + # self.config.add_blank = config['add_blank'] + # set tokenizer + vocab = FairseqVocab(vocab_file) + self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) + self.tokenizer = TTSTokenizer( + use_phonemes=False, + text_cleaner=basic_cleaners, + characters=vocab, + phonemizer=None, + add_blank=config_org["data"]["add_blank"], + use_eos_bos=False, + ) + # load fairseq checkpoint + new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) + self.load_state_dict(new_chk, strict=strict) + if eval: + self.eval() + assert not self.training + + @staticmethod + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() + + if not config.model_args.encoder_sample_rate: + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + else: + encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate + effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor + assert ( + upsample_rate == effective_hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" + + ap = AudioProcessor.init_from_config(config, verbose=verbose) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path: + speaker_manager.init_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + + def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): + """Export model to ONNX format for inference + + Args: + output_path (str): Path to save the exported model. + verbose (bool): Print verbose information. Defaults to True. + """ + + # rollback values + _forward = self.forward + disc = None + if hasattr(self, "disc"): + disc = self.disc + training = self.training + + # set export mode + self.disc = None + self.eval() + + def onnx_inference(text, text_lengths, scales, sid=None, langid=None): + noise_scale = scales[0] + length_scale = scales[1] + noise_scale_dp = scales[2] + self.noise_scale = noise_scale + self.length_scale = length_scale + self.noise_scale_dp = noise_scale_dp + return self.inference( + text, + aux_input={ + "x_lengths": text_lengths, + "d_vectors": None, + "speaker_ids": sid, + "language_ids": langid, + "durations": None, + }, + )["model_outputs"] + + self.forward = onnx_inference + + # set dummy inputs + dummy_input_length = 100 + sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long) + sequence_lengths = torch.LongTensor([sequences.size(1)]) + scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) + dummy_input = (sequences, sequence_lengths, scales) + input_names = ["input", "input_lengths", "scales"] + + if self.num_speakers > 0: + speaker_id = torch.LongTensor([0]) + dummy_input += (speaker_id,) + input_names.append("sid") + + if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0: + language_id = torch.LongTensor([0]) + dummy_input += (language_id,) + input_names.append("langid") + + # export to ONNX + torch.onnx.export( + model=self, + args=dummy_input, + opset_version=15, + f=output_path, + verbose=verbose, + input_names=input_names, + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size", 1: "phonemes"}, + "input_lengths": {0: "batch_size"}, + "output": {0: "batch_size", 1: "time1", 2: "time2"}, + }, + ) + + # rollback + self.forward = _forward + if training: + self.train() + if not disc is None: + self.disc = disc + + def load_onnx(self, model_path: str, cuda=False): + import onnxruntime as ort + + providers = [ + "CPUExecutionProvider" + if cuda is False + else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) + ] + sess_options = ort.SessionOptions() + self.onnx_sess = ort.InferenceSession( + model_path, + sess_options=sess_options, + providers=providers, + ) + + def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None): + """ONNX inference""" + + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() + + if x_lengths is None: + x_lengths = np.array([x.shape[1]], dtype=np.int64) + + if isinstance(x_lengths, torch.Tensor): + x_lengths = x_lengths.cpu().numpy() + scales = np.array( + [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], + dtype=np.float32, + ) + input_params = {"input": x, "input_lengths": x_lengths, "scales": scales} + if not speaker_id is None: + input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() + if not language_id is None: + input_params["langid"] = torch.tensor([language_id]).cpu().numpy() + + audio = self.onnx_sess.run( + ["output"], + input_params, + ) + return audio[0][0] + + +################################## +# VITS CHARACTERS +################################## + + +class VitsCharacters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + # pylint: disable=unnecessary-comprehension + self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)} + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = VitsCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + ) + + +class FairseqVocab(BaseVocabulary): + def __init__(self, vocab: str): + super(FairseqVocab).__init__() + self.vocab = vocab + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab_file): + with open(vocab_file, encoding="utf-8") as f: + self._vocab = [x.replace("\n", "") for x in f.readlines()] + self.blank = self._vocab[0] + self.pad = " " + self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension + self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension diff --git a/viXTTS/TTS/tts/models/xtts.py b/viXTTS/TTS/tts/models/xtts.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9d6bd38297fa440930318c6b485303ffb7d76d --- /dev/null +++ b/viXTTS/TTS/tts/models/xtts.py @@ -0,0 +1,791 @@ +import os +from dataclasses import dataclass + +import librosa +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit + +from TTS.tts.layers.xtts.gpt import GPT +from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder +from TTS.tts.layers.xtts.stream_generator import init_stream_support +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence +from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager +from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.io import load_fsspec + +init_stream_support() + + +def wav_to_mel_cloning( + wav, + mel_norms_file="../experiments/clips_mel_norms.pth", + mel_norms=None, + device=torch.device("cpu"), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, +): + """ + Convert waveform to mel-spectrogram with hard-coded parameters for cloning. + + Args: + wav (torch.Tensor): Input waveform tensor. + mel_norms_file (str): Path to mel-spectrogram normalization file. + mel_norms (torch.Tensor): Mel-spectrogram normalization tensor. + device (torch.device): Device to use for computation. + + Returns: + torch.Tensor: Mel-spectrogram tensor. + """ + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + power=power, + normalized=normalized, + sample_rate=sample_rate, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +def load_audio(audiopath, sampling_rate): + # better load setting following: https://github.com/faroit/python_audio_loading_benchmark + + # torchaudio should chose proper backend to load audio depending on platform + audio, lsr = torchaudio.load(audiopath) + + # stereo to mono if needed + if audio.size(0) != 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 10) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + # clip audio invalid values + audio.clip_(-1, 1) + return audio + + +def pad_or_truncate(t, length): + """ + Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. + + Args: + t (torch.Tensor): The input tensor to be padded or truncated. + length (int): The desired length of the tensor. + + Returns: + torch.Tensor: The padded or truncated tensor. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +@dataclass +class XttsAudioConfig(Coqpit): + """ + Configuration class for audio-related parameters in the XTTS model. + + Args: + sample_rate (int): The sample rate in which the GPT operates. + output_sample_rate (int): The sample rate of the output audio waveform. + """ + + sample_rate: int = 22050 + output_sample_rate: int = 24000 + + +@dataclass +class XttsArgs(Coqpit): + """A dataclass to represent XTTS model arguments that define the model structure. + + Args: + gpt_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + + For GPT model: + gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. + gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024. + gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True. + gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False. + """ + + gpt_batch_size: int = 1 + enable_redaction: bool = False + kv_cache: bool = True + gpt_checkpoint: str = None + clvp_checkpoint: str = None + decoder_checkpoint: str = None + num_chars: int = 255 + + # XTTS GPT Encoder params + tokenizer_file: str = "" + gpt_max_audio_tokens: int = 605 + gpt_max_text_tokens: int = 402 + gpt_max_prompt_tokens: int = 70 + gpt_layers: int = 30 + gpt_n_model_channels: int = 1024 + gpt_n_heads: int = 16 + gpt_number_text_tokens: int = None + gpt_start_text_token: int = None + gpt_stop_text_token: int = None + gpt_num_audio_tokens: int = 8194 + gpt_start_audio_token: int = 8192 + gpt_stop_audio_token: int = 8193 + gpt_code_stride_len: int = 1024 + gpt_use_masking_gt_prompt_approach: bool = True + gpt_use_perceiver_resampler: bool = False + + # HifiGAN Decoder params + input_sample_rate: int = 22050 + output_sample_rate: int = 24000 + output_hop_length: int = 256 + decoder_input_dim: int = 1024 + d_vector_dim: int = 512 + cond_d_vector_in_each_upsampling_layer: bool = True + + # constants + duration_const: int = 102400 + + +class Xtts(BaseTTS): + """ⓍTTS model implementation. + + ❗ Currently it only supports inference. + + Examples: + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> from TTS.tts.models.xtts import Xtts + >>> config = XttsConfig() + >>> model = Xtts.inif_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.mel_stats_path = None + self.config = config + self.gpt_checkpoint = self.args.gpt_checkpoint + self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.gpt_batch_size = self.args.gpt_batch_size + + self.tokenizer = VoiceBpeTokenizer() + self.gpt = None + self.init_models() + self.register_buffer("mel_stats", torch.ones(80)) + + def init_models(self): + """Initialize the models. We do it here since we need to load the tokenizer first.""" + if self.tokenizer.tokenizer is not None: + self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens() + self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") + self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") + + if self.args.gpt_number_text_tokens: + self.gpt = GPT( + layers=self.args.gpt_layers, + model_dim=self.args.gpt_n_model_channels, + start_text_token=self.args.gpt_start_text_token, + stop_text_token=self.args.gpt_stop_text_token, + heads=self.args.gpt_n_heads, + max_text_tokens=self.args.gpt_max_text_tokens, + max_mel_tokens=self.args.gpt_max_audio_tokens, + max_prompt_tokens=self.args.gpt_max_prompt_tokens, + number_text_tokens=self.args.gpt_number_text_tokens, + num_audio_tokens=self.args.gpt_num_audio_tokens, + start_audio_token=self.args.gpt_start_audio_token, + stop_audio_token=self.args.gpt_stop_audio_token, + use_perceiver_resampler=self.args.gpt_use_perceiver_resampler, + code_stride_len=self.args.gpt_code_stride_len, + ) + + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.gpt_code_stride_len, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) + + @property + def device(self): + return next(self.parameters()).device + + @torch.inference_mode() + def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6): + """Compute the conditioning latents for the GPT model from the given audio. + + Args: + audio (tensor): audio tensor. + sr (int): Sample rate of the audio. + length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30. + chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio + is being used without chunking. It must be < `length`. Defaults to 6. + """ + if sr != 22050: + audio = torchaudio.functional.resample(audio, sr, 22050) + if length > 0: + audio = audio[:, : 22050 * length] + if self.args.gpt_use_perceiver_resampler: + style_embs = [] + for i in range(0, audio.shape[1], 22050 * chunk_length): + audio_chunk = audio[:, i : i + 22050 * chunk_length] + + # if the chunk is too short ignore it + if audio_chunk.size(-1) < 22050 * 0.33: + continue + + mel_chunk = wav_to_mel_cloning( + audio_chunk, + mel_norms=self.mel_stats.cpu(), + n_fft=2048, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None) + style_embs.append(style_emb) + + # mean style embedding + cond_latent = torch.stack(style_embs).mean(dim=0) + else: + mel = wav_to_mel_cloning( + audio, + mel_norms=self.mel_stats.cpu(), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + cond_latent = self.gpt.get_style_emb(mel.to(self.device)) + return cond_latent.transpose(1, 2) + + @torch.inference_mode() + def get_speaker_embedding(self, audio, sr): + audio_16k = torchaudio.functional.resample(audio, sr, 16000) + return ( + self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True) + .unsqueeze(-1) + .to(self.device) + ) + + @torch.inference_mode() + def get_conditioning_latents( + self, + audio_path, + max_ref_length=30, + gpt_cond_len=6, + gpt_cond_chunk_len=6, + librosa_trim_db=None, + sound_norm_refs=False, + load_sr=22050, + ): + """Get the conditioning latents for the GPT model from the given audio. + + Args: + audio_path (str or List[str]): Path to reference audio file(s). + max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30. + gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6. + gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6. + librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None. + sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False. + load_sr (int, optional): Sample rate to load the audio. Defaults to 24000. + """ + # deal with multiples references + if not isinstance(audio_path, list): + audio_paths = [audio_path] + else: + audio_paths = audio_path + + speaker_embeddings = [] + audios = [] + speaker_embedding = None + for file_path in audio_paths: + audio = load_audio(file_path, load_sr) + audio = audio[:, : load_sr * max_ref_length].to(self.device) + if sound_norm_refs: + audio = (audio / torch.abs(audio).max()) * 0.75 + if librosa_trim_db is not None: + audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + + # compute latents for the decoder + speaker_embedding = self.get_speaker_embedding(audio, load_sr) + speaker_embeddings.append(speaker_embedding) + + audios.append(audio) + + # merge all the audios and compute the latents for the gpt + full_audio = torch.cat(audios, dim=-1) + gpt_cond_latents = self.get_gpt_cond_latents( + full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len + ) # [1, 1024, T] + + if speaker_embeddings: + speaker_embedding = torch.stack(speaker_embeddings) + speaker_embedding = speaker_embedding.mean(dim=0) + + return gpt_cond_latents, speaker_embedding + + def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (XttsConfig): Config with inference parameters. + speaker_wav (list): List of paths to the speaker audio files to be used for cloning. + language (str): Language ID of the speaker. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + assert ( + "zh-cn" if language == "zh" else language in self.config.languages + ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_k": config.top_k, + "top_p": config.top_p, + } + settings.update(kwargs) # allow overriding of preset settings with kwargs + if speaker_id is not None: + gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() + return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) + settings.update({ + "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, + "max_ref_len": config.max_ref_len, + "sound_norm_refs": config.sound_norm_refs, + }) + return self.full_inference(text, speaker_wav, language, **settings) + + @torch.inference_mode() + def full_inference( + self, + text, + ref_audio_path, + language, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Cloning + gpt_cond_len=30, + gpt_cond_chunk_len=6, + max_ref_len=10, + sound_norm_refs=False, + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + + ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3 + seconds long. + + language: (str) Language of the voice to be generated. + + temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65. + + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the + model to produce more terse outputs. Defaults to 1.0. + + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during + decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0. + + top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 50. + + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 0.8. + + gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used + else the first `gpt_cond_len` secs is used. Defaults to 30 seconds. + + gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds. + + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive + transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( + audio_path=ref_audio_path, + gpt_cond_len=gpt_cond_len, + gpt_cond_chunk_len=gpt_cond_chunk_len, + max_ref_length=max_ref_len, + sound_norm_refs=sound_norm_refs, + ) + + return self.inference( + text, + language, + gpt_cond_latent, + speaker_embedding, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + **hf_generate_kwargs, + ) + + @torch.inference_mode() + def inference( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + num_beams=1, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs, + ): + language = language.split("-")[0] # remove the country code + length_scale = 1.0 / max(speed, 0.05) + gpt_cond_latent = gpt_cond_latent.to(self.device) + speaker_embedding = speaker_embedding.to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + wavs = [] + gpt_latents_list = [] + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + with torch.no_grad(): + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) + + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + + gpt_latents_list.append(gpt_latents.cpu()) + wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) + + return { + "wav": torch.cat(wavs, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(), + "speaker_embedding": speaker_embedding, + } + + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + # cross fade the overlap section + if overlap_len > len(wav_chunk): + # wav_chunk is smaller than overlap_len, pass on last wav_gen + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :] + else: + # not expecting will hit here as problem happens on last chunk + wav_chunk = wav_gen[-overlap_len:] + return wav_chunk, wav_gen, None + else: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + @torch.inference_mode() + def inference_stream( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # Streaming + stream_chunk_size=20, + overlap_wav_len=1024, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs, + ): + language = language.split("-")[0] # remove the country code + length_scale = 1.0 / max(speed, 0.05) + gpt_cond_latent = gpt_cond_latent.to(self.device) + speaker_embedding = speaker_embedding.to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk + + def forward(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) + + def eval_step(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) + + @staticmethod + def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument + return Xtts(config) + + def eval(self): # pylint: disable=redefined-builtin + """Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode.""" + self.gpt.init_gpt_for_inference() + super().eval() + + def get_compatible_checkpoint_state_dict(self, model_path): + checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] + # remove xtts gpt trainer extra keys + ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] + for key in list(checkpoint.keys()): + # check if it is from the coqui Trainer if so convert it + if key.startswith("xtts."): + new_key = key.replace("xtts.", "") + checkpoint[new_key] = checkpoint[key] + del checkpoint[key] + key = new_key + + # remove unused keys + if key.split(".")[0] in ignore_keys: + del checkpoint[key] + + return checkpoint + + def load_checkpoint( + self, + config, + checkpoint_dir=None, + checkpoint_path=None, + vocab_path=None, + eval=True, + strict=True, + use_deepspeed=False, + speaker_file_path=None, + ): + """ + Loads a checkpoint from disk and initializes the model's state and tokenizer. + + Args: + config (dict): The configuration dictionary for the model. + checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. + checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. + vocab_path (str, optional): The path to the vocabulary file. Defaults to None. + eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. + strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. + + Returns: + None + """ + + model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") + vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") + + if speaker_file_path is None and checkpoint_dir is not None: + speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth") + + self.language_manager = LanguageManager(config) + self.speaker_manager = None + if speaker_file_path is not None and os.path.exists(speaker_file_path): + self.speaker_manager = SpeakerManager(speaker_file_path) + + if os.path.exists(vocab_path): + self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) + + self.init_models() + + checkpoint = self.get_compatible_checkpoint_state_dict(model_path) + + # deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not + try: + self.load_state_dict(checkpoint, strict=strict) + except: + if eval: + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + self.load_state_dict(checkpoint, strict=strict) + + if eval: + self.hifigan_decoder.eval() + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) + self.gpt.eval() + + def train_step(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training" + ) diff --git a/viXTTS/TTS/tts/utils/__init__.py b/viXTTS/TTS/tts/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5729f5002fc4be4f1b6e23fc0a5e8a36910c26c4 Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/data.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41deead3c803809cd4b67d3f6feb467752d3efb3 Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/data.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/helpers.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9cf35c8125f3483997439b0fc23dff390b4134c Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/helpers.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/languages.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/languages.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df4f5071151c5233c6209d69a800832f28489ea1 Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/languages.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/managers.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/managers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..326614317d9564a9abfc657b34ef563a2957a8e5 Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/managers.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/speakers.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/speakers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..413f98cbf392ce40066726fde929bdedca721cf2 Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/speakers.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/ssim.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/ssim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb0682f755d6ba1686d43544ba2b7d5f8ba026ab Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/ssim.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/synthesis.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/synthesis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7aa5f3d4ec70da7859c5c0db668be9225f12860 Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/synthesis.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/__pycache__/visual.cpython-310.pyc b/viXTTS/TTS/tts/utils/__pycache__/visual.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1002ad3847ba1da9ffc563af92f50498d048fc77 Binary files /dev/null and b/viXTTS/TTS/tts/utils/__pycache__/visual.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/assets/tortoise/tokenizer.json b/viXTTS/TTS/tts/utils/assets/tortoise/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..a128f273053e465a15c488e48d8106e0c8b0898e --- /dev/null +++ b/viXTTS/TTS/tts/utils/assets/tortoise/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} \ No newline at end of file diff --git a/viXTTS/TTS/tts/utils/data.py b/viXTTS/TTS/tts/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..22e46b683adfc7f6c7c8a57fb5b697e422cd915c --- /dev/null +++ b/viXTTS/TTS/tts/utils/data.py @@ -0,0 +1,79 @@ +import bisect + +import numpy as np +import torch + + +def _pad_data(x, length): + _pad = 0 + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) + + +def prepare_data(inputs): + max_len = max((len(x) for x in inputs)) + return np.stack([_pad_data(x, max_len) for x in inputs]) + + +def _pad_tensor(x, length): + _pad = 0.0 + assert x.ndim == 2 + x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode="constant", constant_values=_pad) + return x + + +def prepare_tensor(inputs, out_steps): + max_len = max((x.shape[1] for x in inputs)) + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_tensor(x, pad_len) for x in inputs]) + + +def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray: + """Pad stop target array. + + Args: + x (np.ndarray): Stop target array. + length (int): Length after padding. + pad_val (int, optional): Padding value. Defaults to 1. + + Returns: + np.ndarray: Padded stop target array. + """ + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad_val) + + +def prepare_stop_target(inputs, out_steps): + """Pad row vectors with 1.""" + max_len = max((x.shape[0] for x in inputs)) + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_stop_target(x, pad_len) for x in inputs]) + + +def pad_per_step(inputs, pad_len): + return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0) + + +def get_length_balancer_weights(items: list, num_buckets=10): + # get all durations + audio_lengths = np.array([item["audio_length"] for item in items]) + # create the $num_buckets buckets classes based in the dataset max and min length + max_length = int(max(audio_lengths)) + min_length = int(min(audio_lengths)) + step = int((max_length - min_length) / num_buckets) + 1 + buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)] + # add each sample in their respective length bucket + buckets_names = np.array( + [buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items] + ) + # count and compute the weights_bucket for each sample + unique_buckets_names = np.unique(buckets_names).tolist() + bucket_ids = [unique_buckets_names.index(l) for l in buckets_names] + bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names]) + weight_bucket = 1.0 / bucket_count + dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/viXTTS/TTS/tts/utils/fairseq.py b/viXTTS/TTS/tts/utils/fairseq.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8eec2b4ee0d7b0c79e368616d4b75fb2e551d4 --- /dev/null +++ b/viXTTS/TTS/tts/utils/fairseq.py @@ -0,0 +1,48 @@ +import torch + + +def rehash_fairseq_vits_checkpoint(checkpoint_file): + chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"] + new_chk = {} + for k, v in chk.items(): + if "enc_p." in k: + new_chk[k.replace("enc_p.", "text_encoder.")] = v + elif "dec." in k: + new_chk[k.replace("dec.", "waveform_decoder.")] = v + elif "enc_q." in k: + new_chk[k.replace("enc_q.", "posterior_encoder.")] = v + elif "flow.flows.2." in k: + new_chk[k.replace("flow.flows.2.", "flow.flows.1.")] = v + elif "flow.flows.4." in k: + new_chk[k.replace("flow.flows.4.", "flow.flows.2.")] = v + elif "flow.flows.6." in k: + new_chk[k.replace("flow.flows.6.", "flow.flows.3.")] = v + elif "dp.flows.0.m" in k: + new_chk[k.replace("dp.flows.0.m", "duration_predictor.flows.0.translation")] = v + elif "dp.flows.0.logs" in k: + new_chk[k.replace("dp.flows.0.logs", "duration_predictor.flows.0.log_scale")] = v + elif "dp.flows.1" in k: + new_chk[k.replace("dp.flows.1", "duration_predictor.flows.1")] = v + elif "dp.flows.3" in k: + new_chk[k.replace("dp.flows.3", "duration_predictor.flows.2")] = v + elif "dp.flows.5" in k: + new_chk[k.replace("dp.flows.5", "duration_predictor.flows.3")] = v + elif "dp.flows.7" in k: + new_chk[k.replace("dp.flows.7", "duration_predictor.flows.4")] = v + elif "dp.post_flows.0.m" in k: + new_chk[k.replace("dp.post_flows.0.m", "duration_predictor.post_flows.0.translation")] = v + elif "dp.post_flows.0.logs" in k: + new_chk[k.replace("dp.post_flows.0.logs", "duration_predictor.post_flows.0.log_scale")] = v + elif "dp.post_flows.1" in k: + new_chk[k.replace("dp.post_flows.1", "duration_predictor.post_flows.1")] = v + elif "dp.post_flows.3" in k: + new_chk[k.replace("dp.post_flows.3", "duration_predictor.post_flows.2")] = v + elif "dp.post_flows.5" in k: + new_chk[k.replace("dp.post_flows.5", "duration_predictor.post_flows.3")] = v + elif "dp.post_flows.7" in k: + new_chk[k.replace("dp.post_flows.7", "duration_predictor.post_flows.4")] = v + elif "dp." in k: + new_chk[k.replace("dp.", "duration_predictor.")] = v + else: + new_chk[k] = v + return new_chk diff --git a/viXTTS/TTS/tts/utils/helpers.py b/viXTTS/TTS/tts/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7b37201f8410eb34300d8bb2b1a595d5c5cfc42f --- /dev/null +++ b/viXTTS/TTS/tts/utils/helpers.py @@ -0,0 +1,258 @@ +import numpy as np +import torch +from scipy.stats import betabinom +from torch.nn import functional as F + +try: + from TTS.tts.utils.monotonic_align.core import maximum_path_c + + CYTHON = True +except ModuleNotFoundError: + CYTHON = False + + +class StandardScaler: + """StandardScaler for mean-scale normalization with the given mean and scale values.""" + + def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None: + self.mean_ = mean + self.scale_ = scale + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, "mean_") + delattr(self, "scale_") + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def sequence_mask(sequence_length, max_len=None): + """Create a sequence mask for filtering padding in a sequence tensor. + + Args: + sequence_length (torch.tensor): Sequence lengths. + max_len (int, Optional): Maximum sequence length. Defaults to None. + + Shapes: + - mask: :math:`[B, T_max]` + """ + if max_len is None: + max_len = sequence_length.max() + seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) + # B x T_max + return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) + + +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): + """Segment each sample in a batch based on the provided segment indices + + Args: + x (torch.tensor): Input tensor. + segment_indices (torch.tensor): Segment indices. + segment_size (int): Expected output segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. + """ + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + + segments = torch.zeros_like(x[:, :, :segment_size]) + + for i in range(x.size(0)): + index_start = segment_indices[i] + index_end = index_start + segment_size + x_i = x[i] + if pad_short and index_end >= x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] + return segments + + +def rand_segments( + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False +): + """Create random segments based on the input lengths. + + Args: + x (torch.tensor): Input tensor. + x_lengths (torch.tensor): Input lengths. + segment_size (int): Expected output segment size. + let_short_samples (bool): Allow shorter samples than the segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. + + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B]` + """ + _x_lenghts = x_lengths.clone() + B, _, T = x.size() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + else: + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long() + ret = segment(x, segment_indices, segment_size, pad_short=pad_short) + return ret, segment_indices + + +def average_over_durations(values, durs): + """Average values over durations. + + Shapes: + - values: :math:`[B, 1, T_de]` + - durs: :math:`[B, T_en]` + - avg: :math:`[B, 1, T_en]` + """ + durs_cums_ends = torch.cumsum(durs, dim=1).long() + durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) + values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0)) + values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0)) + + bs, l = durs_cums_ends.size() + n_formants = values.size(1) + dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) + dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) + + values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).float() + values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).float() + + avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems) + return avg + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + """ + Shapes: + - duration: :math:`[B, T_en]` + - mask: :math:'[B, T_en, T_de]` + - path: :math:`[B, T_en, T_de]` + """ + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def maximum_path(value, mask): + if CYTHON: + return maximum_path_cython(value, mask) + return maximum_path_numpy(value, mask) + + +def maximum_path_cython(value, mask): + """Cython optimised version. + Shapes: + - value: :math:`[B, T_en, T_de]` + - mask: :math:`[B, T_en, T_de]` + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +def maximum_path_numpy(value, mask, max_neg_val=None): + """ + Monotonic alignment search algorithm + Numpy-friendly version. It's about 4 times faster than torch version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + if max_neg_val is None: + max_neg_val = -np.inf # Patch for Sphinx complaint + value = value * mask + + device = value.device + dtype = value.dtype + value = value.cpu().detach().numpy() + mask = mask.cpu().detach().numpy().astype(bool) + + b, t_x, t_y = value.shape + direction = np.zeros(value.shape, dtype=np.int64) + v = np.zeros((b, t_x), dtype=np.float32) + x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) + for j in range(t_y): + v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1] + v1 = v + max_mask = v1 >= v0 + v_max = np.where(max_mask, v1, v0) + direction[:, :, j] = max_mask + + index_mask = x_range <= j + v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) + direction = np.where(mask, direction, 1) + + path = np.zeros(value.shape, dtype=np.float32) + index = mask[:, :, 0].sum(1).astype(np.int64) - 1 + index_range = np.arange(b) + for j in reversed(range(t_y)): + path[index_range, index, j] = 1 + index = index + direction[index_range, index, j] - 1 + path = path * mask.astype(np.float32) + path = torch.from_numpy(path).to(device=device, dtype=dtype) + return path + + +def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): + P, M = phoneme_count, mel_count + x = np.arange(0, P) + mel_text_probs = [] + for i in range(1, M + 1): + a, b = scaling_factor * i, scaling_factor * (M + 1 - i) + rv = betabinom(P, a, b) + mel_i_prob = rv.pmf(x) + mel_text_probs.append(mel_i_prob) + return np.array(mel_text_probs) + + +def compute_attn_prior(x_len, y_len, scaling_factor=1.0): + """Compute attention priors for the alignment network.""" + attn_prior = beta_binomial_prior_distribution( + x_len, + y_len, + scaling_factor, + ) + return attn_prior # [y_len, x_len] diff --git a/viXTTS/TTS/tts/utils/languages.py b/viXTTS/TTS/tts/utils/languages.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1836b32ce2010ad55a0253849f2e59c61dad82 --- /dev/null +++ b/viXTTS/TTS/tts/utils/languages.py @@ -0,0 +1,125 @@ +import os +from typing import Any, Dict, List + +import fsspec +import numpy as np +import torch +from coqpit import Coqpit + +from TTS.config import check_config_and_model_args +from TTS.tts.utils.managers import BaseIDManager + + +class LanguageManager(BaseIDManager): + """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by language. + + Args: + language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by + TTS models. Defaults to "". + config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed. + Defaults to None. + + Examples: + >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path) + >>> language_id_mapper = manager.language_ids + """ + + def __init__( + self, + language_ids_file_path: str = "", + config: Coqpit = None, + ): + super().__init__(id_file_path=language_ids_file_path) + + if config: + self.set_language_ids_from_config(config) + + @property + def num_languages(self) -> int: + return len(list(self.name_to_id.keys())) + + @property + def language_names(self) -> List: + return list(self.name_to_id.keys()) + + @staticmethod + def parse_language_ids_from_config(c: Coqpit) -> Dict: + """Set language id from config. + + Args: + c (Coqpit): Config + + Returns: + Tuple[Dict, int]: Language ID mapping and the number of languages. + """ + languages = set({}) + for dataset in c.datasets: + if "language" in dataset: + languages.add(dataset["language"]) + else: + raise ValueError(f"Dataset {dataset['name']} has no language specified.") + return {name: i for i, name in enumerate(sorted(list(languages)))} + + def set_language_ids_from_config(self, c: Coqpit) -> None: + """Set language IDs from config samples. + + Args: + c (Coqpit): Config. + """ + self.name_to_id = self.parse_language_ids_from_config(c) + + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Any: + raise NotImplementedError + + def set_ids_from_data(self, items: List, parse_key: str) -> Any: + raise NotImplementedError + + def save_ids_to_file(self, file_path: str) -> None: + """Save language IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.name_to_id) + + @staticmethod + def init_from_config(config: Coqpit) -> "LanguageManager": + """Initialize the language manager from a Coqpit config. + + Args: + config (Coqpit): Coqpit config. + """ + language_manager = None + if check_config_and_model_args(config, "use_language_embedding", True): + if config.get("language_ids_file", None): + language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + language_manager = LanguageManager(config=config) + return language_manager + + +def _set_file_path(path): + """Find the language_ids.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "language_ids.json") + path_continue = os.path.join(path, "language_ids.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + return None + + +def get_language_balancer_weights(items: list): + language_names = np.array([item["language"] for item in items]) + unique_language_names = np.unique(language_names).tolist() + language_ids = [unique_language_names.index(l) for l in language_names] + language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) + weight_language = 1.0 / language_count + # get weight for each sample + dataset_samples_weight = np.array([weight_language[l] for l in language_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/viXTTS/TTS/tts/utils/managers.py b/viXTTS/TTS/tts/utils/managers.py new file mode 100644 index 0000000000000000000000000000000000000000..1f94c5332df1e2774955eb263c3b688c5ad6e827 --- /dev/null +++ b/viXTTS/TTS/tts/utils/managers.py @@ -0,0 +1,383 @@ +import json +import random +from typing import Any, Dict, List, Tuple, Union + +import fsspec +import numpy as np +import torch + +from TTS.config import load_config +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.utils.audio import AudioProcessor + + +def load_file(path: str): + if path.endswith(".json"): + with fsspec.open(path, "r") as f: + return json.load(f) + elif path.endswith(".pth"): + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location="cpu") + else: + raise ValueError("Unsupported file type") + + +def save_file(obj: Any, path: str): + if path.endswith(".json"): + with fsspec.open(path, "w") as f: + json.dump(obj, f, indent=4) + elif path.endswith(".pth"): + with fsspec.open(path, "wb") as f: + torch.save(obj, f) + else: + raise ValueError("Unsupported file type") + + +class BaseIDManager: + """Base `ID` Manager class. Every new `ID` manager must inherit this. + It defines common `ID` manager specific functions. + """ + + def __init__(self, id_file_path: str = ""): + self.name_to_id = {} + + if id_file_path: + self.load_ids_from_file(id_file_path) + + @staticmethod + def _load_json(json_file_path: str) -> Dict: + with fsspec.open(json_file_path, "r") as f: + return json.load(f) + + @staticmethod + def _save_json(json_file_path: str, data: dict) -> None: + with fsspec.open(json_file_path, "w") as f: + json.dump(data, f, indent=4) + + def set_ids_from_data(self, items: List, parse_key: str) -> None: + """Set IDs from data samples. + + Args: + items (List): Data sampled returned by `load_tts_samples()`. + """ + self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key) + + def load_ids_from_file(self, file_path: str) -> None: + """Set IDs from a file. + + Args: + file_path (str): Path to the file. + """ + self.name_to_id = load_file(file_path) + + def save_ids_to_file(self, file_path: str) -> None: + """Save IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + save_file(self.name_to_id, file_path) + + def get_random_id(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.name_to_id: + return self.name_to_id[random.choices(list(self.name_to_id.keys()))[0]] + + return None + + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]: + """Parse IDs from data samples retured by `load_tts_samples()`. + + Args: + items (list): Data sampled returned by `load_tts_samples()`. + parse_key (str): The key to being used to parse the data. + Returns: + Tuple[Dict]: speaker IDs. + """ + classes = sorted({item[parse_key] for item in items}) + ids = {name: i for i, name in enumerate(classes)} + return ids + + +class EmbeddingManager(BaseIDManager): + """Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. + It defines common `Embedding` manager specific functions. + + It expects embeddings files in the following format: + + :: + + { + 'audio_file_key':{ + 'name': 'category_name', + 'embedding'[] + }, + ... + } + + `audio_file_key` is a unique key to the audio file in the dataset. It can be the path to the file or any other unique key. + `embedding` is the embedding vector of the audio file. + `name` can be name of the speaker of the audio file. + """ + + def __init__( + self, + embedding_file_path: Union[str, List[str]] = "", + id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__(id_file_path=id_file_path) + + self.embeddings = {} + self.embeddings_by_names = {} + self.clip_ids = [] + self.encoder = None + self.encoder_ap = None + self.use_cuda = use_cuda + + if embedding_file_path: + if isinstance(embedding_file_path, list): + self.load_embeddings_from_list_of_files(embedding_file_path) + else: + self.load_embeddings_from_file(embedding_file_path) + + if encoder_model_path and encoder_config_path: + self.init_encoder(encoder_model_path, encoder_config_path, use_cuda) + + @property + def num_embeddings(self): + """Get number of embeddings.""" + return len(self.embeddings) + + @property + def num_names(self): + """Get number of embeddings.""" + return len(self.embeddings_by_names) + + @property + def embedding_dim(self): + """Dimensionality of embeddings. If embeddings are not loaded, returns zero.""" + if self.embeddings: + return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"]) + return 0 + + @property + def embedding_names(self): + """Get embedding names.""" + return list(self.embeddings_by_names.keys()) + + def save_embeddings_to_file(self, file_path: str) -> None: + """Save embeddings to a json file. + + Args: + file_path (str): Path to the output file. + """ + save_file(self.embeddings, file_path) + + @staticmethod + def read_embeddings_from_file(file_path: str): + """Load embeddings from a json file. + + Args: + file_path (str): Path to the file. + """ + embeddings = load_file(file_path) + speakers = sorted({x["name"] for x in embeddings.values()}) + name_to_id = {name: i for i, name in enumerate(speakers)} + clip_ids = list(set(sorted(clip_name for clip_name in embeddings.keys()))) + # cache embeddings_by_names for fast inference using a bigger speakers.json + embeddings_by_names = {} + for x in embeddings.values(): + if x["name"] not in embeddings_by_names.keys(): + embeddings_by_names[x["name"]] = [x["embedding"]] + else: + embeddings_by_names[x["name"]].append(x["embedding"]) + return name_to_id, clip_ids, embeddings, embeddings_by_names + + def load_embeddings_from_file(self, file_path: str) -> None: + """Load embeddings from a json file. + + Args: + file_path (str): Path to the target json file. + """ + self.name_to_id, self.clip_ids, self.embeddings, self.embeddings_by_names = self.read_embeddings_from_file( + file_path + ) + + def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None: + """Load embeddings from a list of json files and don't allow duplicate keys. + + Args: + file_paths (List[str]): List of paths to the target json files. + """ + self.name_to_id = {} + self.clip_ids = [] + self.embeddings_by_names = {} + self.embeddings = {} + for file_path in file_paths: + ids, clip_ids, embeddings, embeddings_by_names = self.read_embeddings_from_file(file_path) + # check colliding keys + duplicates = set(self.embeddings.keys()) & set(embeddings.keys()) + if duplicates: + raise ValueError(f" [!] Duplicate embedding names <{duplicates}> in {file_path}") + # store values + self.name_to_id.update(ids) + self.clip_ids.extend(clip_ids) + self.embeddings_by_names.update(embeddings_by_names) + self.embeddings.update(embeddings) + + # reset name_to_id to get the right speaker ids + self.name_to_id = {name: i for i, name in enumerate(self.name_to_id)} + + def get_embedding_by_clip(self, clip_idx: str) -> List: + """Get embedding by clip ID. + + Args: + clip_idx (str): Target clip ID. + + Returns: + List: embedding as a list. + """ + return self.embeddings[clip_idx]["embedding"] + + def get_embeddings_by_name(self, idx: str) -> List[List]: + """Get all embeddings of a speaker. + + Args: + idx (str): Target name. + + Returns: + List[List]: all the embeddings of the given speaker. + """ + return self.embeddings_by_names[idx] + + def get_embeddings_by_names(self) -> Dict: + """Get all embeddings by names. + + Returns: + Dict: all the embeddings of each speaker. + """ + embeddings_by_names = {} + for x in self.embeddings.values(): + if x["name"] not in embeddings_by_names.keys(): + embeddings_by_names[x["name"]] = [x["embedding"]] + else: + embeddings_by_names[x["name"]].append(x["embedding"]) + return embeddings_by_names + + def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: + """Get mean embedding of a idx. + + Args: + idx (str): Target name. + num_samples (int, optional): Number of samples to be averaged. Defaults to None. + randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False. + + Returns: + np.ndarray: Mean embedding. + """ + embeddings = self.get_embeddings_by_name(idx) + if num_samples is None: + embeddings = np.stack(embeddings).mean(0) + else: + assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}" + if randomize: + embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0) + else: + embeddings = np.stack(embeddings[:num_samples]).mean(0) + return embeddings + + def get_random_embedding(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.embeddings: + return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"] + + return None + + def get_clips(self) -> List: + return sorted(self.embeddings.keys()) + + def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None: + """Initialize a speaker encoder model. + + Args: + model_path (str): Model file path. + config_path (str): Model config file path. + use_cuda (bool, optional): Use CUDA. Defaults to False. + """ + self.use_cuda = use_cuda + self.encoder_config = load_config(config_path) + self.encoder = setup_encoder_model(self.encoder_config) + self.encoder_criterion = self.encoder.load_checkpoint( + self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True + ) + self.encoder_ap = AudioProcessor(**self.encoder_config.audio) + + def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: + """Compute a embedding from a given audio file. + + Args: + wav_file (Union[str, List[str]]): Target file path. + + Returns: + list: Computed embedding. + """ + + def _compute(wav_file: str): + waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate) + if not self.encoder_config.model_params.get("use_torch_spec", False): + m_input = self.encoder_ap.melspectrogram(waveform) + m_input = torch.from_numpy(m_input) + else: + m_input = torch.from_numpy(waveform) + + if self.use_cuda: + m_input = m_input.cuda() + m_input = m_input.unsqueeze(0) + embedding = self.encoder.compute_embedding(m_input) + return embedding + + if isinstance(wav_file, list): + # compute the mean embedding + embeddings = None + for wf in wav_file: + embedding = _compute(wf) + if embeddings is None: + embeddings = embedding + else: + embeddings += embedding + return (embeddings / len(wav_file))[0].tolist() + embedding = _compute(wav_file) + return embedding[0].tolist() + + def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List: + """Compute embedding from features. + + Args: + feats (Union[torch.Tensor, np.ndarray]): Input features. + + Returns: + List: computed embedding. + """ + if isinstance(feats, np.ndarray): + feats = torch.from_numpy(feats) + if feats.ndim == 2: + feats = feats.unsqueeze(0) + if self.use_cuda: + feats = feats.cuda() + return self.encoder.compute_embedding(feats) diff --git a/viXTTS/TTS/tts/utils/measures.py b/viXTTS/TTS/tts/utils/measures.py new file mode 100644 index 0000000000000000000000000000000000000000..90e862e1190bdb8443933580b3ff47321f70cecd --- /dev/null +++ b/viXTTS/TTS/tts/utils/measures.py @@ -0,0 +1,15 @@ +def alignment_diagonal_score(alignments, binary=False): + """ + Compute how diagonal alignment predictions are. It is useful + to measure the alignment consistency of a model + Args: + alignments (torch.Tensor): batch of alignments. + binary (bool): if True, ignore scores and consider attention + as a binary mask. + Shape: + - alignments : :math:`[B, T_de, T_en]` + """ + maxs = alignments.max(dim=1)[0] + if binary: + maxs[maxs > 0] = 1 + return maxs.mean(dim=1).mean(dim=0).item() diff --git a/viXTTS/TTS/tts/utils/monotonic_align/__init__.py b/viXTTS/TTS/tts/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/monotonic_align/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/tts/utils/monotonic_align/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb1ac9979faaa93ac17774a3f0b00d49c61d6b6e Binary files /dev/null and b/viXTTS/TTS/tts/utils/monotonic_align/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/tts/utils/monotonic_align/core.pyx b/viXTTS/TTS/tts/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..091fcc3a50a51f3d3fee47a70825260757e6d885 --- /dev/null +++ b/viXTTS/TTS/tts/utils/monotonic_align/core.pyx @@ -0,0 +1,47 @@ +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/viXTTS/TTS/tts/utils/monotonic_align/setup.py b/viXTTS/TTS/tts/utils/monotonic_align/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f22bc6a35a5a04c9e6d7b82040973722c9b770c9 --- /dev/null +++ b/viXTTS/TTS/tts/utils/monotonic_align/setup.py @@ -0,0 +1,7 @@ +# from distutils.core import setup +# from Cython.Build import cythonize +# import numpy + +# setup(name='monotonic_align', +# ext_modules=cythonize("core.pyx"), +# include_dirs=[numpy.get_include()]) diff --git a/viXTTS/TTS/tts/utils/speakers.py b/viXTTS/TTS/tts/utils/speakers.py new file mode 100644 index 0000000000000000000000000000000000000000..e49695268d6a4a3ac8e5f41df8954f07b16b5566 --- /dev/null +++ b/viXTTS/TTS/tts/utils/speakers.py @@ -0,0 +1,222 @@ +import json +import os +from typing import Any, Dict, List, Union + +import fsspec +import numpy as np +import torch +from coqpit import Coqpit + +from TTS.config import get_from_config_or_model_args_with_default +from TTS.tts.utils.managers import EmbeddingManager + + +class SpeakerManager(EmbeddingManager): + """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by speaker or clip. + + There are 3 different scenarios considered: + + 1. Models using speaker embedding layers. The datafile only maps speaker names to ids used by the embedding layer. + 2. Models using d-vectors. The datafile includes a dictionary in the following format. + + :: + + { + 'clip_name.wav':{ + 'name': 'speakerA', + 'embedding'[] + }, + ... + } + + + 3. Computing the d-vectors by the speaker encoder. It loads the speaker encoder model and + computes the d-vectors for a given clip or speaker. + + Args: + d_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". + speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by + TTS models. Defaults to "". + encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "". + encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "". + + Examples: + >>> # load audio processor and speaker encoder + >>> ap = AudioProcessor(**config.audio) + >>> manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) + >>> # load a sample audio and compute embedding + >>> waveform = ap.load_wav(sample_wav_path) + >>> mel = ap.melspectrogram(waveform) + >>> d_vector = manager.compute_embeddings(mel.T) + """ + + def __init__( + self, + data_items: List[List[Any]] = None, + d_vectors_file_path: str = "", + speaker_id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__( + embedding_file_path=d_vectors_file_path, + id_file_path=speaker_id_file_path, + encoder_model_path=encoder_model_path, + encoder_config_path=encoder_config_path, + use_cuda=use_cuda, + ) + + if data_items: + self.set_ids_from_data(data_items, parse_key="speaker_name") + + @property + def num_speakers(self): + return len(self.name_to_id) + + @property + def speaker_names(self): + return list(self.name_to_id.keys()) + + def get_speakers(self) -> List: + return self.name_to_id + + @staticmethod + def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager": + """Initialize a speaker manager from config + + Args: + config (Coqpit): Config object. + samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names. + Defaults to None. + + Returns: + SpeakerEncoder: Speaker encoder object. + """ + speaker_manager = None + if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False): + if samples: + speaker_manager = SpeakerManager(data_items=samples) + if get_from_config_or_model_args_with_default(config, "speaker_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None) + ) + + if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False): + speaker_manager = SpeakerManager() + if get_from_config_or_model_args_with_default(config, "d_vector_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None) + ) + return speaker_manager + + +def _set_file_path(path): + """Find the speakers.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "speakers.json") + path_continue = os.path.join(path, "speakers.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + raise FileNotFoundError(f" [!] `speakers.json` not found in {path}") + + +def load_speaker_mapping(out_path): + """Loads speaker mapping if already present.""" + if os.path.splitext(out_path)[1] == ".json": + json_file = out_path + else: + json_file = _set_file_path(out_path) + with fsspec.open(json_file, "r") as f: + return json.load(f) + + +def save_speaker_mapping(out_path, speaker_mapping): + """Saves speaker mapping if not yet present.""" + if out_path is not None: + speakers_json_path = _set_file_path(out_path) + with fsspec.open(speakers_json_path, "w") as f: + json.dump(speaker_mapping, f, indent=4) + + +def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> SpeakerManager: + """Initiate a `SpeakerManager` instance by the provided config. + + Args: + c (Coqpit): Model configuration. + restore_path (str): Path to a previous training folder. + data (List): Data samples used in training to infer speakers from. It must be provided if speaker embedding + layers is used. Defaults to None. + out_path (str, optional): Save the generated speaker IDs to a output path. Defaults to None. + + Returns: + SpeakerManager: initialized and ready to use instance. + """ + speaker_manager = SpeakerManager() + if c.use_speaker_embedding: + if data is not None: + speaker_manager.set_ids_from_data(data, parse_key="speaker_name") + if restore_path: + speakers_file = _set_file_path(restore_path) + # restoring speaker manager from a previous run. + if c.use_d_vector_file: + # restore speaker manager with the embedding file + if not os.path.exists(speakers_file): + print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file") + if not os.path.exists(c.d_vector_file): + raise RuntimeError( + "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" + ) + speaker_manager.load_embeddings_from_file(c.d_vector_file) + speaker_manager.load_embeddings_from_file(speakers_file) + elif not c.use_d_vector_file: # restor speaker manager with speaker ID file. + speaker_ids_from_data = speaker_manager.name_to_id + speaker_manager.load_ids_from_file(speakers_file) + assert all( + speaker in speaker_manager.name_to_id for speaker in speaker_ids_from_data + ), " [!] You cannot introduce new speakers to a pre-trained model." + elif c.use_d_vector_file and c.d_vector_file: + # new speaker manager with external speaker embeddings. + speaker_manager.load_embeddings_from_file(c.d_vector_file) + elif c.use_d_vector_file and not c.d_vector_file: + raise "use_d_vector_file is True, so you need pass a external speaker embedding file." + elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file: + # new speaker manager with speaker IDs file. + speaker_manager.load_ids_from_file(c.speakers_file) + + if speaker_manager.num_speakers > 0: + print( + " > Speaker manager is loaded with {} speakers: {}".format( + speaker_manager.num_speakers, ", ".join(speaker_manager.name_to_id) + ) + ) + + # save file if path is defined + if out_path: + out_file_path = os.path.join(out_path, "speakers.json") + print(f" > Saving `speakers.json` to {out_file_path}.") + if c.use_d_vector_file and c.d_vector_file: + speaker_manager.save_embeddings_to_file(out_file_path) + else: + speaker_manager.save_ids_to_file(out_file_path) + return speaker_manager + + +def get_speaker_balancer_weights(items: list): + speaker_names = np.array([item["speaker_name"] for item in items]) + unique_speaker_names = np.unique(speaker_names).tolist() + speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] + speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) + weight_speaker = 1.0 / speaker_count + dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/viXTTS/TTS/tts/utils/ssim.py b/viXTTS/TTS/tts/utils/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc3befc5bd3fb154cd48b4458184a4d8f3dca78 --- /dev/null +++ b/viXTTS/TTS/tts/utils/ssim.py @@ -0,0 +1,383 @@ +# Adopted from https://github.com/photosynthesis-team/piq + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor: + r"""Reduce input in batch dimension if needed. + Args: + x: Tensor with shape (N, *). + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` + """ + if reduction == "none": + return x + if reduction == "mean": + return x.mean(dim=0) + if reduction == "sum": + return x.sum(dim=0) + raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") + + +def _validate_input( + tensors: List[torch.Tensor], + dim_range: Tuple[int, int] = (0, -1), + data_range: Tuple[float, float] = (0.0, -1.0), + # size_dim_range: Tuple[float, float] = (0., -1.), + size_range: Optional[Tuple[int, int]] = None, +) -> None: + r"""Check that input(-s) satisfies the requirements + Args: + tensors: Tensors to check + dim_range: Allowed number of dimensions. (min, max) + data_range: Allowed range of values in tensors. (min, max) + size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1) + """ + + if not __debug__: + return + + x = tensors[0] + + for t in tensors: + assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}" + assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}" + + if size_range is None: + assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}" + else: + assert ( + t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]] + ), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}" + + if dim_range[0] == dim_range[1]: + assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}" + elif dim_range[0] < dim_range[1]: + assert ( + dim_range[0] <= t.dim() <= dim_range[1] + ), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}" + + if data_range[0] < data_range[1]: + assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}" + assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}" + + +def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor: + r"""Returns 2D Gaussian kernel N(0,`sigma`^2) + Args: + size: Size of the kernel + sigma: Std of the distribution + Returns: + gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + coords = torch.arange(kernel_size, dtype=torch.float32) + coords -= (kernel_size - 1) / 2.0 + + g = coords**2 + g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp() + + g /= g.sum() + return g.unsqueeze(0) + + +def ssim( + x: torch.Tensor, + y: torch.Tensor, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + data_range: Union[int, float] = 1.0, + reduction: str = "mean", + full: bool = False, + downsample: bool = True, + k1: float = 0.01, + k2: float = 0.03, +) -> List[torch.Tensor]: + r"""Interface of Structural Similarity (SSIM) index. + Inputs supposed to be in range ``[0, data_range]``. + To match performance with skimage and tensorflow set ``'downsample' = True``. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. + kernel_sigma: Sigma of normal distribution. + data_range: Maximum value range of images (usually 1.0 or 255). + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + full: Return cs map or not. + downsample: Perform average pool before SSIM computation. Default: True + k1: Algorithm parameter, K1 (small constant). + k2: Algorithm parameter, K2 (small constant). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned + as a tensor of size 2. + + References: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI: `10.1109/TIP.2003.819861` + """ + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" + _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range)) + + x = x / float(data_range) + y = y / float(data_range) + + # Averagepool image if the size is large enough + f = max(1, round(min(x.size()[-2:]) / 256)) + if (f > 1) and downsample: + x = F.avg_pool2d(x, kernel_size=f) + y = F.avg_pool2d(y, kernel_size=f) + + kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) + _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel + ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2) + ssim_val = ssim_map.mean(1) + cs = cs_map.mean(1) + + ssim_val = _reduce(ssim_val, reduction) + cs = _reduce(cs, reduction) + + if full: + return [ssim_val, cs] + + return ssim_val + + +class SSIMLoss(_Loss): + r"""Creates a criterion that measures the structural similarity index error between + each element in the input :math:`x` and target :math:`y`. + + To match performance with skimage and tensorflow set ``'downsample' = True``. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + SSIM = \{ssim_1,\dots,ssim_{N \times C}\}\\ + ssim_{l}(x, y) = \frac{(2 \mu_x \mu_y + c_1) (2 \sigma_{xy} + c_2)} + {(\mu_x^2 +\mu_y^2 + c_1)(\sigma_x^2 +\sigma_y^2 + c_2)}, + + where :math:`N` is the batch size, `C` is the channel size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + SSIMLoss(x, y) = + \begin{cases} + \operatorname{mean}(1 - SSIM), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(1 - SSIM), & \text{if reduction} = \text{'sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The sum operation still operates over all the elements, and divides by :math:`n`. + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + In case of 5D input tensors, complex value is returned as a tensor of size 2. + + Args: + kernel_size: By default, the mean and covariance of a pixel is obtained + by convolution with given filter_size. + kernel_sigma: Standard deviation for Gaussian kernel. + k1: Coefficient related to c1 in the above equation. + k2: Coefficient related to c2 in the above equation. + downsample: Perform average pool before SSIM computation. Default: True + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + data_range: Maximum value range of images (usually 1.0 or 255). + + Examples: + >>> loss = SSIMLoss() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI:`10.1109/TIP.2003.819861` + """ + __constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"] + + def __init__( + self, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + downsample: bool = True, + reduction: str = "mean", + data_range: Union[int, float] = 1.0, + ) -> None: + super().__init__() + + # Generic loss parameters. + self.reduction = reduction + + # Loss-specific parameters. + self.kernel_size = kernel_size + + # This check might look redundant because kernel size is checked within the ssim function anyway. + # However, this check allows to fail fast when the loss is being initialised and training has not been started. + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" + self.kernel_sigma = kernel_sigma + self.k1 = k1 + self.k2 = k2 + self.downsample = downsample + self.data_range = data_range + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r"""Computation of Structural Similarity (SSIM) index as a loss function. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + + Returns: + Value of SSIM loss to be minimized, i.e ``1 - ssim`` in [0, 1] range. In case of 5D input tensors, + complex value is returned as a tensor of size 2. + """ + + score = ssim( + x=x, + y=y, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + downsample=self.downsample, + data_range=self.data_range, + reduction=self.reduction, + full=False, + k1=self.k1, + k2=self.k2, + ) + return torch.ones_like(score) - score + + +def _ssim_per_channel( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""Calculate Structural Similarity (SSIM) index for X and Y per channel. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)`. + y: A target tensor. Shape :math:`(N, C, H, W)`. + kernel: 2D Gaussian kernel. + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Full Value of Structural Similarity (SSIM) index. + """ + if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2): + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) + + c1 = k1**2 + c2 = k2**2 + n_channels = x.size(1) + mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels) + mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels) + + mu_xx = mu_x**2 + mu_yy = mu_y**2 + mu_xy = mu_x * mu_y + + sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx + sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy + sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy + + # Contrast sensitivity (CS) with alpha = beta = gamma = 1. + cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) + + # Structural similarity (SSIM) + ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs + + ssim_val = ss.mean(dim=(-1, -2)) + cs = cs.mean(dim=(-1, -2)) + return ssim_val, cs + + +def _ssim_per_channel_complex( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W, 2)`. + kernel: 2-D gauss kernel. + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Full Value of Complex Structural Similarity (SSIM) index. + """ + n_channels = x.size(1) + if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2): + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) + + c1 = k1**2 + c2 = k2**2 + + x_real = x[..., 0] + x_imag = x[..., 1] + y_real = y[..., 0] + y_imag = y[..., 1] + + mu1_real = F.conv2d(x_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu1_imag = F.conv2d(x_imag, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_real = F.conv2d(y_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_imag = F.conv2d(y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) + + mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2) + mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2) + mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag + mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real + + compensation = 1.0 + + x_sq = x_real.pow(2) + x_imag.pow(2) + y_sq = y_real.pow(2) + y_imag.pow(2) + x_y_real = x_real * y_real - x_imag * y_imag + x_y_imag = x_real * y_imag + x_imag * y_real + + sigma1_sq = F.conv2d(x_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq + sigma2_sq = F.conv2d(y_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq + sigma12_real = F.conv2d(x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_real + sigma12_imag = F.conv2d(x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_imag + sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1) + mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1) + # Set alpha = beta = gamma = 1. + cs_map = (sigma12 * 2 + c2 * compensation) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation) + ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation) + ssim_map = ssim_map * cs_map + + ssim_val = ssim_map.mean(dim=(-2, -3)) + cs = cs_map.mean(dim=(-2, -3)) + + return ssim_val, cs diff --git a/viXTTS/TTS/tts/utils/synthesis.py b/viXTTS/TTS/tts/utils/synthesis.py new file mode 100644 index 0000000000000000000000000000000000000000..797151c2540fe86b26df2b52bf734f0a93389f72 --- /dev/null +++ b/viXTTS/TTS/tts/utils/synthesis.py @@ -0,0 +1,343 @@ +from typing import Dict + +import numpy as np +import torch +from torch import nn + + +def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"): + if cuda: + device = "cuda" + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype, device=device) + return tensor + + +def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): + if cuda: + device = "cuda" + style_mel = torch.FloatTensor( + ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), + device=device, + ).unsqueeze(0) + return style_mel + + +def run_model_torch( + model: nn.Module, + inputs: torch.Tensor, + speaker_id: int = None, + style_mel: torch.Tensor = None, + style_text: str = None, + d_vector: torch.Tensor = None, + language_id: torch.Tensor = None, +) -> Dict: + """Run a torch model for inference. It does not support batch inference. + + Args: + model (nn.Module): The model to run inference. + inputs (torch.Tensor): Input tensor with character ids. + speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None. + style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None. + d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None. + + Returns: + Dict: model outputs. + """ + input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) + if hasattr(model, "module"): + _func = model.module.inference + else: + _func = model.inference + outputs = _func( + inputs, + aux_input={ + "x_lengths": input_lengths, + "speaker_ids": speaker_id, + "d_vectors": d_vector, + "style_mel": style_mel, + "style_text": style_text, + "language_ids": language_id, + }, + ) + return outputs + + +def trim_silence(wav, ap): + return wav[: ap.find_endpoint(wav)] + + +def inv_spectrogram(postnet_output, ap, CONFIG): + if CONFIG.model.lower() in ["tacotron"]: + wav = ap.inv_spectrogram(postnet_output.T) + else: + wav = ap.inv_melspectrogram(postnet_output.T) + return wav + + +def id_to_torch(aux_id, cuda=False, device="cpu"): + if cuda: + device = "cuda" + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id).to(device) + return aux_id + + +def embedding_to_torch(d_vector, cuda=False, device="cpu"): + if cuda: + device = "cuda" + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0).to(device) + return d_vector + + +# TODO: perform GL with pytorch for batching +def apply_griffin_lim(inputs, input_lens, CONFIG, ap): + """Apply griffin-lim to each sample iterating throught the first dimension. + Args: + inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size. + input_lens (Tensor or np.Array): 1D array of sample lengths. + CONFIG (Dict): TTS config. + ap (AudioProcessor): TTS audio processor. + """ + wavs = [] + for idx, spec in enumerate(inputs): + wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding + wav = inv_spectrogram(spec, ap, CONFIG) + # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}" + wavs.append(wav[:wav_len]) + return wavs + + +def synthesis( + model, + text, + CONFIG, + use_cuda, + speaker_id=None, + style_wav=None, + style_text=None, + use_griffin_lim=False, + do_trim_silence=False, + d_vector=None, + language_id=None, +): + """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to + the vocoder model. + + Args: + model (TTS.tts.models): + The TTS model to synthesize audio with. + + text (str): + The input text to convert to speech. + + CONFIG (Coqpit): + Model configuration. + + use_cuda (bool): + Enable/disable CUDA. + + speaker_id (int): + Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + style_wav (str | Dict[str, float]): + Path or tensor to/of a waveform used for computing the style embedding based on GST or Capacitron. + Defaults to None, meaning that Capacitron models will sample from the prior distribution to + generate random but realistic prosody. + + style_text (str): + Transcription of style_wav for Capacitron models. Defaults to None. + + enable_eos_bos_chars (bool): + enable special chars for end of sentence and start of sentence. Defaults to False. + + do_trim_silence (bool): + trim silence after synthesis. Defaults to False. + + d_vector (torch.Tensor): + d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + language_id (int): + Language ID passed to the language embedding layer in multi-langual model. Defaults to None. + """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + + # GST or Capacitron processing + # TODO: need to handle the case of setting both gst and capacitron to true somewhere + style_mel = None + if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: + if isinstance(style_wav, dict): + style_mel = style_wav + else: + style_mel = compute_style_mel(style_wav, model.ap, device=device) + + if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None: + style_mel = compute_style_mel(style_wav, model.ap, device=device) + style_mel = style_mel.transpose(1, 2) # [1, time, depth] + + language_name = None + if language_id is not None: + language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id] + assert len(language) == 1, "language_id must be a valid language" + language_name = language[0] + + # convert text to sequence of token IDs + text_inputs = np.asarray( + model.tokenizer.text_to_ids(text, language=language_name), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, device=device) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, device=device) + + if language_id is not None: + language_id = id_to_torch(language_id, device=device) + + if not isinstance(style_mel, dict): + # GST or Capacitron style mel + style_mel = numpy_to_torch(style_mel, torch.float, device=device) + if style_text is not None: + style_text = np.asarray( + model.tokenizer.text_to_ids(style_text, language=language_id), + dtype=np.int32, + ) + style_text = numpy_to_torch(style_text, torch.long, device=device) + style_text = style_text.unsqueeze(0) + + text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) + text_inputs = text_inputs.unsqueeze(0) + # synthesize voice + outputs = run_model_torch( + model, + text_inputs, + speaker_id, + style_mel, + style_text, + d_vector=d_vector, + language_id=language_id, + ) + model_outputs = outputs["model_outputs"] + model_outputs = model_outputs[0].data.cpu().numpy() + alignments = outputs["alignments"] + + # convert outputs to numpy + # plot results + wav = None + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] + if use_griffin_lim: + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + +def transfer_voice( + model, + CONFIG, + use_cuda, + reference_wav, + speaker_id=None, + d_vector=None, + reference_speaker_id=None, + reference_d_vector=None, + do_trim_silence=False, + use_griffin_lim=False, +): + """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to + the vocoder model. + + Args: + model (TTS.tts.models): + The TTS model to synthesize audio with. + + CONFIG (Coqpit): + Model configuration. + + use_cuda (bool): + Enable/disable CUDA. + + reference_wav (str): + Path of reference_wav to be used to voice conversion. + + speaker_id (int): + Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + d_vector (torch.Tensor): + d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + reference_speaker_id (int): + Reference Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + reference_d_vector (torch.Tensor): + Reference d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + enable_eos_bos_chars (bool): + enable special chars for end of sentence and start of sentence. Defaults to False. + + do_trim_silence (bool): + trim silence after synthesis. Defaults to False. + """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, device=device) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, device=device) + + if reference_d_vector is not None: + reference_d_vector = embedding_to_torch(reference_d_vector, device=device) + + # load reference_wav audio + reference_wav = embedding_to_torch( + model.ap.load_wav( + reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate + ), + device=device, + ) + + if hasattr(model, "module"): + _func = model.module.inference_voice_conversion + else: + _func = model.inference_voice_conversion + model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector) + + # convert outputs to numpy + # plot results + wav = None + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] + if use_griffin_lim: + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs + + return wav diff --git a/viXTTS/TTS/tts/utils/text/__init__.py b/viXTTS/TTS/tts/utils/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..593372dc7cb2fba240eb5f08e8e2cfae5a4b4e45 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/__init__.py @@ -0,0 +1 @@ +from TTS.tts.utils.text.tokenizer import TTSTokenizer diff --git a/viXTTS/TTS/tts/utils/text/bangla/__init__.py b/viXTTS/TTS/tts/utils/text/bangla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/text/bangla/phonemizer.py b/viXTTS/TTS/tts/utils/text/bangla/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e15830fe8af6cd08ea03fb4f8e40a9ddbe70f1f6 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/bangla/phonemizer.py @@ -0,0 +1,121 @@ +import re + +import bangla +from bnnumerizer import numerize +from bnunicodenormalizer import Normalizer + +# initialize +bnorm = Normalizer() + + +attribution_dict = { + "সাঃ": "সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম", + "আঃ": "আলাইহিস সালাম", + "রাঃ": "রাদিআল্লাহু আনহু", + "রহঃ": "রহমাতুল্লাহি আলাইহি", + "রহিঃ": "রহিমাহুল্লাহ", + "হাফিঃ": "হাফিযাহুল্লাহ", + "বায়ান": "বাইআন", + "দাঃবাঃ": "দামাত বারাকাতুহুম,দামাত বারাকাতুল্লাহ", + # "আয়াত" : "আইআত",#আইআত + # "ওয়া" : "ওআ", + # "ওয়াসাল্লাম" : "ওআসাল্লাম", + # "কেন" : "কেনো", + # "কোন" : "কোনো", + # "বল" : "বলো", + # "চল" : "চলো", + # "কর" : "করো", + # "রাখ" : "রাখো", + "’": "", + "‘": "", + # "য়" : "অ", + # "সম্প্রদায়" : "সম্প্রদাই", + # "রয়েছে" : "রইছে", + # "রয়েছ" : "রইছ", + "/": " বাই ", +} + + +def tag_text(text: str): + # remove multiple spaces + text = re.sub(" +", " ", text) + # create start and end + text = "start" + text + "end" + # tag text + parts = re.split("[\u0600-\u06FF]+", text) + # remove non chars + parts = [p for p in parts if p.strip()] + # unique parts + parts = set(parts) + # tag the text + for m in parts: + if len(m.strip()) > 1: + text = text.replace(m, f"{m}") + # clean-tags + text = text.replace("start", "") + text = text.replace("end", "") + return text + + +def normalize(sen): + global bnorm # pylint: disable=global-statement + _words = [bnorm(word)["normalized"] for word in sen.split()] + return " ".join([word for word in _words if word is not None]) + + +def expand_full_attribution(text): + for word, attr in attribution_dict.items(): + if word in text: + text = text.replace(word, normalize(attr)) + return text + + +def collapse_whitespace(text): + # Regular expression matching whitespace: + _whitespace_re = re.compile(r"\s+") + return re.sub(_whitespace_re, " ", text) + + +def bangla_text_to_phonemes(text: str) -> str: + # english numbers to bangla conversion + res = re.search("[0-9]", text) + if res is not None: + text = bangla.convert_english_digit_to_bangla_digit(text) + + # replace ':' in between two bangla numbers with ' এর ' + pattern = r"[০, ১, ২, ৩, ৪, ৫, ৬, ৭, ৮, ৯]:[০, ১, ২, ৩, ৪, ৫, ৬, ৭, ৮, ৯]" + matches = re.findall(pattern, text) + for m in matches: + r = m.replace(":", " এর ") + text = text.replace(m, r) + + # numerize text + text = numerize(text) + + # tag sections + text = tag_text(text) + + # text blocks + # blocks = text.split("") + # blocks = [b for b in blocks if b.strip()] + + # create tuple of (lang,text) + if "" in text: + text = text.replace("", "").replace("", "") + # Split based on sentence ending Characters + bn_text = text.strip() + + sentenceEnders = re.compile("[।!?]") + sentences = sentenceEnders.split(str(bn_text)) + + data = "" + for sent in sentences: + res = re.sub("\n", "", sent) + res = normalize(res) + # expand attributes + res = expand_full_attribution(res) + + res = collapse_whitespace(res) + res += "।" + data += res + return data diff --git a/viXTTS/TTS/tts/utils/text/belarusian/__init__.py b/viXTTS/TTS/tts/utils/text/belarusian/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/text/belarusian/phonemizer.py b/viXTTS/TTS/tts/utils/text/belarusian/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1922577e5b479980a8e11ac3ae15549cfeb178db --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/belarusian/phonemizer.py @@ -0,0 +1,37 @@ +import os + +finder = None + + +def init(): + try: + import jpype + import jpype.imports + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`." + ) + + try: + jar_path = os.environ["BEL_FANETYKA_JAR"] + except KeyError: + raise KeyError("You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file") + + jpype.startJVM(classpath=[jar_path]) + + # import the Java modules + from org.alex73.korpus.base import GrammarDB2, GrammarFinder + + grammar_db = GrammarDB2.initializeFromJar() + global finder + finder = GrammarFinder(grammar_db) + + +def belarusian_text_to_phonemes(text: str) -> str: + # Initialize only on first run + if finder is None: + init() + + from org.alex73.fanetyka.impl import FanetykaText + + return str(FanetykaText(finder, text).ipa) diff --git a/viXTTS/TTS/tts/utils/text/characters.py b/viXTTS/TTS/tts/utils/text/characters.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa45ed84bef4aa7953bd365b025d38f82b717d2 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/characters.py @@ -0,0 +1,501 @@ +from dataclasses import replace +from typing import Dict + +from TTS.tts.configs.shared_configs import CharactersConfig + + +def parse_symbols(): + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } + + +# DEFAULT SET OF GRAPHEMES +_pad = "" +_eos = "" +_bos = "" +_blank = "" # TODO: check if we need this alongside with PAD +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_punctuations = "!'(),-.:;? " + + +# DEFAULT SET OF IPA PHONEMES +# Phonemes definition (All IPA characters) +_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" +_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" +_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" +_suprasegmentals = "ˈˌːˑ" +_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" +_diacrilics = "ɚ˞ɫ" +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics + + +class BaseVocabulary: + """Base Vocabulary class. + + This class only needs a vocabulary dictionary without specifying the characters. + + Args: + vocab (Dict): A dictionary of characters and their corresponding indices. + """ + + def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None): + self.vocab = vocab + self.pad = pad + self.blank = blank + self.bos = bos + self.eos = eos + + @property + def pad_id(self) -> int: + """Return the index of the padding character. If the padding character is not specified, return the length + of the vocabulary.""" + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + """Return the index of the blank character. If the blank character is not specified, return the length of + the vocabulary.""" + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def bos_id(self) -> int: + """Return the index of the bos character. If the bos character is not specified, return the length of the + vocabulary.""" + return self.char_to_id(self.bos) if self.bos else len(self.vocab) + + @property + def eos_id(self) -> int: + """Return the index of the eos character. If the eos character is not specified, return the length of the + vocabulary.""" + return self.char_to_id(self.eos) if self.eos else len(self.vocab) + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab): + """Set the vocabulary dictionary and character mapping dictionaries.""" + self._vocab, self._char_to_id, self._id_to_char = None, None, None + if vocab is not None: + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension + } + + @staticmethod + def init_from_config(config, **kwargs): + """Initialize from the given config.""" + if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict: + return ( + BaseVocabulary( + config.characters.vocab_dict, + config.characters.pad, + config.characters.blank, + config.characters.bos, + config.characters.eos, + ), + config, + ) + return BaseVocabulary(**kwargs), config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + vocab_dict=self._vocab, + pad=self.pad, + eos=self.eos, + bos=self.bos, + blank=self.blank, + is_unique=False, + is_sorted=False, + ) + + @property + def num_chars(self): + """Return number of tokens in the vocabulary.""" + return len(self._vocab) + + def char_to_id(self, char: str) -> int: + """Map a character to an token ID.""" + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + """Map an token ID to a character.""" + return self._id_to_char[idx] + + +class BaseCharacters: + """🐸BaseCharacters class + + Every new character class should inherit from this. + + Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```. + + If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method. + + Args: + characters (str): + Main set of characters to be used in the vocabulary. + + punctuations (str): + Characters to be treated as punctuation. + + pad (str): + Special padding character that would be ignored by the model. + + eos (str): + End of the sentence character. + + bos (str): + Beginning of the sentence character. + + blank (str): + Optional character used between characters by some models for better prosody. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + el + is_sorted (bool): + Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True. + """ + + def __init__( + self, + characters: str = None, + punctuations: str = None, + pad: str = None, + eos: str = None, + bos: str = None, + blank: str = None, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + self._characters = characters + self._punctuations = punctuations + self._pad = pad + self._eos = eos + self._bos = bos + self._blank = blank + self.is_unique = is_unique + self.is_sorted = is_sorted + self._create_vocab() + + @property + def pad_id(self) -> int: + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def eos_id(self) -> int: + return self.char_to_id(self.eos) if self.eos else len(self.vocab) + + @property + def bos_id(self) -> int: + return self.char_to_id(self.bos) if self.bos else len(self.vocab) + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, characters): + self._characters = characters + self._create_vocab() + + @property + def punctuations(self): + return self._punctuations + + @punctuations.setter + def punctuations(self, punctuations): + self._punctuations = punctuations + self._create_vocab() + + @property + def pad(self): + return self._pad + + @pad.setter + def pad(self, pad): + self._pad = pad + self._create_vocab() + + @property + def eos(self): + return self._eos + + @eos.setter + def eos(self, eos): + self._eos = eos + self._create_vocab() + + @property + def bos(self): + return self._bos + + @bos.setter + def bos(self, bos): + self._bos = bos + self._create_vocab() + + @property + def blank(self): + return self._blank + + @blank.setter + def blank(self, blank): + self._blank = blank + self._create_vocab() + + @property + def vocab(self): + return self._vocab + + @vocab.setter + def vocab(self, vocab): + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension + } + + @property + def num_chars(self): + return len(self._vocab) + + def _create_vocab(self): + _vocab = self._characters + if self.is_unique: + _vocab = list(set(_vocab)) + if self.is_sorted: + _vocab = sorted(_vocab) + _vocab = list(_vocab) + _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab + _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab + _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab + _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab + self.vocab = _vocab + list(self._punctuations) + if self.is_unique: + duplicates = {x for x in self.vocab if self.vocab.count(x) > 1} + assert ( + len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) + ), f" [!] There are duplicate characters in the character set. {duplicates}" + + def char_to_id(self, char: str) -> int: + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + return self._id_to_char[idx] + + def print_log(self, level: int = 0): + """ + Prints the vocabulary in a nice format. + """ + indent = "\t" * level + print(f"{indent}| > Characters: {self._characters}") + print(f"{indent}| > Punctuations: {self._punctuations}") + print(f"{indent}| > Pad: {self._pad}") + print(f"{indent}| > EOS: {self._eos}") + print(f"{indent}| > BOS: {self._bos}") + print(f"{indent}| > Blank: {self._blank}") + print(f"{indent}| > Vocab: {self.vocab}") + print(f"{indent}| > Num chars: {self.num_chars}") + + @staticmethod + def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument + """Init your character class from a config. + + Implement this method for your subclass. + """ + # use character set from config + if config.characters is not None: + return BaseCharacters(**config.characters), config + # return default character set + characters = BaseCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=self._eos, + bos=self._bos, + blank=self._blank, + is_unique=self.is_unique, + is_sorted=self.is_sorted, + ) + + +class IPAPhonemes(BaseCharacters): + """🐸IPAPhonemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using IPAPhonemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _phonemes, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a IPAPhonemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + # band-aid for compatibility with old models + if "characters" in config and config.characters is not None: + if "phonemes" in config.characters and config.characters.phonemes is not None: + config.characters["characters"] = config.characters["phonemes"] + return ( + IPAPhonemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + # use character set from config + if config.characters is not None: + return IPAPhonemes(**config.characters), config + # return default character set + characters = IPAPhonemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +class Graphemes(BaseCharacters): + """🐸Graphemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using graphemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a Graphemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + if config.characters is not None: + # band-aid for compatibility with old models + if "phonemes" in config.characters: + return ( + Graphemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + return Graphemes(**config.characters), config + characters = Graphemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +if __name__ == "__main__": + gr = Graphemes() + ph = IPAPhonemes() + gr.print_log() + ph.print_log() diff --git a/viXTTS/TTS/tts/utils/text/chinese_mandarin/__init__.py b/viXTTS/TTS/tts/utils/text/chinese_mandarin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/text/chinese_mandarin/numbers.py b/viXTTS/TTS/tts/utils/text/chinese_mandarin/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..4787ea61007656819eb57d52d5865b38c7afa915 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/chinese_mandarin/numbers.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Licensed under WTFPL or the Unlicense or CC0. +# This uses Python 3, but it's easy to port to Python 2 by changing +# strings to u'xx'. + +import itertools +import re + + +def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: + """Convert numerical arabic numbers (0->9) to chinese hanzi numbers (〇 -> 九) + + Args: + num (str): arabic number to convert + big (bool, optional): use financial characters. Defaults to False. + simp (bool, optional): use simplified characters instead of tradictional characters. Defaults to True. + o (bool, optional): use 〇 for 'zero'. Defaults to False. + twoalt (bool, optional): use 两/兩 for 'two' when appropriate. Defaults to False. + + Raises: + ValueError: if number is more than 1e48 + ValueError: if 'e' exposent in number + + Returns: + str: converted number as hanzi characters + """ + + # check num first + nd = str(num) + if abs(float(nd)) >= 1e48: + raise ValueError("number out of range") + if "e" in nd: + raise ValueError("scientific notation is not supported") + c_symbol = "正负点" if simp else "正負點" + if o: # formal + twoalt = False + if big: + c_basic = "零壹贰叁肆伍陆柒捌玖" if simp else "零壹貳參肆伍陸柒捌玖" + c_unit1 = "拾佰仟" + c_twoalt = "贰" if simp else "貳" + else: + c_basic = "〇一二三四五六七八九" if o else "零一二三四五六七八九" + c_unit1 = "十百千" + if twoalt: + c_twoalt = "两" if simp else "兩" + else: + c_twoalt = "二" + c_unit2 = "万亿兆京垓秭穰沟涧正载" if simp else "萬億兆京垓秭穰溝澗正載" + revuniq = lambda l: "".join(k for k, g in itertools.groupby(reversed(l))) + nd = str(num) + result = [] + if nd[0] == "+": + result.append(c_symbol[0]) + elif nd[0] == "-": + result.append(c_symbol[1]) + if "." in nd: + integer, remainder = nd.lstrip("+-").split(".") + else: + integer, remainder = nd.lstrip("+-"), None + if int(integer): + splitted = [integer[max(i - 4, 0) : i] for i in range(len(integer), 0, -4)] + intresult = [] + for nu, unit in enumerate(splitted): + # special cases + if int(unit) == 0: # 0000 + intresult.append(c_basic[0]) + continue + if nu > 0 and int(unit) == 2: # 0002 + intresult.append(c_twoalt + c_unit2[nu - 1]) + continue + ulist = [] + unit = unit.zfill(4) + for nc, ch in enumerate(reversed(unit)): + if ch == "0": + if ulist: # ???0 + ulist.append(c_basic[0]) + elif nc == 0: + ulist.append(c_basic[int(ch)]) + elif nc == 1 and ch == "1" and unit[1] == "0": + # special case for tens + # edit the 'elif' if you don't like + # 十四, 三千零十四, 三千三百一十四 + ulist.append(c_unit1[0]) + elif nc > 1 and ch == "2": + ulist.append(c_twoalt + c_unit1[nc - 1]) + else: + ulist.append(c_basic[int(ch)] + c_unit1[nc - 1]) + ustr = revuniq(ulist) + if nu == 0: + intresult.append(ustr) + else: + intresult.append(ustr + c_unit2[nu - 1]) + result.append(revuniq(intresult).strip(c_basic[0])) + else: + result.append(c_basic[0]) + if remainder: + result.append(c_symbol[2]) + result.append("".join(c_basic[int(ch)] for ch in remainder)) + return "".join(result) + + +def _number_replace(match) -> str: + """function to apply in a match, transform all numbers in a match by chinese characters + + Args: + match (re.Match): numbers regex matches + + Returns: + str: replaced characters for the numbers + """ + match_str: str = match.group() + return _num2chinese(match_str) + + +def replace_numbers_to_characters_in_text(text: str) -> str: + """Replace all arabic numbers in a text by their equivalent in chinese characters (simplified) + + Args: + text (str): input text to transform + + Returns: + str: output text + """ + text = re.sub(r"[0-9]+", _number_replace, text) + return text diff --git a/viXTTS/TTS/tts/utils/text/chinese_mandarin/phonemizer.py b/viXTTS/TTS/tts/utils/text/chinese_mandarin/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..727c881e1062badc57df7418aa07e7434d57335c --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/chinese_mandarin/phonemizer.py @@ -0,0 +1,37 @@ +from typing import List + +import jieba +import pypinyin + +from .pinyinToPhonemes import PINYIN_DICT + + +def _chinese_character_to_pinyin(text: str) -> List[str]: + pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True) + pinyins_flat_list = [item for sublist in pinyins for item in sublist] + return pinyins_flat_list + + +def _chinese_pinyin_to_phoneme(pinyin: str) -> str: + segment = pinyin[:-1] + tone = pinyin[-1] + phoneme = PINYIN_DICT.get(segment, [""])[0] + return phoneme + tone + + +def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str: + tokenized_text = jieba.cut(text, HMM=False) + tokenized_text = " ".join(tokenized_text) + pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text) + + results: List[str] = [] + + for token in pinyined_text: + if token[-1] in "12345": # TODO transform to is_pinyin() + pinyin_phonemes = _chinese_pinyin_to_phoneme(token) + + results += list(pinyin_phonemes) + else: # is ponctuation or other + results += list(token) + + return seperator.join(results) diff --git a/viXTTS/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py b/viXTTS/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py new file mode 100644 index 0000000000000000000000000000000000000000..4e25c3a4c91cddd0bf0e5d6e273262e3dbd3a2dd --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py @@ -0,0 +1,419 @@ +PINYIN_DICT = { + "a": ["a"], + "ai": ["ai"], + "an": ["an"], + "ang": ["ɑŋ"], + "ao": ["aʌ"], + "ba": ["ba"], + "bai": ["bai"], + "ban": ["ban"], + "bang": ["bɑŋ"], + "bao": ["baʌ"], + # "be": ["be"], doesnt exist + "bei": ["bɛi"], + "ben": ["bœn"], + "beng": ["bɵŋ"], + "bi": ["bi"], + "bian": ["biɛn"], + "biao": ["biaʌ"], + "bie": ["bie"], + "bin": ["bin"], + "bing": ["bɨŋ"], + "bo": ["bo"], + "bu": ["bu"], + "ca": ["tsa"], + "cai": ["tsai"], + "can": ["tsan"], + "cang": ["tsɑŋ"], + "cao": ["tsaʌ"], + "ce": ["tsø"], + "cen": ["tsœn"], + "ceng": ["tsɵŋ"], + "cha": ["ʈʂa"], + "chai": ["ʈʂai"], + "chan": ["ʈʂan"], + "chang": ["ʈʂɑŋ"], + "chao": ["ʈʂaʌ"], + "che": ["ʈʂø"], + "chen": ["ʈʂœn"], + "cheng": ["ʈʂɵŋ"], + "chi": ["ʈʂʏ"], + "chong": ["ʈʂoŋ"], + "chou": ["ʈʂou"], + "chu": ["ʈʂu"], + "chua": ["ʈʂua"], + "chuai": ["ʈʂuai"], + "chuan": ["ʈʂuan"], + "chuang": ["ʈʂuɑŋ"], + "chui": ["ʈʂuei"], + "chun": ["ʈʂun"], + "chuo": ["ʈʂuo"], + "ci": ["tsɪ"], + "cong": ["tsoŋ"], + "cou": ["tsou"], + "cu": ["tsu"], + "cuan": ["tsuan"], + "cui": ["tsuei"], + "cun": ["tsun"], + "cuo": ["tsuo"], + "da": ["da"], + "dai": ["dai"], + "dan": ["dan"], + "dang": ["dɑŋ"], + "dao": ["daʌ"], + "de": ["dø"], + "dei": ["dei"], + # "den": ["dœn"], + "deng": ["dɵŋ"], + "di": ["di"], + "dia": ["dia"], + "dian": ["diɛn"], + "diao": ["diaʌ"], + "die": ["die"], + "ding": ["dɨŋ"], + "diu": ["dio"], + "dong": ["doŋ"], + "dou": ["dou"], + "du": ["du"], + "duan": ["duan"], + "dui": ["duei"], + "dun": ["dun"], + "duo": ["duo"], + "e": ["ø"], + "ei": ["ei"], + "en": ["œn"], + # "ng": ["œn"], + # "eng": ["ɵŋ"], + "er": ["er"], + "fa": ["fa"], + "fan": ["fan"], + "fang": ["fɑŋ"], + "fei": ["fei"], + "fen": ["fœn"], + "feng": ["fɵŋ"], + "fo": ["fo"], + "fou": ["fou"], + "fu": ["fu"], + "ga": ["ga"], + "gai": ["gai"], + "gan": ["gan"], + "gang": ["gɑŋ"], + "gao": ["gaʌ"], + "ge": ["gø"], + "gei": ["gei"], + "gen": ["gœn"], + "geng": ["gɵŋ"], + "gong": ["goŋ"], + "gou": ["gou"], + "gu": ["gu"], + "gua": ["gua"], + "guai": ["guai"], + "guan": ["guan"], + "guang": ["guɑŋ"], + "gui": ["guei"], + "gun": ["gun"], + "guo": ["guo"], + "ha": ["xa"], + "hai": ["xai"], + "han": ["xan"], + "hang": ["xɑŋ"], + "hao": ["xaʌ"], + "he": ["xø"], + "hei": ["xei"], + "hen": ["xœn"], + "heng": ["xɵŋ"], + "hong": ["xoŋ"], + "hou": ["xou"], + "hu": ["xu"], + "hua": ["xua"], + "huai": ["xuai"], + "huan": ["xuan"], + "huang": ["xuɑŋ"], + "hui": ["xuei"], + "hun": ["xun"], + "huo": ["xuo"], + "ji": ["dʑi"], + "jia": ["dʑia"], + "jian": ["dʑiɛn"], + "jiang": ["dʑiɑŋ"], + "jiao": ["dʑiaʌ"], + "jie": ["dʑie"], + "jin": ["dʑin"], + "jing": ["dʑɨŋ"], + "jiong": ["dʑioŋ"], + "jiu": ["dʑio"], + "ju": ["dʑy"], + "juan": ["dʑyɛn"], + "jue": ["dʑye"], + "jun": ["dʑyn"], + "ka": ["ka"], + "kai": ["kai"], + "kan": ["kan"], + "kang": ["kɑŋ"], + "kao": ["kaʌ"], + "ke": ["kø"], + "kei": ["kei"], + "ken": ["kœn"], + "keng": ["kɵŋ"], + "kong": ["koŋ"], + "kou": ["kou"], + "ku": ["ku"], + "kua": ["kua"], + "kuai": ["kuai"], + "kuan": ["kuan"], + "kuang": ["kuɑŋ"], + "kui": ["kuei"], + "kun": ["kun"], + "kuo": ["kuo"], + "la": ["la"], + "lai": ["lai"], + "lan": ["lan"], + "lang": ["lɑŋ"], + "lao": ["laʌ"], + "le": ["lø"], + "lei": ["lei"], + "leng": ["lɵŋ"], + "li": ["li"], + "lia": ["lia"], + "lian": ["liɛn"], + "liang": ["liɑŋ"], + "liao": ["liaʌ"], + "lie": ["lie"], + "lin": ["lin"], + "ling": ["lɨŋ"], + "liu": ["lio"], + "lo": ["lo"], + "long": ["loŋ"], + "lou": ["lou"], + "lu": ["lu"], + "lv": ["ly"], + "luan": ["luan"], + "lve": ["lye"], + "lue": ["lue"], + "lun": ["lun"], + "luo": ["luo"], + "ma": ["ma"], + "mai": ["mai"], + "man": ["man"], + "mang": ["mɑŋ"], + "mao": ["maʌ"], + "me": ["mø"], + "mei": ["mei"], + "men": ["mœn"], + "meng": ["mɵŋ"], + "mi": ["mi"], + "mian": ["miɛn"], + "miao": ["miaʌ"], + "mie": ["mie"], + "min": ["min"], + "ming": ["mɨŋ"], + "miu": ["mio"], + "mo": ["mo"], + "mou": ["mou"], + "mu": ["mu"], + "na": ["na"], + "nai": ["nai"], + "nan": ["nan"], + "nang": ["nɑŋ"], + "nao": ["naʌ"], + "ne": ["nø"], + "nei": ["nei"], + "nen": ["nœn"], + "neng": ["nɵŋ"], + "ni": ["ni"], + "nia": ["nia"], + "nian": ["niɛn"], + "niang": ["niɑŋ"], + "niao": ["niaʌ"], + "nie": ["nie"], + "nin": ["nin"], + "ning": ["nɨŋ"], + "niu": ["nio"], + "nong": ["noŋ"], + "nou": ["nou"], + "nu": ["nu"], + "nv": ["ny"], + "nuan": ["nuan"], + "nve": ["nye"], + "nue": ["nye"], + "nuo": ["nuo"], + "o": ["o"], + "ou": ["ou"], + "pa": ["pa"], + "pai": ["pai"], + "pan": ["pan"], + "pang": ["pɑŋ"], + "pao": ["paʌ"], + "pe": ["pø"], + "pei": ["pei"], + "pen": ["pœn"], + "peng": ["pɵŋ"], + "pi": ["pi"], + "pian": ["piɛn"], + "piao": ["piaʌ"], + "pie": ["pie"], + "pin": ["pin"], + "ping": ["pɨŋ"], + "po": ["po"], + "pou": ["pou"], + "pu": ["pu"], + "qi": ["tɕi"], + "qia": ["tɕia"], + "qian": ["tɕiɛn"], + "qiang": ["tɕiɑŋ"], + "qiao": ["tɕiaʌ"], + "qie": ["tɕie"], + "qin": ["tɕin"], + "qing": ["tɕɨŋ"], + "qiong": ["tɕioŋ"], + "qiu": ["tɕio"], + "qu": ["tɕy"], + "quan": ["tɕyɛn"], + "que": ["tɕye"], + "qun": ["tɕyn"], + "ran": ["ʐan"], + "rang": ["ʐɑŋ"], + "rao": ["ʐaʌ"], + "re": ["ʐø"], + "ren": ["ʐœn"], + "reng": ["ʐɵŋ"], + "ri": ["ʐʏ"], + "rong": ["ʐoŋ"], + "rou": ["ʐou"], + "ru": ["ʐu"], + "rua": ["ʐua"], + "ruan": ["ʐuan"], + "rui": ["ʐuei"], + "run": ["ʐun"], + "ruo": ["ʐuo"], + "sa": ["sa"], + "sai": ["sai"], + "san": ["san"], + "sang": ["sɑŋ"], + "sao": ["saʌ"], + "se": ["sø"], + "sen": ["sœn"], + "seng": ["sɵŋ"], + "sha": ["ʂa"], + "shai": ["ʂai"], + "shan": ["ʂan"], + "shang": ["ʂɑŋ"], + "shao": ["ʂaʌ"], + "she": ["ʂø"], + "shei": ["ʂei"], + "shen": ["ʂœn"], + "sheng": ["ʂɵŋ"], + "shi": ["ʂʏ"], + "shou": ["ʂou"], + "shu": ["ʂu"], + "shua": ["ʂua"], + "shuai": ["ʂuai"], + "shuan": ["ʂuan"], + "shuang": ["ʂuɑŋ"], + "shui": ["ʂuei"], + "shun": ["ʂun"], + "shuo": ["ʂuo"], + "si": ["sɪ"], + "song": ["soŋ"], + "sou": ["sou"], + "su": ["su"], + "suan": ["suan"], + "sui": ["suei"], + "sun": ["sun"], + "suo": ["suo"], + "ta": ["ta"], + "tai": ["tai"], + "tan": ["tan"], + "tang": ["tɑŋ"], + "tao": ["taʌ"], + "te": ["tø"], + "tei": ["tei"], + "teng": ["tɵŋ"], + "ti": ["ti"], + "tian": ["tiɛn"], + "tiao": ["tiaʌ"], + "tie": ["tie"], + "ting": ["tɨŋ"], + "tong": ["toŋ"], + "tou": ["tou"], + "tu": ["tu"], + "tuan": ["tuan"], + "tui": ["tuei"], + "tun": ["tun"], + "tuo": ["tuo"], + "wa": ["wa"], + "wai": ["wai"], + "wan": ["wan"], + "wang": ["wɑŋ"], + "wei": ["wei"], + "wen": ["wœn"], + "weng": ["wɵŋ"], + "wo": ["wo"], + "wu": ["wu"], + "xi": ["ɕi"], + "xia": ["ɕia"], + "xian": ["ɕiɛn"], + "xiang": ["ɕiɑŋ"], + "xiao": ["ɕiaʌ"], + "xie": ["ɕie"], + "xin": ["ɕin"], + "xing": ["ɕɨŋ"], + "xiong": ["ɕioŋ"], + "xiu": ["ɕio"], + "xu": ["ɕy"], + "xuan": ["ɕyɛn"], + "xue": ["ɕye"], + "xun": ["ɕyn"], + "ya": ["ia"], + "yan": ["iɛn"], + "yang": ["iɑŋ"], + "yao": ["iaʌ"], + "ye": ["ie"], + "yi": ["i"], + "yin": ["in"], + "ying": ["ɨŋ"], + "yo": ["io"], + "yong": ["ioŋ"], + "you": ["io"], + "yu": ["y"], + "yuan": ["yɛn"], + "yue": ["ye"], + "yun": ["yn"], + "za": ["dza"], + "zai": ["dzai"], + "zan": ["dzan"], + "zang": ["dzɑŋ"], + "zao": ["dzaʌ"], + "ze": ["dzø"], + "zei": ["dzei"], + "zen": ["dzœn"], + "zeng": ["dzɵŋ"], + "zha": ["dʒa"], + "zhai": ["dʒai"], + "zhan": ["dʒan"], + "zhang": ["dʒɑŋ"], + "zhao": ["dʒaʌ"], + "zhe": ["dʒø"], + # "zhei": ["dʒei"], it doesn't exist + "zhen": ["dʒœn"], + "zheng": ["dʒɵŋ"], + "zhi": ["dʒʏ"], + "zhong": ["dʒoŋ"], + "zhou": ["dʒou"], + "zhu": ["dʒu"], + "zhua": ["dʒua"], + "zhuai": ["dʒuai"], + "zhuan": ["dʒuan"], + "zhuang": ["dʒuɑŋ"], + "zhui": ["dʒuei"], + "zhun": ["dʒun"], + "zhuo": ["dʒuo"], + "zi": ["dzɪ"], + "zong": ["dzoŋ"], + "zou": ["dzou"], + "zu": ["dzu"], + "zuan": ["dzuan"], + "zui": ["dzuei"], + "zun": ["dzun"], + "zuo": ["dzuo"], +} diff --git a/viXTTS/TTS/tts/utils/text/cleaners.py b/viXTTS/TTS/tts/utils/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..74d3910b516ae1dc856509d62aca52a01cc2088b --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/cleaners.py @@ -0,0 +1,171 @@ +"""Set of default text cleaners""" +# TODO: pick the cleaner for languages dynamically + +import re + +from anyascii import anyascii + +from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text + +from .english.abbreviations import abbreviations_en +from .english.number_norm import normalize_numbers as en_normalize_numbers +from .english.time_norm import expand_time_english +from .french.abbreviations import abbreviations_fr + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + + +def expand_abbreviations(text, lang="en"): + if lang == "en": + _abbreviations = abbreviations_en + elif lang == "fr": + _abbreviations = abbreviations_fr + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text).strip() + + +def convert_to_ascii(text): + return anyascii(text) + + +def remove_aux_symbols(text): + text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text) + return text + + +def replace_symbols(text, lang="en"): + """Replace symbols based on the lenguage tag. + + Args: + text: + Input text. + lang: + Lenguage identifier. ex: "en", "fr", "pt", "ca". + + Returns: + The modified text + example: + input args: + text: "si l'avi cau, diguem-ho" + lang: "ca" + Output: + text: "si lavi cau, diguemho" + """ + text = text.replace(";", ",") + text = text.replace("-", " ") if lang != "ca" else text.replace("-", "") + text = text.replace(":", ",") + if lang == "en": + text = text.replace("&", " and ") + elif lang == "fr": + text = text.replace("&", " et ") + elif lang == "pt": + text = text.replace("&", " e ") + elif lang == "ca": + text = text.replace("&", " i ") + text = text.replace("'", "") + return text + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + # text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def basic_german_cleaners(text): + """Pipeline for German text""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +# TODO: elaborate it +def basic_turkish_cleaners(text): + """Pipeline for Turkish text""" + text = text.replace("I", "ı") + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + """Pipeline for English text, including number and abbreviation expansion.""" + # text = convert_to_ascii(text) + text = lowercase(text) + text = expand_time_english(text) + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def phoneme_cleaners(text): + """Pipeline for phonemes mode, including number and abbreviation expansion.""" + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def french_cleaners(text): + """Pipeline for French text. There is no need to expand numbers, phonemizer already does that""" + text = expand_abbreviations(text, lang="fr") + text = lowercase(text) + text = replace_symbols(text, lang="fr") + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def portuguese_cleaners(text): + """Basic pipeline for Portuguese text. There is no need to expand abbreviation and + numbers, phonemizer already does that""" + text = lowercase(text) + text = replace_symbols(text, lang="pt") + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def chinese_mandarin_cleaners(text: str) -> str: + """Basic pipeline for chinese""" + text = replace_numbers_to_characters_in_text(text) + return text + + +def multilingual_cleaners(text): + """Pipeline for multilingual text""" + text = lowercase(text) + text = replace_symbols(text, lang=None) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def no_cleaners(text): + # remove newline characters + text = text.replace("\n", "") + return text diff --git a/viXTTS/TTS/tts/utils/text/cmudict.py b/viXTTS/TTS/tts/utils/text/cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..f206fb043be1d478fa6ace36fefdefa30b0acb02 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/cmudict.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +import re + +VALID_SYMBOLS = [ + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", +] + + +class CMUDict: + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" + + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding="latin-1") as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + def __len__(self): + return len(self._entries) + + def lookup(self, word): + """Returns list of ARPAbet pronunciations of the given word.""" + return self._entries.get(word.upper()) + + @staticmethod + def get_arpabet(word, cmudict, punctuation_symbols): + first_symbol, last_symbol = "", "" + if word and word[0] in punctuation_symbols: + first_symbol = word[0] + word = word[1:] + if word and word[-1] in punctuation_symbols: + last_symbol = word[-1] + word = word[:-1] + arpabet = cmudict.lookup(word) + if arpabet is not None: + return first_symbol + "{%s}" % arpabet[0] + last_symbol + return first_symbol + word + last_symbol + + +_alt_re = re.compile(r"\([0-9]+\)") + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if line and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(" ") + for part in parts: + if part not in VALID_SYMBOLS: + return None + return " ".join(parts) diff --git a/viXTTS/TTS/tts/utils/text/english/__init__.py b/viXTTS/TTS/tts/utils/text/english/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/text/english/abbreviations.py b/viXTTS/TTS/tts/utils/text/english/abbreviations.py new file mode 100644 index 0000000000000000000000000000000000000000..cd93c13c8ecfbc0df2d0c6d2fa348388940c213a --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/english/abbreviations.py @@ -0,0 +1,26 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in english: +abbreviations_en = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] diff --git a/viXTTS/TTS/tts/utils/text/english/number_norm.py b/viXTTS/TTS/tts/utils/text/english/number_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e8377ede87ebc9d1bb9cffbbb290aa7787caea4f --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/english/number_norm.py @@ -0,0 +1,97 @@ +""" from https://github.com/keithito/tacotron """ + +import re +from typing import Dict + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"-?[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def __expand_currency(value: str, inflection: Dict[float, str]) -> str: + parts = value.replace(",", "").split(".") + if len(parts) > 2: + return f"{value} {inflection[2]}" # Unexpected format + text = [] + integer = int(parts[0]) if parts[0] else 0 + if integer > 0: + integer_unit = inflection.get(integer, inflection[2]) + text.append(f"{integer} {integer_unit}") + fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if fraction > 0: + fraction_unit = inflection.get(fraction / 100, inflection[0.02]) + text.append(f"{fraction} {fraction_unit}") + if len(text) == 0: + return f"zero {inflection[2]}" + return " ".join(text) + + +def _expand_currency(m: "re.Match") -> str: + currencies = { + "$": { + 0.01: "cent", + 0.02: "cents", + 1: "dollar", + 2: "dollars", + }, + "€": { + 0.01: "cent", + 0.02: "cents", + 1: "euro", + 2: "euros", + }, + "£": { + 0.01: "penny", + 0.02: "pence", + 1: "pound sterling", + 2: "pounds sterling", + }, + "¥": { + # TODO rin + 0.02: "sen", + 2: "yen", + }, + } + unit = m.group(1) + currency = currencies[unit] + value = m.group(2) + return __expand_currency(value, currency) + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if 1000 < num < 3000: + if num == 2000: + return "two thousand" + if 2000 < num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + if num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_currency_re, _expand_currency, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/viXTTS/TTS/tts/utils/text/english/time_norm.py b/viXTTS/TTS/tts/utils/text/english/time_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ac09e79db4a239a7f72f101503dbf0d6feb3ae --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/english/time_norm.py @@ -0,0 +1,47 @@ +import re + +import inflect + +_inflect = inflect.engine() + +_time_re = re.compile( + r"""\b + ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours + : + ([0-5][0-9]) # minutes + \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm + \b""", + re.IGNORECASE | re.X, +) + + +def _expand_num(n: int) -> str: + return _inflect.number_to_words(n) + + +def _expand_time_english(match: "re.Match") -> str: + hour = int(match.group(1)) + past_noon = hour >= 12 + time = [] + if hour > 12: + hour -= 12 + elif hour == 0: + hour = 12 + past_noon = True + time.append(_expand_num(hour)) + + minute = int(match.group(6)) + if minute > 0: + if minute < 10: + time.append("oh") + time.append(_expand_num(minute)) + am_pm = match.group(7) + if am_pm is None: + time.append("p m" if past_noon else "a m") + else: + time.extend(list(am_pm.replace(".", ""))) + return " ".join(time) + + +def expand_time_english(text: str) -> str: + return re.sub(_time_re, _expand_time_english, text) diff --git a/viXTTS/TTS/tts/utils/text/french/__init__.py b/viXTTS/TTS/tts/utils/text/french/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/text/french/abbreviations.py b/viXTTS/TTS/tts/utils/text/french/abbreviations.py new file mode 100644 index 0000000000000000000000000000000000000000..f580dfed7b4576a9f87b0a4145cb729e70050d50 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/french/abbreviations.py @@ -0,0 +1,48 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in french: +abbreviations_fr = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("M", "monsieur"), + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ("N.B", "nota bene"), + ("M", "monsieur"), + ("p.c.q", "parce que"), + ("Pr", "professeur"), + ("qqch", "quelque chose"), + ("rdv", "rendez-vous"), + ("max", "maximum"), + ("min", "minimum"), + ("no", "numéro"), + ("adr", "adresse"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "companie"), + ("jr", "junior"), + ("sgt", "sergent"), + ("capt", "capitain"), + ("col", "colonel"), + ("av", "avenue"), + ("av. J.-C", "avant Jésus-Christ"), + ("apr. J.-C", "après Jésus-Christ"), + ("art", "article"), + ("boul", "boulevard"), + ("c.-à-d", "c’est-à-dire"), + ("etc", "et cetera"), + ("ex", "exemple"), + ("excl", "exclusivement"), + ("boul", "boulevard"), + ] +] + [ + (re.compile("\\b%s" % x[0]), x[1]) + for x in [ + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ] +] diff --git a/viXTTS/TTS/tts/utils/text/japanese/__init__.py b/viXTTS/TTS/tts/utils/text/japanese/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/text/japanese/phonemizer.py b/viXTTS/TTS/tts/utils/text/japanese/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c3111067e140903b301696c59b67f090995c1f59 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/japanese/phonemizer.py @@ -0,0 +1,470 @@ +# Convert Japanese text to phonemes which is +# compatible with Julius https://github.com/julius-speech/segmentation-kit + +import re +import unicodedata + +try: + import MeCab +except ImportError as e: + raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e +from num2words import num2words + +_CONVRULES = [ + # Conversion of 2 letters + "アァ/ a a", + "イィ/ i i", + "イェ/ i e", + "イャ/ y a", + "ウゥ/ u:", + "エェ/ e e", + "オォ/ o:", + "カァ/ k a:", + "キィ/ k i:", + "クゥ/ k u:", + "クャ/ ky a", + "クュ/ ky u", + "クョ/ ky o", + "ケェ/ k e:", + "コォ/ k o:", + "ガァ/ g a:", + "ギィ/ g i:", + "グゥ/ g u:", + "グャ/ gy a", + "グュ/ gy u", + "グョ/ gy o", + "ゲェ/ g e:", + "ゴォ/ g o:", + "サァ/ s a:", + "シィ/ sh i:", + "スゥ/ s u:", + "スャ/ sh a", + "スュ/ sh u", + "スョ/ sh o", + "セェ/ s e:", + "ソォ/ s o:", + "ザァ/ z a:", + "ジィ/ j i:", + "ズゥ/ z u:", + "ズャ/ zy a", + "ズュ/ zy u", + "ズョ/ zy o", + "ゼェ/ z e:", + "ゾォ/ z o:", + "タァ/ t a:", + "チィ/ ch i:", + "ツァ/ ts a", + "ツィ/ ts i", + "ツゥ/ ts u:", + "ツャ/ ch a", + "ツュ/ ch u", + "ツョ/ ch o", + "ツェ/ ts e", + "ツォ/ ts o", + "テェ/ t e:", + "トォ/ t o:", + "ダァ/ d a:", + "ヂィ/ j i:", + "ヅゥ/ d u:", + "ヅャ/ zy a", + "ヅュ/ zy u", + "ヅョ/ zy o", + "デェ/ d e:", + "ドォ/ d o:", + "ナァ/ n a:", + "ニィ/ n i:", + "ヌゥ/ n u:", + "ヌャ/ ny a", + "ヌュ/ ny u", + "ヌョ/ ny o", + "ネェ/ n e:", + "ノォ/ n o:", + "ハァ/ h a:", + "ヒィ/ h i:", + "フゥ/ f u:", + "フャ/ hy a", + "フュ/ hy u", + "フョ/ hy o", + "ヘェ/ h e:", + "ホォ/ h o:", + "バァ/ b a:", + "ビィ/ b i:", + "ブゥ/ b u:", + "フャ/ hy a", + "ブュ/ by u", + "フョ/ hy o", + "ベェ/ b e:", + "ボォ/ b o:", + "パァ/ p a:", + "ピィ/ p i:", + "プゥ/ p u:", + "プャ/ py a", + "プュ/ py u", + "プョ/ py o", + "ペェ/ p e:", + "ポォ/ p o:", + "マァ/ m a:", + "ミィ/ m i:", + "ムゥ/ m u:", + "ムャ/ my a", + "ムュ/ my u", + "ムョ/ my o", + "メェ/ m e:", + "モォ/ m o:", + "ヤァ/ y a:", + "ユゥ/ y u:", + "ユャ/ y a:", + "ユュ/ y u:", + "ユョ/ y o:", + "ヨォ/ y o:", + "ラァ/ r a:", + "リィ/ r i:", + "ルゥ/ r u:", + "ルャ/ ry a", + "ルュ/ ry u", + "ルョ/ ry o", + "レェ/ r e:", + "ロォ/ r o:", + "ワァ/ w a:", + "ヲォ/ o:", + "ディ/ d i", + "デェ/ d e:", + "デャ/ dy a", + "デュ/ dy u", + "デョ/ dy o", + "ティ/ t i", + "テェ/ t e:", + "テャ/ ty a", + "テュ/ ty u", + "テョ/ ty o", + "スィ/ s i", + "ズァ/ z u a", + "ズィ/ z i", + "ズゥ/ z u", + "ズャ/ zy a", + "ズュ/ zy u", + "ズョ/ zy o", + "ズェ/ z e", + "ズォ/ z o", + "キャ/ ky a", + "キュ/ ky u", + "キョ/ ky o", + "シャ/ sh a", + "シュ/ sh u", + "シェ/ sh e", + "ショ/ sh o", + "チャ/ ch a", + "チュ/ ch u", + "チェ/ ch e", + "チョ/ ch o", + "トゥ/ t u", + "トャ/ ty a", + "トュ/ ty u", + "トョ/ ty o", + "ドァ/ d o a", + "ドゥ/ d u", + "ドャ/ dy a", + "ドュ/ dy u", + "ドョ/ dy o", + "ドォ/ d o:", + "ニャ/ ny a", + "ニュ/ ny u", + "ニョ/ ny o", + "ヒャ/ hy a", + "ヒュ/ hy u", + "ヒョ/ hy o", + "ミャ/ my a", + "ミュ/ my u", + "ミョ/ my o", + "リャ/ ry a", + "リュ/ ry u", + "リョ/ ry o", + "ギャ/ gy a", + "ギュ/ gy u", + "ギョ/ gy o", + "ヂェ/ j e", + "ヂャ/ j a", + "ヂュ/ j u", + "ヂョ/ j o", + "ジェ/ j e", + "ジャ/ j a", + "ジュ/ j u", + "ジョ/ j o", + "ビャ/ by a", + "ビュ/ by u", + "ビョ/ by o", + "ピャ/ py a", + "ピュ/ py u", + "ピョ/ py o", + "ウァ/ u a", + "ウィ/ w i", + "ウェ/ w e", + "ウォ/ w o", + "ファ/ f a", + "フィ/ f i", + "フゥ/ f u", + "フャ/ hy a", + "フュ/ hy u", + "フョ/ hy o", + "フェ/ f e", + "フォ/ f o", + "ヴァ/ b a", + "ヴィ/ b i", + "ヴェ/ b e", + "ヴォ/ b o", + "ヴュ/ by u", + # Conversion of 1 letter + "ア/ a", + "イ/ i", + "ウ/ u", + "エ/ e", + "オ/ o", + "カ/ k a", + "キ/ k i", + "ク/ k u", + "ケ/ k e", + "コ/ k o", + "サ/ s a", + "シ/ sh i", + "ス/ s u", + "セ/ s e", + "ソ/ s o", + "タ/ t a", + "チ/ ch i", + "ツ/ ts u", + "テ/ t e", + "ト/ t o", + "ナ/ n a", + "ニ/ n i", + "ヌ/ n u", + "ネ/ n e", + "ノ/ n o", + "ハ/ h a", + "ヒ/ h i", + "フ/ f u", + "ヘ/ h e", + "ホ/ h o", + "マ/ m a", + "ミ/ m i", + "ム/ m u", + "メ/ m e", + "モ/ m o", + "ラ/ r a", + "リ/ r i", + "ル/ r u", + "レ/ r e", + "ロ/ r o", + "ガ/ g a", + "ギ/ g i", + "グ/ g u", + "ゲ/ g e", + "ゴ/ g o", + "ザ/ z a", + "ジ/ j i", + "ズ/ z u", + "ゼ/ z e", + "ゾ/ z o", + "ダ/ d a", + "ヂ/ j i", + "ヅ/ z u", + "デ/ d e", + "ド/ d o", + "バ/ b a", + "ビ/ b i", + "ブ/ b u", + "ベ/ b e", + "ボ/ b o", + "パ/ p a", + "ピ/ p i", + "プ/ p u", + "ペ/ p e", + "ポ/ p o", + "ヤ/ y a", + "ユ/ y u", + "ヨ/ y o", + "ワ/ w a", + "ヰ/ i", + "ヱ/ e", + "ヲ/ o", + "ン/ N", + "ッ/ q", + "ヴ/ b u", + "ー/:", + # Try converting broken text + "ァ/ a", + "ィ/ i", + "ゥ/ u", + "ェ/ e", + "ォ/ o", + "ヮ/ w a", + "ォ/ o", + # Symbols + "、/ ,", + "。/ .", + "!/ !", + "?/ ?", + "・/ ,", +] + +_COLON_RX = re.compile(":+") +_REJECT_RX = re.compile("[^ a-zA-Z:,.?]") + + +def _makerulemap(): + l = [tuple(x.split("/")) for x in _CONVRULES] + return tuple({k: v for k, v in l if len(k) == i} for i in (1, 2)) + + +_RULEMAP1, _RULEMAP2 = _makerulemap() + + +def kata2phoneme(text: str) -> str: + """Convert katakana text to phonemes.""" + text = text.strip() + res = "" + while text: + if len(text) >= 2: + x = _RULEMAP2.get(text[:2]) + if x is not None: + text = text[2:] + res += x + continue + x = _RULEMAP1.get(text[0]) + if x is not None: + text = text[1:] + res += x + continue + res += " " + text[0] + text = text[1:] + res = _COLON_RX.sub(":", res) + return res[1:] + + +_KATAKANA = "".join(chr(ch) for ch in range(ord("ァ"), ord("ン") + 1)) +_HIRAGANA = "".join(chr(ch) for ch in range(ord("ぁ"), ord("ん") + 1)) +_HIRA2KATATRANS = str.maketrans(_HIRAGANA, _KATAKANA) + + +def hira2kata(text: str) -> str: + text = text.translate(_HIRA2KATATRANS) + return text.replace("う゛", "ヴ") + + +_SYMBOL_TOKENS = set(list("・、。?!")) +_NO_YOMI_TOKENS = set(list("「」『』―()[][] …")) +_TAGGER = MeCab.Tagger() + + +def text2kata(text: str) -> str: + parsed = _TAGGER.parse(text) + res = [] + for line in parsed.split("\n"): + if line == "EOS": + break + parts = line.split("\t") + + word, yomi = parts[0], parts[1] + if yomi: + res.append(yomi) + else: + if word in _SYMBOL_TOKENS: + res.append(word) + elif word in ("っ", "ッ"): + res.append("ッ") + elif word in _NO_YOMI_TOKENS: + pass + else: + res.append(word) + return hira2kata("".join(res)) + + +_ALPHASYMBOL_YOMI = { + "#": "シャープ", + "%": "パーセント", + "&": "アンド", + "+": "プラス", + "-": "マイナス", + ":": "コロン", + ";": "セミコロン", + "<": "小なり", + "=": "イコール", + ">": "大なり", + "@": "アット", + "a": "エー", + "b": "ビー", + "c": "シー", + "d": "ディー", + "e": "イー", + "f": "エフ", + "g": "ジー", + "h": "エイチ", + "i": "アイ", + "j": "ジェー", + "k": "ケー", + "l": "エル", + "m": "エム", + "n": "エヌ", + "o": "オー", + "p": "ピー", + "q": "キュー", + "r": "アール", + "s": "エス", + "t": "ティー", + "u": "ユー", + "v": "ブイ", + "w": "ダブリュー", + "x": "エックス", + "y": "ワイ", + "z": "ゼット", + "α": "アルファ", + "β": "ベータ", + "γ": "ガンマ", + "δ": "デルタ", + "ε": "イプシロン", + "ζ": "ゼータ", + "η": "イータ", + "θ": "シータ", + "ι": "イオタ", + "κ": "カッパ", + "λ": "ラムダ", + "μ": "ミュー", + "ν": "ニュー", + "ξ": "クサイ", + "ο": "オミクロン", + "π": "パイ", + "ρ": "ロー", + "σ": "シグマ", + "τ": "タウ", + "υ": "ウプシロン", + "φ": "ファイ", + "χ": "カイ", + "ψ": "プサイ", + "ω": "オメガ", +} + + +_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") +_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} +_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") +_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?") + + +def japanese_convert_numbers_to_words(text: str) -> str: + res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) + res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) + res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) + return res + + +def japanese_convert_alpha_symbols_to_words(text: str) -> str: + return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()]) + + +def japanese_text_to_phonemes(text: str) -> str: + """Convert Japanese text to phonemes.""" + res = unicodedata.normalize("NFKC", text) + res = japanese_convert_numbers_to_words(res) + res = japanese_convert_alpha_symbols_to_words(res) + res = text2kata(res) + res = kata2phoneme(res) + return res.replace(" ", "") diff --git a/viXTTS/TTS/tts/utils/text/korean/__init__.py b/viXTTS/TTS/tts/utils/text/korean/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/tts/utils/text/korean/ko_dictionary.py b/viXTTS/TTS/tts/utils/text/korean/ko_dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..9b739339c6750443e7362aab1762147eda33fb53 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/korean/ko_dictionary.py @@ -0,0 +1,44 @@ +# coding: utf-8 +# Add the word you want to the dictionary. +etc_dictionary = {"1+1": "원플러스원", "2+1": "투플러스원"} + + +english_dictionary = { + "KOREA": "코리아", + "IDOL": "아이돌", + "IT": "아이티", + "IQ": "아이큐", + "UP": "업", + "DOWN": "다운", + "PC": "피씨", + "CCTV": "씨씨티비", + "SNS": "에스엔에스", + "AI": "에이아이", + "CEO": "씨이오", + "A": "에이", + "B": "비", + "C": "씨", + "D": "디", + "E": "이", + "F": "에프", + "G": "지", + "H": "에이치", + "I": "아이", + "J": "제이", + "K": "케이", + "L": "엘", + "M": "엠", + "N": "엔", + "O": "오", + "P": "피", + "Q": "큐", + "R": "알", + "S": "에스", + "T": "티", + "U": "유", + "V": "브이", + "W": "더블유", + "X": "엑스", + "Y": "와이", + "Z": "제트", +} diff --git a/viXTTS/TTS/tts/utils/text/korean/korean.py b/viXTTS/TTS/tts/utils/text/korean/korean.py new file mode 100644 index 0000000000000000000000000000000000000000..423aeed377e3ba2d22b574418aa45714c322f670 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/korean/korean.py @@ -0,0 +1,32 @@ +# coding: utf-8 +# Code based on https://github.com/carpedm20/multi-speaker-tacotron-tensorflow/blob/master/text/korean.py +import re + +from TTS.tts.utils.text.korean.ko_dictionary import english_dictionary, etc_dictionary + + +def normalize(text): + text = text.strip() + text = re.sub("[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text) + text = normalize_with_dictionary(text, etc_dictionary) + text = normalize_english(text) + text = text.lower() + return text + + +def normalize_with_dictionary(text, dic): + if any(key in text for key in dic.keys()): + pattern = re.compile("|".join(re.escape(key) for key in dic.keys())) + return pattern.sub(lambda x: dic[x.group()], text) + return text + + +def normalize_english(text): + def fn(m): + word = m.group() + if word in english_dictionary: + return english_dictionary.get(word) + return word + + text = re.sub("([A-Za-z]+)", fn, text) + return text diff --git a/viXTTS/TTS/tts/utils/text/korean/phonemizer.py b/viXTTS/TTS/tts/utils/text/korean/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2c69217c403d491de2d78c8253c8eca9ce83bef5 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/korean/phonemizer.py @@ -0,0 +1,36 @@ +from jamo import hangul_to_jamo + +from TTS.tts.utils.text.korean.korean import normalize + +g2p = None + + +def korean_text_to_phonemes(text, character: str = "hangeul") -> str: + """ + + The input and output values look the same, but they are different in Unicode. + + example : + + input = '하늘' (Unicode : \ud558\ub298), (하 + 늘) + output = '하늘' (Unicode :\u1112\u1161\u1102\u1173\u11af), (ᄒ + ᅡ + ᄂ + ᅳ + ᆯ) + + """ + global g2p # pylint: disable=global-statement + if g2p is None: + from g2pkk import G2p + + g2p = G2p() + + if character == "english": + from anyascii import anyascii + + text = normalize(text) + text = g2p(text) + text = anyascii(text) + return text + + text = normalize(text) + text = g2p(text) + text = list(hangul_to_jamo(text)) # '하늘' --> ['ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆯ'] + return "".join(text) diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/__init__.py b/viXTTS/TTS/tts/utils/text/phonemizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a0340c55577bc173107f40c39342009def7276 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/__init__.py @@ -0,0 +1,79 @@ +from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer +from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut +from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer +from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer + +try: + from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer +except ImportError: + JA_JP_Phonemizer = None + pass + +PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut, KO_KR_Phonemizer, BN_Phonemizer)} + + +ESPEAK_LANGS = list(ESpeak.supported_languages().keys()) +GRUUT_LANGS = list(Gruut.supported_languages()) + + +# Dict setting default phonemizers for each language +# Add Gruut languages +_ = [Gruut.name()] * len(GRUUT_LANGS) +DEF_LANG_TO_PHONEMIZER = dict(list(zip(GRUUT_LANGS, _))) + + +# Add ESpeak languages and override any existing ones +_ = [ESpeak.name()] * len(ESPEAK_LANGS) +_new_dict = dict(list(zip(list(ESPEAK_LANGS), _))) +DEF_LANG_TO_PHONEMIZER.update(_new_dict) + + +# Force default for some languages +DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"] +DEF_LANG_TO_PHONEMIZER["zh-cn"] = ZH_CN_Phonemizer.name() +DEF_LANG_TO_PHONEMIZER["ko-kr"] = KO_KR_Phonemizer.name() +DEF_LANG_TO_PHONEMIZER["bn"] = BN_Phonemizer.name() +DEF_LANG_TO_PHONEMIZER["be"] = BEL_Phonemizer.name() + + +# JA phonemizer has deal breaking dependencies like MeCab for some systems. +# So we only have it when we have it. +if JA_JP_Phonemizer is not None: + PHONEMIZERS[JA_JP_Phonemizer.name()] = JA_JP_Phonemizer + DEF_LANG_TO_PHONEMIZER["ja-jp"] = JA_JP_Phonemizer.name() + + +def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: + """Initiate a phonemizer by name + + Args: + name (str): + Name of the phonemizer that should match `phonemizer.name()`. + + kwargs (dict): + Extra keyword arguments that should be passed to the phonemizer. + """ + if name == "espeak": + return ESpeak(**kwargs) + if name == "gruut": + return Gruut(**kwargs) + if name == "zh_cn_phonemizer": + return ZH_CN_Phonemizer(**kwargs) + if name == "ja_jp_phonemizer": + if JA_JP_Phonemizer is None: + raise ValueError(" ❗ You need to install JA phonemizer dependencies. Try `pip install TTS[ja]`.") + return JA_JP_Phonemizer(**kwargs) + if name == "ko_kr_phonemizer": + return KO_KR_Phonemizer(**kwargs) + if name == "bn_phonemizer": + return BN_Phonemizer(**kwargs) + if name == "be_phonemizer": + return BEL_Phonemizer(**kwargs) + raise ValueError(f"Phonemizer {name} not found") + + +if __name__ == "__main__": + print(DEF_LANG_TO_PHONEMIZER) diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py b/viXTTS/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4a35bbfa4fc63a7cbb2a2caa152ef59b4afa63 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py @@ -0,0 +1,62 @@ +from typing import Dict + +from TTS.tts.utils.text.bangla.phonemizer import bangla_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class BN_Phonemizer(BasePhonemizer): + """🐸TTS bn phonemizer using functions in `TTS.tts.utils.text.bangla.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + + Example :: + + "这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。` + + TODO: someone with Bangla knowledge should check this implementation + """ + + language = "bn" + + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "bn_phonemizer" + + @staticmethod + def phonemize_bn(text: str, separator: str = "|") -> str: # pylint: disable=unused-argument + ph = bangla_text_to_phonemes(text) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_bn(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"bn": "Bangla"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + txt = "রাসূলুল্লাহ সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম শিক্ষা দিয়েছেন যে, কেউ যদি কোন খারাপ কিছুর সম্মুখীন হয়, তখনও যেন বলে." + e = BN_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print("`" + e.phonemize(txt) + "`") diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/base.py b/viXTTS/TTS/tts/utils/text/phonemizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc79874159aad40ce27f34671907a25d7e39a94 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/base.py @@ -0,0 +1,140 @@ +import abc +from typing import List, Tuple + +from TTS.tts.utils.text.punctuation import Punctuation + + +class BasePhonemizer(abc.ABC): + """Base phonemizer class + + Phonemization follows the following steps: + 1. Preprocessing: + - remove empty lines + - remove punctuation + - keep track of punctuation marks + + 2. Phonemization: + - convert text to phonemes + + 3. Postprocessing: + - join phonemes + - restore punctuation marks + + Args: + language (str): + Language used by the phonemizer. + + punctuations (List[str]): + List of punctuation marks to be preserved. + + keep_puncs (bool): + Whether to preserve punctuation marks or not. + """ + + def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False): + # ensure the backend is installed on the system + if not self.is_available(): + raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover + + # ensure the backend support the requested language + self._language = self._init_language(language) + + # setup punctuation processing + self._keep_puncs = keep_puncs + self._punctuator = Punctuation(punctuations) + + def _init_language(self, language): + """Language initialization + + This method may be overloaded in child classes (see Segments backend) + + """ + if not self.is_supported_language(language): + raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend") + return language + + @property + def language(self): + """The language code configured to be used for phonemization""" + return self._language + + @staticmethod + @abc.abstractmethod + def name(): + """The name of the backend""" + ... + + @classmethod + @abc.abstractmethod + def is_available(cls): + """Returns True if the backend is installed, False otherwise""" + ... + + @classmethod + @abc.abstractmethod + def version(cls): + """Return the backend version as a tuple (major, minor, patch)""" + ... + + @staticmethod + @abc.abstractmethod + def supported_languages(): + """Return a dict of language codes -> name supported by the backend""" + ... + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return language in self.supported_languages() + + @abc.abstractmethod + def _phonemize(self, text, separator): + """The main phonemization method""" + + def _phonemize_preprocess(self, text) -> Tuple[List[str], List]: + """Preprocess the text before phonemization + + 1. remove spaces + 2. remove punctuation + + Override this if you need a different behaviour + """ + text = text.strip() + if self._keep_puncs: + # a tuple (text, punctuation marks) + return self._punctuator.strip_to_restore(text) + return [self._punctuator.strip(text)], [] + + def _phonemize_postprocess(self, phonemized, punctuations) -> str: + """Postprocess the raw phonemized output + + Override this if you need a different behaviour + """ + if self._keep_puncs: + return self._punctuator.restore(phonemized, punctuations)[0] + return phonemized[0] + + def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument + """Returns the `text` phonemized for the given language + + Args: + text (str): + Text to be phonemized. + + separator (str): + string separator used between phonemes. Default to '_'. + + Returns: + (str): Phonemized text + """ + text, punctuations = self._phonemize_preprocess(text) + phonemized = [] + for t in text: + p = self._phonemize(t, separator) + phonemized.append(p) + phonemized = self._phonemize_postprocess(phonemized, punctuations) + return phonemized + + def print_logs(self, level: int = 0): + indent = "\t" * level + print(f"{indent}| > phoneme language: {self.language}") + print(f"{indent}| > phoneme backend: {self.name()}") diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py b/viXTTS/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e5fcab6e09b7e9fad06c381ebc7239e36c4cb8db --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py @@ -0,0 +1,55 @@ +from typing import Dict + +from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_BE_PUNCS = ",!." # TODO + + +class BEL_Phonemizer(BasePhonemizer): + """🐸TTS be phonemizer using functions in `TTS.tts.utils.text.belarusian.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_BE_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + """ + + language = "be" + + def __init__(self, punctuations=_DEF_BE_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "be_phonemizer" + + @staticmethod + def phonemize_be(text: str, separator: str = "|") -> str: # pylint: disable=unused-argument + return belarusian_text_to_phonemes(text) + + def _phonemize(self, text, separator): + return self.phonemize_be(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"be": "Belarusian"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + txt = "тэст" + e = BEL_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print("`" + e.phonemize(txt) + "`") diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/viXTTS/TTS/tts/utils/text/phonemizers/espeak_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..328e52f369e641f92d6c550a9ccebf60dbff9a26 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -0,0 +1,264 @@ +import logging +import re +import subprocess +from typing import Dict, List + +from packaging.version import Version + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + + +def is_tool(name): + from shutil import which + + return which(name) is not None + + +# Use a regex pattern to match the espeak version, because it may be +# symlinked to espeak-ng, which moves the version bits to another spot. +espeak_version_pattern = re.compile(r"text-to-speech:\s(?P\d+\.\d+(\.\d+)?)") + + +def get_espeak_version(): + output = subprocess.getoutput("espeak --version") + match = espeak_version_pattern.search(output) + + return match.group("version") + + +def get_espeakng_version(): + output = subprocess.getoutput("espeak-ng --version") + return output.split()[3] + + +# priority: espeakng > espeak +if is_tool("espeak-ng"): + _DEF_ESPEAK_LIB = "espeak-ng" + _DEF_ESPEAK_VER = get_espeakng_version() +elif is_tool("espeak"): + _DEF_ESPEAK_LIB = "espeak" + _DEF_ESPEAK_VER = get_espeak_version() +else: + _DEF_ESPEAK_LIB = None + _DEF_ESPEAK_VER = None + + +def _espeak_exe(espeak_lib: str, args: List, sync=False) -> List[str]: + """Run espeak with the given arguments.""" + cmd = [ + espeak_lib, + "-q", + "-b", + "1", # UTF8 text encoding + ] + cmd.extend(args) + logging.debug("espeakng: executing %s", repr(cmd)) + + with subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) as p: + res = iter(p.stdout.readline, b"") + if not sync: + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + return res + res2 = [] + for line in res: + res2.append(line) + p.stdout.close() + if p.stderr: + p.stderr.close() + if p.stdin: + p.stdin.close() + p.wait() + return res2 + + +class ESpeak(BasePhonemizer): + """ESpeak wrapper calling `espeak` or `espeak-ng` from the command-line the perform G2P + + Args: + language (str): + Valid language code for the used backend. + + backend (str): + Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically + prefering `espeak-ng` over `espeak`. Defaults to None. + + punctuations (str): + Characters to be treated as punctuation. Defaults to Punctuation.default_puncs(). + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to True. + + Example: + + >>> from TTS.tts.utils.text.phonemizers import ESpeak + >>> phonemizer = ESpeak("tr") + >>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|") + 'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.' + + """ + + _ESPEAK_LIB = _DEF_ESPEAK_LIB + _ESPEAK_VER = _DEF_ESPEAK_VER + + def __init__(self, language: str, backend=None, punctuations=Punctuation.default_puncs(), keep_puncs=True): + if self._ESPEAK_LIB is None: + raise Exception(" [!] No espeak backend found. Install espeak-ng or espeak to your system.") + self.backend = self._ESPEAK_LIB + + # band-aid for backwards compatibility + if language == "en": + language = "en-us" + if language == "zh-cn": + language = "cmn" + + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + if backend is not None: + self.backend = backend + + @property + def backend(self): + return self._ESPEAK_LIB + + @property + def backend_version(self): + return self._ESPEAK_VER + + @backend.setter + def backend(self, backend): + if backend not in ["espeak", "espeak-ng"]: + raise Exception("Unknown backend: %s" % backend) + self._ESPEAK_LIB = backend + self._ESPEAK_VER = get_espeakng_version() if backend == "espeak-ng" else get_espeak_version() + + def auto_set_espeak_lib(self) -> None: + if is_tool("espeak-ng"): + self._ESPEAK_LIB = "espeak-ng" + self._ESPEAK_VER = get_espeakng_version() + elif is_tool("espeak"): + self._ESPEAK_LIB = "espeak" + self._ESPEAK_VER = get_espeak_version() + else: + raise Exception("Cannot set backend automatically. espeak-ng or espeak not found") + + @staticmethod + def name(): + return "espeak" + + def phonemize_espeak(self, text: str, separator: str = "|", tie=False) -> str: + """Convert input text to phonemes. + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + # set arguments + args = ["-v", f"{self._language}"] + # espeak and espeak-ng parses `ipa` differently + if tie: + # use '͡' between phonemes + if self.backend == "espeak": + args.append("--ipa=1") + else: + args.append("--ipa=3") + else: + # split with '_' + if self.backend == "espeak": + if Version(self.backend_version) >= Version("1.48.15"): + args.append("--ipa=1") + else: + args.append("--ipa=3") + else: + args.append("--ipa=1") + if tie: + args.append("--tie=%s" % tie) + + args.append(text) + # compute phonemes + phonemes = "" + for line in _espeak_exe(self._ESPEAK_LIB, args, sync=True): + logging.debug("line: %s", repr(line)) + ph_decoded = line.decode("utf8").strip() + # espeak: + # version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # espeak-ng: + # "p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + + # espeak-ng backend can add language flags that need to be removed: + # "sɛʁtˈɛ̃ mˈo kɔm (en)fˈʊtbɔːl(fr) ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ." + # phonemize needs to remove the language flags of the returned text: + # "sɛʁtˈɛ̃ mˈo kɔm fˈʊtbɔːl ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ." + ph_decoded = re.sub(r"\(.+?\)", "", ph_decoded) + + phonemes += ph_decoded.strip() + return phonemes.replace("_", separator) + + def _phonemize(self, text, separator=None): + return self.phonemize_espeak(text, separator, tie=False) + + @staticmethod + def supported_languages() -> Dict: + """Get a dictionary of supported languages. + + Returns: + Dict: Dictionary of language codes. + """ + if _DEF_ESPEAK_LIB is None: + return {} + args = ["--voices"] + langs = {} + count = 0 + for line in _espeak_exe(_DEF_ESPEAK_LIB, args, sync=True): + line = line.decode("utf8").strip() + if count > 0: + cols = line.split() + lang_code = cols[1] + lang_name = cols[3] + langs[lang_code] = lang_name + logging.debug("line: %s", repr(line)) + count += 1 + return langs + + def version(self) -> str: + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + args = ["--version"] + for line in _espeak_exe(self.backend, args, sync=True): + version = line.decode("utf8").strip().split()[2] + logging.debug("line: %s", repr(line)) + return version + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return is_tool("espeak") or is_tool("espeak-ng") + + +if __name__ == "__main__": + e = ESpeak(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = ESpeak(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = ESpeak(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how are you today?") + "`") diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/viXTTS/TTS/tts/utils/text/phonemizers/gruut_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e9c9abd4c41935ed07ec10ed883d75b42a6bc8 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/gruut_wrapper.py @@ -0,0 +1,151 @@ +import importlib +from typing import List + +import gruut +from gruut_ipa import IPA + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + +# Table for str.translate to fix gruut/TTS phoneme mismatch +GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") + + +class Gruut(BasePhonemizer): + """Gruut wrapper for G2P + + Args: + language (str): + Valid language code for the used backend. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`. + + keep_puncs (bool): + If true, keep the punctuations after phonemization. Defaults to True. + + use_espeak_phonemes (bool): + If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False. + + keep_stress (bool): + If true, keep the stress characters after phonemization. Defaults to False. + + Example: + + >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + >>> phonemizer = Gruut('en-us') + >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|") + 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?' + """ + + def __init__( + self, + language: str, + punctuations=Punctuation.default_puncs(), + keep_puncs=True, + use_espeak_phonemes=False, + keep_stress=False, + ): + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + self.use_espeak_phonemes = use_espeak_phonemes + self.keep_stress = keep_stress + + @staticmethod + def name(): + return "gruut" + + def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument + """Convert input text to phonemes. + + Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters + that constitude a single sound. + + It doesn't affect 🐸TTS since it individually converts each character to token IDs. + + Examples:: + "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ` + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + ph_list = [] + for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes): + for word in sentence: + if word.is_break: + # Use actual character for break phoneme (e.g., comma) + if ph_list: + # Join with previous word + ph_list[-1].append(word.text) + else: + # First word is punctuation + ph_list.append([word.text]) + elif word.phonemes: + # Add phonemes for word + word_phonemes = [] + + for word_phoneme in word.phonemes: + if not self.keep_stress: + # Remove primary/secondary stress + word_phoneme = IPA.without_stress(word_phoneme) + + word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE) + + if word_phoneme: + # Flatten phonemes + word_phonemes.extend(word_phoneme) + + if word_phonemes: + ph_list.append(word_phonemes) + + ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list] + ph = f"{separator} ".join(ph_words) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_gruut(text, separator, tie=False) + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return gruut.is_language_supported(language) + + @staticmethod + def supported_languages() -> List: + """Get a dictionary of supported languages. + + Returns: + List: List of language codes. + """ + return list(gruut.get_supported_languages()) + + def version(self): + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + return gruut.__version__ + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return importlib.util.find_spec("gruut") is not None + + +if __name__ == "__main__": + e = Gruut(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = Gruut(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = Gruut(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how, are you today?") + "`") diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/viXTTS/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..878e5e52969f740ae74d875c91294666095e89ac --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -0,0 +1,72 @@ +from typing import Dict + +from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】" + +_TRANS_TABLE = {"、": ","} + + +def trans(text): + for i, j in _TRANS_TABLE.items(): + text = text.replace(i, j) + return text + + +class JA_JP_Phonemizer(BasePhonemizer): + """🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer` + + TODO: someone with JA knowledge should check this implementation + + Example: + + >>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer + >>> phonemizer = JA_JP_Phonemizer() + >>> phonemizer.phonemize("どちらに行きますか?", separator="|") + 'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?' + + """ + + language = "ja-jp" + + def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "ja_jp_phonemizer" + + def _phonemize(self, text: str, separator: str = "|") -> str: + ph = japanese_text_to_phonemes(text) + if separator is not None or separator != "": + return separator.join(ph) + return ph + + def phonemize(self, text: str, separator="|", language=None) -> str: + """Custom phonemize for JP_JA + + Skip pre-post processing steps used by the other phonemizers. + """ + return self._phonemize(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"ja-jp": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "これは、電話をかけるための私の日本語の例のテキストです。" +# e = JA_JP_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py b/viXTTS/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdba2137b039025491a08d8b8d52454b82a72fd --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py @@ -0,0 +1,65 @@ +from typing import Dict + +from TTS.tts.utils.text.korean.phonemizer import korean_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_KO_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class KO_KR_Phonemizer(BasePhonemizer): + """🐸TTS ko_kr_phonemizer using functions in `TTS.tts.utils.text.korean.phonemizer` + + TODO: Add Korean to character (ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑ하ᅢᅣᅤᅥᅦᅧᅨᅩᅪᅫᅬᅭᅮᅯᅰᅱᅲᅳᅴᅵᆨᆩᆪᆫᆬᆭᆮᆯᆰᆱᆲᆳᆴᆵᆶᆷᆸᆹᆺᆻᆼᆽᆾᆿᇀᇁᇂ) + + Example: + + >>> from TTS.tts.utils.text.phonemizers import KO_KR_Phonemizer + >>> phonemizer = KO_KR_Phonemizer() + >>> phonemizer.phonemize("이 문장은 음성합성 테스트를 위한 문장입니다.", separator="|") + 'ᄋ|ᅵ| |ᄆ|ᅮ|ᆫ|ᄌ|ᅡ|ᆼ|ᄋ|ᅳ| |ᄂ|ᅳ|ᆷ|ᄉ|ᅥ|ᆼ|ᄒ|ᅡ|ᆸ|ᄊ|ᅥ|ᆼ| |ᄐ|ᅦ|ᄉ|ᅳ|ᄐ|ᅳ|ᄅ|ᅳ| |ᄅ|ᅱ|ᄒ|ᅡ|ᆫ| |ᄆ|ᅮ|ᆫ|ᄌ|ᅡ|ᆼ|ᄋ|ᅵ|ᆷ|ᄂ|ᅵ|ᄃ|ᅡ|.' + + >>> from TTS.tts.utils.text.phonemizers import KO_KR_Phonemizer + >>> phonemizer = KO_KR_Phonemizer() + >>> phonemizer.phonemize("이 문장은 음성합성 테스트를 위한 문장입니다.", separator="|", character='english') + 'I| |M|u|n|J|a|n|g|E|u| |N|e|u|m|S|e|o|n|g|H|a|b|S|s|e|o|n|g| |T|e|S|e|u|T|e|u|L|e|u| |L|w|i|H|a|n| |M|u|n|J|a|n|g|I|m|N|i|D|a|.' + + """ + + language = "ko-kr" + + def __init__(self, punctuations=_DEF_KO_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "ko_kr_phonemizer" + + def _phonemize(self, text: str, separator: str = "", character: str = "hangeul") -> str: + ph = korean_text_to_phonemes(text, character=character) + if separator is not None or separator != "": + return separator.join(ph) + return ph + + def phonemize(self, text: str, separator: str = "", character: str = "hangeul", language=None) -> str: + return self._phonemize(text, separator, character) + + @staticmethod + def supported_languages() -> Dict: + return {"ko-kr": "hangeul(korean)"} + + def version(self) -> str: + return "0.0.2" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + texts = "이 문장은 음성합성 테스트를 위한 문장입니다." + e = KO_KR_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print(e.phonemize(texts)) diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/viXTTS/TTS/tts/utils/text/phonemizers/multi_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..62a9c39322e051124d8c2d816cbc7d479df69dfe --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -0,0 +1,65 @@ +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name + + +class MultiPhonemizer: + """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages + + Args: + custom_lang_to_phonemizer (Dict): + Custom phonemizer mapping if you want to change the defaults. In the format of + `{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`. + + TODO: find a way to pass custom kwargs to the phonemizers + """ + + lang_to_phonemizer = {} + + def __init__(self, lang_to_phonemizer_name: Dict = {}) -> None: # pylint: disable=dangerous-default-value + for k, v in lang_to_phonemizer_name.items(): + if v == "" and k in DEF_LANG_TO_PHONEMIZER.keys(): + lang_to_phonemizer_name[k] = DEF_LANG_TO_PHONEMIZER[k] + elif v == "": + raise ValueError(f"Phonemizer wasn't set for language {k} and doesn't have a default.") + self.lang_to_phonemizer_name = lang_to_phonemizer_name + self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name) + + @staticmethod + def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict: + lang_to_phonemizer = {} + for k, v in lang_to_phonemizer_name.items(): + lang_to_phonemizer[k] = get_phonemizer_by_name(v, language=k) + return lang_to_phonemizer + + @staticmethod + def name(): + return "multi-phonemizer" + + def phonemize(self, text, separator="|", language=""): + if language == "": + raise ValueError("Language must be set for multi-phonemizer to phonemize.") + return self.lang_to_phonemizer[language].phonemize(text, separator) + + def supported_languages(self) -> List: + return list(self.lang_to_phonemizer.keys()) + + def print_logs(self, level: int = 0): + indent = "\t" * level + print(f"{indent}| > phoneme language: {self.supported_languages()}") + print(f"{indent}| > phoneme backend: {self.name()}") + + +# if __name__ == "__main__": +# texts = { +# "tr": "Merhaba, bu Türkçe bit örnek!", +# "en-us": "Hello, this is English example!", +# "de": "Hallo, das ist ein Deutches Beipiel!", +# "zh-cn": "这是中国的例子", +# } +# phonemes = {} +# ph = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""}) +# for lang, text in texts.items(): +# phoneme = ph.phonemize(text, lang) +# phonemes[lang] = phoneme +# print(phonemes) diff --git a/viXTTS/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py b/viXTTS/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..41480c417356fd941e71e3eff0099eb38ac7296a --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py @@ -0,0 +1,62 @@ +from typing import Dict + +from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class ZH_CN_Phonemizer(BasePhonemizer): + """🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + + Example :: + + "这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。` + + TODO: someone with Mandarin knowledge should check this implementation + """ + + language = "zh-cn" + + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "zh_cn_phonemizer" + + @staticmethod + def phonemize_zh_cn(text: str, separator: str = "|") -> str: + ph = chinese_text_to_phonemes(text, separator) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_zh_cn(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"zh-cn": "Chinese (China)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "这是,样本中文。" +# e = ZH_CN_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/viXTTS/TTS/tts/utils/text/punctuation.py b/viXTTS/TTS/tts/utils/text/punctuation.py new file mode 100644 index 0000000000000000000000000000000000000000..36c467d08335ea24ad1c0e6705e0c351f7fb8d65 --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/punctuation.py @@ -0,0 +1,171 @@ +import collections +import re +from enum import Enum + +import six + +_DEF_PUNCS = ';:,.!?¡¿—…"«»“”' + +_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"]) + + +class PuncPosition(Enum): + """Enum for the punctuations positions""" + + BEGIN = 0 + END = 1 + MIDDLE = 2 + + +class Punctuation: + """Handle punctuations in text. + + Just strip punctuations from text or strip and restore them later. + + Args: + puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`. + + Example: + >>> punc = Punctuation() + >>> punc.strip("This is. example !") + 'This is example' + + >>> text_striped, punc_map = punc.strip_to_restore("This is. example !") + >>> ' '.join(text_striped) + 'This is example' + + >>> text_restored = punc.restore(text_striped, punc_map) + >>> text_restored[0] + 'This is. example !' + """ + + def __init__(self, puncs: str = _DEF_PUNCS): + self.puncs = puncs + + @staticmethod + def default_puncs(): + """Return default set of punctuations.""" + return _DEF_PUNCS + + @property + def puncs(self): + return self._puncs + + @puncs.setter + def puncs(self, value): + if not isinstance(value, six.string_types): + raise ValueError("[!] Punctuations must be of type str.") + self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder + self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+") + + def strip(self, text): + """Remove all the punctuations by replacing with `space`. + + Args: + text (str): The text to be processed. + + Example:: + + "This is. example !" -> "This is example " + """ + return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip() + + def strip_to_restore(self, text): + """Remove punctuations from text to restore them later. + + Args: + text (str): The text to be processed. + + Examples :: + + "This is. example !" -> [["This is", "example"], [".", "!"]] + + """ + text, puncs = self._strip_to_restore(text) + return text, puncs + + def _strip_to_restore(self, text): + """Auxiliary method for Punctuation.preserve()""" + matches = list(re.finditer(self.puncs_regular_exp, text)) + if not matches: + return [text], [] + # the text is only punctuations + if len(matches) == 1 and matches[0].group() == text: + return [], [_PUNC_IDX(text, PuncPosition.BEGIN)] + # build a punctuation map to be used later to restore punctuations + puncs = [] + for match in matches: + position = PuncPosition.MIDDLE + if match == matches[0] and text.startswith(match.group()): + position = PuncPosition.BEGIN + elif match == matches[-1] and text.endswith(match.group()): + position = PuncPosition.END + puncs.append(_PUNC_IDX(match.group(), position)) + # convert str text to a List[str], each item is separated by a punctuation + splitted_text = [] + for idx, punc in enumerate(puncs): + split = text.split(punc.punc) + prefix, suffix = split[0], punc.punc.join(split[1:]) + text = suffix + if prefix == "": + # We don't want to insert an empty string in case of initial punctuation + continue + splitted_text.append(prefix) + # if the text does not end with a punctuation, add it to the last item + if idx == len(puncs) - 1 and len(suffix) > 0: + splitted_text.append(suffix) + return splitted_text, puncs + + @classmethod + def restore(cls, text, puncs): + """Restore punctuation in a text. + + Args: + text (str): The text to be processed. + puncs (List[str]): The list of punctuations map to be used for restoring. + + Examples :: + + ['This is', 'example'], ['.', '!'] -> "This is. example!" + + """ + return cls._restore(text, puncs) + + @classmethod + def _restore(cls, text, puncs): # pylint: disable=too-many-return-statements + """Auxiliary method for Punctuation.restore()""" + if not puncs: + return text + + # nothing have been phonemized, returns the puncs alone + if not text: + return ["".join(m.punc for m in puncs)] + + current = puncs[0] + + if current.position == PuncPosition.BEGIN: + return cls._restore([current.punc + text[0]] + text[1:], puncs[1:]) + + if current.position == PuncPosition.END: + return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:]) + + # POSITION == MIDDLE + if len(text) == 1: # pragma: nocover + # a corner case where the final part of an intermediate + # mark (I) has not been phonemized + return cls._restore([text[0] + current.punc], puncs[1:]) + + return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:]) + + +# if __name__ == "__main__": +# punc = Punctuation() +# text = "This is. This is, example!" + +# print(punc.strip(text)) + +# split_text, puncs = punc.strip_to_restore(text) +# print(split_text, " ---- ", puncs) + +# restored_text = punc.restore(split_text, puncs) +# print(restored_text) diff --git a/viXTTS/TTS/tts/utils/text/tokenizer.py b/viXTTS/TTS/tts/utils/text/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b7faf86e8a3120ee39171de0caa40bbc85614ddb --- /dev/null +++ b/viXTTS/TTS/tts/utils/text/tokenizer.py @@ -0,0 +1,216 @@ +from typing import Callable, Dict, List, Union + +from TTS.tts.utils.text import cleaners +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer +from TTS.utils.generic_utils import get_import_path, import_class + + +class TTSTokenizer: + """🐸TTS tokenizer to convert input characters to token IDs and back. + + Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later. + + Args: + use_phonemes (bool): + Whether to use phonemes instead of characters. Defaults to False. + + characters (Characters): + A Characters object to use for character-to-ID and ID-to-character mappings. + + text_cleaner (callable): + A function to pre-process the text before tokenization and phonemization. Defaults to None. + + phonemizer (Phonemizer): + A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None. + + Example: + + >>> from TTS.tts.utils.text.tokenizer import TTSTokenizer + >>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) + >>> text = "Hello world!" + >>> ids = tokenizer.text_to_ids(text) + >>> text_hat = tokenizer.ids_to_text(ids) + >>> assert text == text_hat + """ + + def __init__( + self, + use_phonemes=False, + text_cleaner: Callable = None, + characters: "BaseCharacters" = None, + phonemizer: Union["Phonemizer", Dict] = None, + add_blank: bool = False, + use_eos_bos=False, + ): + self.text_cleaner = text_cleaner + self.use_phonemes = use_phonemes + self.add_blank = add_blank + self.use_eos_bos = use_eos_bos + self.characters = characters + self.not_found_characters = [] + self.phonemizer = phonemizer + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, new_characters): + self._characters = new_characters + self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None + self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None + + def encode(self, text: str) -> List[int]: + """Encodes a string of text as a sequence of IDs.""" + token_ids = [] + for char in text: + try: + idx = self.characters.char_to_id(char) + token_ids.append(idx) + except KeyError: + # discard but store not found characters + if char not in self.not_found_characters: + self.not_found_characters.append(char) + print(text) + print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.") + return token_ids + + def decode(self, token_ids: List[int]) -> str: + """Decodes a sequence of IDs to a string of text.""" + text = "" + for token_id in token_ids: + text += self.characters.id_to_char(token_id) + return text + + def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument + """Converts a string of text to a sequence of token IDs. + + Args: + text(str): + The text to convert to token IDs. + + language(str): + The language code of the text. Defaults to None. + + TODO: + - Add support for language-specific processing. + + 1. Text normalizatin + 2. Phonemization (if use_phonemes is True) + 3. Add blank char between characters + 4. Add BOS and EOS characters + 5. Text to token IDs + """ + # TODO: text cleaner should pick the right routine based on the language + if self.text_cleaner is not None: + text = self.text_cleaner(text) + if self.use_phonemes: + text = self.phonemizer.phonemize(text, separator="", language=language) + text = self.encode(text) + if self.add_blank: + text = self.intersperse_blank_char(text, True) + if self.use_eos_bos: + text = self.pad_with_bos_eos(text) + return text + + def ids_to_text(self, id_sequence: List[int]) -> str: + """Converts a sequence of token IDs to a string of text.""" + return self.decode(id_sequence) + + def pad_with_bos_eos(self, char_sequence: List[str]): + """Pads a sequence with the special BOS and EOS characters.""" + return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id] + + def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): + """Intersperses the blank character between characters in a sequence. + + Use the ```blank``` character if defined else use the ```pad``` character. + """ + char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad + result = [char_to_use] * (len(char_sequence) * 2 + 1) + result[1::2] = char_sequence + return result + + def print_logs(self, level: int = 0): + indent = "\t" * level + print(f"{indent}| > add_blank: {self.add_blank}") + print(f"{indent}| > use_eos_bos: {self.use_eos_bos}") + print(f"{indent}| > use_phonemes: {self.use_phonemes}") + if self.use_phonemes: + print(f"{indent}| > phonemizer:") + self.phonemizer.print_logs(level + 1) + if len(self.not_found_characters) > 0: + print(f"{indent}| > {len(self.not_found_characters)} not found characters:") + for char in self.not_found_characters: + print(f"{indent}| > {char}") + + @staticmethod + def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): + """Init Tokenizer object from config + + Args: + config (Coqpit): Coqpit model config. + characters (BaseCharacters): Defines the model character set. If not set, use the default options based on + the config values. Defaults to None. + """ + # init cleaners + text_cleaner = None + if isinstance(config.text_cleaner, (str, list)): + text_cleaner = getattr(cleaners, config.text_cleaner) + + # init characters + if characters is None: + # set characters based on defined characters class + if config.characters and config.characters.characters_class: + CharactersClass = import_class(config.characters.characters_class) + characters, new_config = CharactersClass.init_from_config(config) + # set characters based on config + else: + if config.use_phonemes: + # init phoneme set + characters, new_config = IPAPhonemes().init_from_config(config) + else: + # init character set + characters, new_config = Graphemes().init_from_config(config) + + else: + characters, new_config = characters.init_from_config(config) + + # set characters class + new_config.characters.characters_class = get_import_path(characters) + + # init phonemizer + phonemizer = None + if config.use_phonemes: + if "phonemizer" in config and config.phonemizer == "multi_phonemizer": + lang_to_phonemizer_name = {} + for dataset in config.datasets: + if dataset.language != "": + lang_to_phonemizer_name[dataset.language] = dataset.phonemizer + else: + raise ValueError("Multi phonemizer requires language to be set for each dataset.") + phonemizer = MultiPhonemizer(lang_to_phonemizer_name) + else: + phonemizer_kwargs = {"language": config.phoneme_language} + if "phonemizer" in config and config.phonemizer: + phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) + else: + try: + phonemizer = get_phonemizer_by_name( + DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs + ) + new_config.phonemizer = phonemizer.name() + except KeyError as e: + raise ValueError( + f"""No phonemizer found for language {config.phoneme_language}. + You may need to install a third party library for this language.""" + ) from e + + return ( + TTSTokenizer( + config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars + ), + new_config, + ) diff --git a/viXTTS/TTS/tts/utils/visual.py b/viXTTS/TTS/tts/utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..fba7bc508ef962d5a93c794ca868acd46d07ec16 --- /dev/null +++ b/viXTTS/TTS/tts/utils/visual.py @@ -0,0 +1,238 @@ +import librosa +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.colors import LogNorm + +matplotlib.use("Agg") + + +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False, plot_log=False): + if isinstance(alignment, torch.Tensor): + alignment_ = alignment.detach().cpu().numpy().squeeze() + else: + alignment_ = alignment + alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_ + fig, ax = plt.subplots(figsize=fig_size) + im = ax.imshow( + alignment_.T, aspect="auto", origin="lower", interpolation="none", norm=LogNorm() if plot_log else None + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + # plt.yticks(range(len(text)), list(text)) + plt.tight_layout() + if title is not None: + plt.title(title) + if not output_fig: + plt.close() + return fig + + +def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + fig = plt.figure(figsize=fig_size) + plt.imshow(spectrogram_, aspect="auto", origin="lower") + plt.colorbar() + plt.tight_layout() + if not output_fig: + plt.close() + return fig + + +def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the spectrogram. + + Args: + pitch (np.array): Pitch values. + spectrogram (np.array): Spectrogram values. + + Shapes: + pitch: :math:`(T,)` + spec: :math:`(C, T)` + """ + + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + ax.imshow(spectrogram_, aspect="auto", origin="lower") + ax.set_xlabel("time") + ax.set_ylabel("spec_freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the input characters. + + Args: + pitch (np.array): Pitch values. + chars (str): Characters to place to the x-axis. + + Shapes: + pitch: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = chars + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def plot_avg_energy(energy, chars, fig_size=(30, 10), output_fig=False): + """Plot energy curves on top of the input characters. + + Args: + energy (np.array): energy values. + chars (str): Characters to place to the x-axis. + + Shapes: + energy: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = chars + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(energy, linewidth=5.0, color="red") + ax2.set_ylabel("energy") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def visualize( + alignment, + postnet_output, + text, + hop_length, + CONFIG, + tokenizer, + stop_tokens=None, + decoder_output=None, + output_path=None, + figsize=(8, 24), + output_fig=False, +): + """Intended to be used in Notebooks.""" + + if decoder_output is not None: + num_plot = 4 + else: + num_plot = 3 + + label_fontsize = 16 + fig = plt.figure(figsize=figsize) + + plt.subplot(num_plot, 1, 1) + plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) + plt.xlabel("Decoder timestamp", fontsize=label_fontsize) + plt.ylabel("Encoder timestamp", fontsize=label_fontsize) + # compute phoneme representation and back + if CONFIG.use_phonemes: + seq = tokenizer.text_to_ids(text) + text = tokenizer.ids_to_text(seq) + print(text) + plt.yticks(range(len(text)), list(text)) + plt.colorbar() + + if stop_tokens is not None: + # plot stopnet predictions + plt.subplot(num_plot, 1, 2) + plt.plot(range(len(stop_tokens)), list(stop_tokens)) + + # plot postnet spectrogram + plt.subplot(num_plot, 1, 3) + librosa.display.specshow( + postnet_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) + + plt.xlabel("Time", fontsize=label_fontsize) + plt.ylabel("Hz", fontsize=label_fontsize) + plt.tight_layout() + plt.colorbar() + + if decoder_output is not None: + plt.subplot(num_plot, 1, 4) + librosa.display.specshow( + decoder_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) + plt.xlabel("Time", fontsize=label_fontsize) + plt.ylabel("Hz", fontsize=label_fontsize) + plt.tight_layout() + plt.colorbar() + + if output_path: + print(output_path) + fig.savefig(output_path) + plt.close() + + if not output_fig: + plt.close() diff --git a/viXTTS/TTS/utils/__init__.py b/viXTTS/TTS/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/utils/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a23d8b47f609f67e71b24603492ee26593605e82 Binary files /dev/null and b/viXTTS/TTS/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/utils/__pycache__/generic_utils.cpython-310.pyc b/viXTTS/TTS/utils/__pycache__/generic_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb7bf52d236b2a111c9d4a684514164608d66308 Binary files /dev/null and b/viXTTS/TTS/utils/__pycache__/generic_utils.cpython-310.pyc differ diff --git a/viXTTS/TTS/utils/__pycache__/io.cpython-310.pyc b/viXTTS/TTS/utils/__pycache__/io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08d981c0298ae255ddb9aafa93b2f471c7ad55b2 Binary files /dev/null and b/viXTTS/TTS/utils/__pycache__/io.cpython-310.pyc differ diff --git a/viXTTS/TTS/utils/audio/__init__.py b/viXTTS/TTS/utils/audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f18f22199908ee0dd5445e34527f5fddb65cfed8 --- /dev/null +++ b/viXTTS/TTS/utils/audio/__init__.py @@ -0,0 +1 @@ +from TTS.utils.audio.processor import AudioProcessor diff --git a/viXTTS/TTS/utils/audio/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/utils/audio/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90999f45fc8cf3857117533b3d5740a7b7e9897a Binary files /dev/null and b/viXTTS/TTS/utils/audio/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/utils/audio/__pycache__/numpy_transforms.cpython-310.pyc b/viXTTS/TTS/utils/audio/__pycache__/numpy_transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aaae46c9a51397cb1293bfe2059fb8d4ae1cff9 Binary files /dev/null and b/viXTTS/TTS/utils/audio/__pycache__/numpy_transforms.cpython-310.pyc differ diff --git a/viXTTS/TTS/utils/audio/__pycache__/processor.cpython-310.pyc b/viXTTS/TTS/utils/audio/__pycache__/processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50ff965fced2376c1e8501e5b912e85860bac52b Binary files /dev/null and b/viXTTS/TTS/utils/audio/__pycache__/processor.cpython-310.pyc differ diff --git a/viXTTS/TTS/utils/audio/__pycache__/torch_transforms.cpython-310.pyc b/viXTTS/TTS/utils/audio/__pycache__/torch_transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3bb150089336f253e54f841de742f79e0858a31 Binary files /dev/null and b/viXTTS/TTS/utils/audio/__pycache__/torch_transforms.cpython-310.pyc differ diff --git a/viXTTS/TTS/utils/audio/numpy_transforms.py b/viXTTS/TTS/utils/audio/numpy_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..af88569fc3653f054bc2320ba414fb8528cc209a --- /dev/null +++ b/viXTTS/TTS/utils/audio/numpy_transforms.py @@ -0,0 +1,485 @@ +from io import BytesIO +from typing import Tuple + +import librosa +import numpy as np +import scipy +import soundfile as sf +from librosa import magphase, pyin + +# For using kwargs +# pylint: disable=unused-argument + + +def build_mel_basis( + *, + sample_rate: int = None, + fft_size: int = None, + num_mels: int = None, + mel_fmax: int = None, + mel_fmin: int = None, + **kwargs, +) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if mel_fmax is not None: + assert mel_fmax <= sample_rate // 2 + assert mel_fmax - mel_fmin > 0 + return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax) + + +def millisec_to_length( + *, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs +) -> Tuple[int, int]: + """Compute hop and window length from milliseconds. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = frame_length_ms / frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + win_length = int(frame_length_ms / 1000.0 * sample_rate) + hop_length = int(win_length / float(factor)) + return win_length, hop_length + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) + + +def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Decibels spectrogram. + """ + assert (x < 0).sum() == 0, " [!] Input values must be non-negative." + return gain * _log(np.maximum(1e-8, x), base) + + +# pylint: disable=no-self-use +def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / gain, base) + + +def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -coef], [1], x) + + +def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray: + """Reverse pre-emphasis.""" + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -coef], x) + + +def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + spec (np.ndarray): Normalized full scale linear spectrogram. + + Shapes: + - spec: :math:`[C, T]` + + Returns: + np.ndarray: Normalized melspectrogram. + """ + return np.dot(mel_basis, spec) + + +def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + assert (mel < 0).sum() == 0, " [!] Input values must be non-negative." + inv_mel_basis = np.linalg.pinv(mel_basis) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel)) + + +def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + wav (np.ndarray): Waveform. Shape :math:`[T_wav,]` + + Returns: + np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length` + """ + D = stft(y=wav, **kwargs) + S = np.abs(D) + return S.astype(np.float32) + + +def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + D = stft(y=wav, **kwargs) + S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs) + return S.astype(np.float32) + + +def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = spec.copy() + return griffin_lim(spec=S**power, **kwargs) + + +def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + S = mel.copy() + S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear + return griffin_lim(spec=S**power, **kwargs) + + +### STFT and ISTFT ### +def stft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + pad_mode: str = "reflect", + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa STFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.stft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=fft_size, + hop_length=hop_length, + win_length=win_length, + pad_mode=pad_mode, + window=window, + center=center, + ) + + +def istft( + *, + y: np.ndarray = None, + hop_length: int = None, + win_length: int = None, + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa iSTFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.istft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window) + + +def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray: + angles = np.exp(2j * np.pi * np.random.rand(*spec.shape)) + S_complex = np.abs(spec).astype(complex) + y = istft(y=S_complex * angles, **kwargs) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(num_iter): + angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) + y = istft(y=S_complex * angles, **kwargs) + return y + + +def compute_stft_paddings( + *, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs +) -> Tuple[int, int]: + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0] + if not pad_two_sides: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + +def compute_f0( + *, + x: np.ndarray = None, + pitch_fmax: float = None, + pitch_fmin: float = None, + hop_length: int = None, + win_length: int = None, + sample_rate: int = None, + stft_pad_mode: str = "reflect", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + pitch_fmax (float): Pitch max value. + pitch_fmin (float): Pitch min value. + hop_length (int): Number of frames between STFT columns. + win_length (int): STFT window length. + sample_rate (int): Audio sampling rate. + stft_pad_mode (str): Padding mode for STFT. + center (bool): Centered padding. + + Returns: + np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length` + + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> pitch = ap.compute_f0(wav) + """ + assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." + assert pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`." + + f0, voiced_mask, _ = pyin( + y=x.astype(np.double), + fmin=pitch_fmin, + fmax=pitch_fmax, + sr=sample_rate, + frame_length=win_length, + win_length=win_length // 2, + hop_length=hop_length, + pad_mode=stft_pad_mode, + center=center, + n_thresholds=100, + beta_parameters=(2, 18), + boltzmann_parameter=2, + resolution=0.1, + max_transition_rate=35.92, + switch_prob=0.01, + no_trough_prob=0.01, + ) + f0[~voiced_mask] = 0.0 + + return f0 + + +def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray: + """Compute energy of a waveform using the same parameters used for computing melspectrogram. + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + Returns: + np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length` + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig() + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> energy = ap.compute_energy(wav) + """ + x = stft(y=y, **kwargs) + mag, _ = magphase(x) + energy = np.sqrt(np.sum(mag**2, axis=0)) + return energy + + +### Audio Processing ### +def find_endpoint( + *, + wav: np.ndarray = None, + trim_db: float = -40, + sample_rate: int = None, + min_silence_sec=0.8, + gain: float = None, + base: int = None, + **kwargs, +) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None. + base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10. + + Returns: + int: Last point without silence. + """ + window_length = int(sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = db_to_amp(x=-trim_db, gain=gain, base=base) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + +def trim_silence( + *, + wav: np.ndarray = None, + sample_rate: int = None, + trim_db: float = None, + win_length: int = None, + hop_length: int = None, + **kwargs, +) -> np.ndarray: + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0] + + +def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + coef (float): Coefficient to rescale the maximum value. Defaults to 0.95. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * coef + + +def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray: + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) + return wav * a + + +def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + db_level (float): Target dB level in RMS. Defaults to -27.0. + + Returns: + np.ndarray: RMS normalized waveform. + """ + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = rms_norm(wav=x, db_level=db_level) + return wav + + +def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False. + + Returns: + np.ndarray: Loaded waveform. + """ + if resample: + # loading with resampling. It is significantly slower. + x, _ = librosa.load(filename, sr=sample_rate) + else: + # SF is faster than librosa for loading files + x, _ = sf.read(filename) + return x + + +def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None: + """Save float waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform with float values in range [-1, 1] to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + + wav_norm = wav_norm.astype(np.int16) + if pipe_out: + wav_buffer = BytesIO() + scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm) + wav_buffer.seek(0) + pipe_out.buffer.write(wav_buffer.read()) + scipy.io.wavfile.write(path, sample_rate, wav_norm) + + +def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray: + mu = 2**mulaw_qc - 1 + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + +def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray: + """Recovers waveform from quantized values.""" + mu = 2**mulaw_qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + +def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray: + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + +def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + quantize_bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2**quantize_bits - 1) / 2 + + +def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray: + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2**quantize_bits - 1) - 1 diff --git a/viXTTS/TTS/utils/audio/processor.py b/viXTTS/TTS/utils/audio/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c53bad562ea62e3f49a8c5406e26a9f9014135ef --- /dev/null +++ b/viXTTS/TTS/utils/audio/processor.py @@ -0,0 +1,633 @@ +from io import BytesIO +from typing import Dict, Tuple + +import librosa +import numpy as np +import scipy.io.wavfile +import scipy.signal + +from TTS.tts.utils.helpers import StandardScaler +from TTS.utils.audio.numpy_transforms import ( + amp_to_db, + build_mel_basis, + compute_f0, + db_to_amp, + deemphasis, + find_endpoint, + griffin_lim, + load_wav, + mel_to_spec, + millisec_to_length, + preemphasis, + rms_volume_norm, + spec_to_mel, + stft, + trim_silence, + volume_norm, +) + +# pylint: disable=too-many-public-methods + + +class AudioProcessor(object): + """Audio Processor for TTS. + + Note: + All the class arguments are set to default values to enable a flexible initialization + of the class with the model config. They are not meaningful for all the arguments. + + Args: + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + resample (bool, optional): + enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. + + num_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + log_func (int, optional): + log exponent used for converting spectrogram aplitude to DB. + + min_level_db (int, optional): + minimum db threshold for the computed melspectrograms. Defaults to None. + + frame_shift_ms (int, optional): + milliseconds of frames between STFT columns. Defaults to None. + + frame_length_ms (int, optional): + milliseconds of STFT window length. Defaults to None. + + hop_length (int, optional): + number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. + + win_length (int, optional): + STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. + + ref_level_db (int, optional): + reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. + + fft_size (int, optional): + FFT window size for STFT. Defaults to 1024. + + power (int, optional): + Exponent value applied to the spectrogram before GriffinLim. Defaults to None. + + preemphasis (float, optional): + Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. + + signal_norm (bool, optional): + enable/disable signal normalization. Defaults to None. + + symmetric_norm (bool, optional): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. + + max_norm (float, optional): + ```k``` defining the normalization range. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + pitch_fmin (int, optional): + minimum filter frequency for computing pitch. Defaults to None. + + pitch_fmax (int, optional): + maximum filter frequency for computing pitch. Defaults to None. + + spec_gain (int, optional): + gain applied when converting amplitude to DB. Defaults to 20. + + stft_pad_mode (str, optional): + Padding mode for STFT. Defaults to 'reflect'. + + clip_norm (bool, optional): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + griffin_lim_iters (int, optional): + Number of GriffinLim iterations. Defaults to None. + + do_trim_silence (bool, optional): + enable/disable silence trimming when loading the audio signal. Defaults to False. + + trim_db (int, optional): + DB threshold used for silence trimming. Defaults to 60. + + do_sound_norm (bool, optional): + enable/disable signal normalization. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + stats_path (str, optional): + Path to the computed stats file. Defaults to None. + + verbose (bool, optional): + enable/disable logging. Defaults to True. + + """ + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + pitch_fmax=None, + pitch_fmin=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + do_amp_to_db_linear=True, + do_amp_to_db_mel=True, + do_rms_norm=False, + db_level=None, + stats_path=None, + verbose=True, + **_, + ): + # setup class attributed + self.sample_rate = sample_rate + self.resample = resample + self.num_mels = num_mels + self.log_func = log_func + self.min_level_db = min_level_db or 0 + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.ref_level_db = ref_level_db + self.fft_size = fft_size + self.power = power + self.preemphasis = preemphasis + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = mel_fmin or 0 + self.mel_fmax = mel_fmax + self.pitch_fmin = pitch_fmin + self.pitch_fmax = pitch_fmax + self.spec_gain = float(spec_gain) + self.stft_pad_mode = stft_pad_mode + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + self.trim_db = trim_db + self.do_sound_norm = do_sound_norm + self.do_amp_to_db_linear = do_amp_to_db_linear + self.do_amp_to_db_mel = do_amp_to_db_mel + self.do_rms_norm = do_rms_norm + self.db_level = db_level + self.stats_path = stats_path + # setup exp_func for db to amp conversion + if log_func == "np.log": + self.base = np.e + elif log_func == "np.log10": + self.base = 10 + else: + raise ValueError(" [!] unknown `log_func` value.") + # setup stft parameters + if hop_length is None: + # compute stft parameters from given time values + self.win_length, self.hop_length = millisec_to_length( + frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate + ) + else: + # use stft parameters from config file + self.hop_length = hop_length + self.win_length = win_length + assert min_level_db != 0.0, " [!] min_level_db is 0" + assert ( + self.win_length <= self.fft_size + ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" + members = vars(self) + if verbose: + print(" > Setting up Audio Processor...") + for key, value in members.items(): + print(" | > {}:{}".format(key, value)) + # create spectrogram utils + self.mel_basis = build_mel_basis( + sample_rate=self.sample_rate, + fft_size=self.fft_size, + num_mels=self.num_mels, + mel_fmax=self.mel_fmax, + mel_fmin=self.mel_fmin, + ) + # setup scaler + if stats_path and signal_norm: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None + + @staticmethod + def init_from_config(config: "Coqpit", verbose=True): + if "audio" in config: + return AudioProcessor(verbose=verbose, **config.audio) + return AudioProcessor(verbose=verbose, **config) + + ### normalization ### + def normalize(self, S: np.ndarray) -> np.ndarray: + """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` + + Args: + S (np.ndarray): Spectrogram to normalize. + + Raises: + RuntimeError: Mean and variance is computed from incompatible parameters. + + Returns: + np.ndarray: Normalized spectrogram. + """ + # pylint: disable=no-else-return + S = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.fft_size / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + # range normalization + S -= self.ref_level_db # discard certain range of DB assuming it is air noise + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def denormalize(self, S: np.ndarray) -> np.ndarray: + """Denormalize spectrogram values. + + Args: + S (np.ndarray): Spectrogram to denormalize. + + Raises: + RuntimeError: Mean and variance are incompatible. + + Returns: + np.ndarray: Denormalized spectrogram. + """ + # pylint: disable=no-else-return + S_denorm = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.fft_size / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + self.ref_level_db + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db + return S_denorm + self.ref_level_db + else: + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: + """Loading mean and variance statistics from a `npy` file. + + Args: + stats_path (str): Path to the `npy` file containing + + Returns: + Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to + compute them. + """ + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] + # check all audio parameters used for computing stats + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] + for key in stats_config.keys(): + if key in skip_parameters: + continue + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler( + self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray + ) -> None: + """Initialize scaler objects used in mean-std normalization. + + Args: + mel_mean (np.ndarray): Mean for melspectrograms. + mel_std (np.ndarray): STD for melspectrograms. + linear_mean (np.ndarray): Mean for full scale spectrograms. + linear_std (np.ndarray): STD for full scale spectrograms. + """ + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) + + ### Preemphasis ### + def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + return preemphasis(x=x, coef=self.preemphasis) + + def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Reverse pre-emphasis.""" + return deemphasis(x=x, coef=self.preemphasis) + + ### SPECTROGRAMs ### + def spectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + if self.preemphasis != 0: + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + if self.do_amp_to_db_linear: + S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base) + else: + S = np.abs(D) + return self.normalize(S).astype(np.float32) + + def melspectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + if self.preemphasis != 0: + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis) + if self.do_amp_to_db_mel: + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + + return self.normalize(S).astype(np.float32) + + def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = self.denormalize(spectrogram) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + # Reconstruct phase + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W + + def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + D = self.denormalize(mel_spectrogram) + S = db_to_amp(x=D, gain=self.spec_gain, base=self.base) + S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W + + def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + S = self.denormalize(linear_spec) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis) + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + mel = self.normalize(S) + return mel + + def _griffin_lim(self, S): + return griffin_lim( + spec=S, + num_iter=self.griffin_lim_iters, + hop_length=self.hop_length, + win_length=self.win_length, + fft_size=self.fft_size, + pad_mode=self.stft_pad_mode, + ) + + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> pitch = ap.compute_f0(wav) + """ + # align F0 length to the spectrogram length + if len(x) % self.hop_length == 0: + x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode) + + f0 = compute_f0( + x=x, + pitch_fmax=self.pitch_fmax, + pitch_fmin=self.pitch_fmin, + hop_length=self.hop_length, + win_length=self.win_length, + sample_rate=self.sample_rate, + stft_pad_mode=self.stft_pad_mode, + center=True, + ) + + return f0 + + ### Audio Processing ### + def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + return find_endpoint( + wav=wav, + trim_db=self.trim_db, + sample_rate=self.sample_rate, + min_silence_sec=min_silence_sec, + gain=self.spec_gain, + base=self.base, + ) + + def trim_silence(self, wav): + """Trim silent parts with a threshold and 0.01 sec margin""" + return trim_silence( + wav=wav, + sample_rate=self.sample_rate, + trim_db=self.trim_db, + win_length=self.win_length, + hop_length=self.hop_length, + ) + + @staticmethod + def sound_norm(x: np.ndarray) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return volume_norm(x=x) + + def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: RMS normalized waveform. + """ + if db_level is None: + db_level = self.db_level + return rms_volume_norm(x=x, db_level=db_level) + + ### save and load ### + def load_wav(self, filename: str, sr: int = None) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + + Returns: + np.ndarray: Loaded waveform. + """ + if sr is not None: + x = load_wav(filename=filename, sample_rate=sr, resample=True) + else: + x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + print(f" [!] File cannot be trimmed for silence - {filename}") + if self.do_sound_norm: + x = self.sound_norm(x) + if self.do_rms_norm: + x = self.rms_volume_norm(x, self.db_level) + return x + + def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None: + """Save a waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + if self.do_rms_norm: + wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767 + else: + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + + wav_norm = wav_norm.astype(np.int16) + if pipe_out: + wav_buffer = BytesIO() + scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm) + wav_buffer.seek(0) + pipe_out.buffer.write(wav_buffer.read()) + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm) + + def get_duration(self, filename: str) -> float: + """Get the duration of a wav file using Librosa. + + Args: + filename (str): Path to the wav file. + """ + return librosa.get_duration(filename=filename) diff --git a/viXTTS/TTS/utils/audio/torch_transforms.py b/viXTTS/TTS/utils/audio/torch_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..fd40ebb048b915a836ba0d84dc22054d23b1d886 --- /dev/null +++ b/viXTTS/TTS/utils/audio/torch_transforms.py @@ -0,0 +1,165 @@ +import librosa +import torch +from torch import nn + + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", + normalized=False, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + self.normalized = normalized + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=self.normalized, + onesided=True, + return_complex=False, + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + + if self.power is not None: + S = S**self.power + + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + sr=self.sample_rate, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain diff --git a/viXTTS/TTS/utils/callbacks.py b/viXTTS/TTS/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..511d215c656f1ce3ed31484963db64fae4dc77d4 --- /dev/null +++ b/viXTTS/TTS/utils/callbacks.py @@ -0,0 +1,105 @@ +class TrainerCallback: + @staticmethod + def on_init_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_start"): + trainer.model.module.on_init_start(trainer) + else: + if hasattr(trainer.model, "on_init_start"): + trainer.model.on_init_start(trainer) + + if hasattr(trainer.criterion, "on_init_start"): + trainer.criterion.on_init_start(trainer) + + if hasattr(trainer.optimizer, "on_init_start"): + trainer.optimizer.on_init_start(trainer) + + @staticmethod + def on_init_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_end"): + trainer.model.module.on_init_end(trainer) + else: + if hasattr(trainer.model, "on_init_end"): + trainer.model.on_init_end(trainer) + + if hasattr(trainer.criterion, "on_init_end"): + trainer.criterion.on_init_end(trainer) + + if hasattr(trainer.optimizer, "on_init_end"): + trainer.optimizer.on_init_end(trainer) + + @staticmethod + def on_epoch_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_start"): + trainer.model.module.on_epoch_start(trainer) + else: + if hasattr(trainer.model, "on_epoch_start"): + trainer.model.on_epoch_start(trainer) + + if hasattr(trainer.criterion, "on_epoch_start"): + trainer.criterion.on_epoch_start(trainer) + + if hasattr(trainer.optimizer, "on_epoch_start"): + trainer.optimizer.on_epoch_start(trainer) + + @staticmethod + def on_epoch_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_end"): + trainer.model.module.on_epoch_end(trainer) + else: + if hasattr(trainer.model, "on_epoch_end"): + trainer.model.on_epoch_end(trainer) + + if hasattr(trainer.criterion, "on_epoch_end"): + trainer.criterion.on_epoch_end(trainer) + + if hasattr(trainer.optimizer, "on_epoch_end"): + trainer.optimizer.on_epoch_end(trainer) + + @staticmethod + def on_train_step_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_start"): + trainer.model.module.on_train_step_start(trainer) + else: + if hasattr(trainer.model, "on_train_step_start"): + trainer.model.on_train_step_start(trainer) + + if hasattr(trainer.criterion, "on_train_step_start"): + trainer.criterion.on_train_step_start(trainer) + + if hasattr(trainer.optimizer, "on_train_step_start"): + trainer.optimizer.on_train_step_start(trainer) + + @staticmethod + def on_train_step_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_end"): + trainer.model.module.on_train_step_end(trainer) + else: + if hasattr(trainer.model, "on_train_step_end"): + trainer.model.on_train_step_end(trainer) + + if hasattr(trainer.criterion, "on_train_step_end"): + trainer.criterion.on_train_step_end(trainer) + + if hasattr(trainer.optimizer, "on_train_step_end"): + trainer.optimizer.on_train_step_end(trainer) + + @staticmethod + def on_keyboard_interrupt(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_keyboard_interrupt"): + trainer.model.module.on_keyboard_interrupt(trainer) + else: + if hasattr(trainer.model, "on_keyboard_interrupt"): + trainer.model.on_keyboard_interrupt(trainer) + + if hasattr(trainer.criterion, "on_keyboard_interrupt"): + trainer.criterion.on_keyboard_interrupt(trainer) + + if hasattr(trainer.optimizer, "on_keyboard_interrupt"): + trainer.optimizer.on_keyboard_interrupt(trainer) diff --git a/viXTTS/TTS/utils/capacitron_optimizer.py b/viXTTS/TTS/utils/capacitron_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7206ffd508896cab96a22288f33a93e999c5f009 --- /dev/null +++ b/viXTTS/TTS/utils/capacitron_optimizer.py @@ -0,0 +1,67 @@ +from typing import Generator + +from trainer.trainer_utils import get_optimizer + + +class CapacitronOptimizer: + """Double optimizer class for the Capacitron model.""" + + def __init__(self, config: dict, model_params: Generator) -> None: + self.primary_params, self.secondary_params = self.split_model_parameters(model_params) + + optimizer_names = list(config.optimizer_params.keys()) + optimizer_parameters = list(config.optimizer_params.values()) + + self.primary_optimizer = get_optimizer( + optimizer_names[0], + optimizer_parameters[0], + config.lr, + parameters=self.primary_params, + ) + + self.secondary_optimizer = get_optimizer( + optimizer_names[1], + self.extract_optimizer_parameters(optimizer_parameters[1]), + optimizer_parameters[1]["lr"], + parameters=self.secondary_params, + ) + + self.param_groups = self.primary_optimizer.param_groups + + def first_step(self): + self.secondary_optimizer.step() + self.secondary_optimizer.zero_grad() + self.primary_optimizer.zero_grad() + + def step(self): + # Update param groups to display the correct learning rate + self.param_groups = self.primary_optimizer.param_groups + self.primary_optimizer.step() + + def zero_grad(self, set_to_none=False): + self.primary_optimizer.zero_grad(set_to_none) + self.secondary_optimizer.zero_grad(set_to_none) + + def load_state_dict(self, state_dict): + self.primary_optimizer.load_state_dict(state_dict[0]) + self.secondary_optimizer.load_state_dict(state_dict[1]) + + def state_dict(self): + return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()] + + @staticmethod + def split_model_parameters(model_params: Generator) -> list: + primary_params = [] + secondary_params = [] + for name, param in model_params: + if param.requires_grad: + if name == "capacitron_vae_layer.beta": + secondary_params.append(param) + else: + primary_params.append(param) + return [iter(primary_params), iter(secondary_params)] + + @staticmethod + def extract_optimizer_parameters(params: dict) -> dict: + """Extract parameters that are not the learning rate""" + return {k: v for k, v in params.items() if k != "lr"} diff --git a/viXTTS/TTS/utils/distribute.py b/viXTTS/TTS/utils/distribute.py new file mode 100644 index 0000000000000000000000000000000000000000..a51ef7661ece97c87c165ad1aba4c9d9700379dc --- /dev/null +++ b/viXTTS/TTS/utils/distribute.py @@ -0,0 +1,20 @@ +# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py +import torch +import torch.distributed as dist + + +def reduce_tensor(tensor, num_gpus): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.reduce_op.SUM) + rt /= num_gpus + return rt + + +def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): + assert torch.cuda.is_available(), "Distributed mode requires CUDA." + + # Set cuda device so everything is done on the right GPU. + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Initialize distributed communication + dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) diff --git a/viXTTS/TTS/utils/download.py b/viXTTS/TTS/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..3f06b578248441d9951bf0ee62e2764ffd9ff9d7 --- /dev/null +++ b/viXTTS/TTS/utils/download.py @@ -0,0 +1,206 @@ +# Adapted from https://github.com/pytorch/audio/ + +import hashlib +import logging +import os +import tarfile +import urllib +import urllib.request +import zipfile +from os.path import expanduser +from typing import Any, Iterable, List, Optional + +from torch.utils.model_zoo import tqdm + + +def stream_url( + url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True +) -> Iterable: + """Stream url by chunk + + Args: + url (str): Url. + start_byte (int or None, optional): Start streaming at that point (Default: ``None``). + block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + """ + + # If we already have the whole file, there is no need to download it again + req = urllib.request.Request(url, method="HEAD") + with urllib.request.urlopen(req) as response: + url_size = int(response.info().get("Content-Length", -1)) + if url_size == start_byte: + return + + req = urllib.request.Request(url) + if start_byte: + req.headers["Range"] = "bytes={}-".format(start_byte) + + with urllib.request.urlopen(req) as upointer, tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=url_size, + disable=not progress_bar, + ) as pbar: + num_bytes = 0 + while True: + chunk = upointer.read(block_size) + if not chunk: + break + yield chunk + num_bytes += len(chunk) + pbar.update(len(chunk)) + + +def download_url( + url: str, + download_folder: str, + filename: Optional[str] = None, + hash_value: Optional[str] = None, + hash_type: str = "sha256", + progress_bar: bool = True, + resume: bool = False, +) -> None: + """Download file to disk. + + Args: + url (str): Url. + download_folder (str): Folder to download file. + filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url + (Default: ``None``). + hash_value (str or None, optional): Hash for url (Default: ``None``). + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + resume (bool, optional): Enable resuming download (Default: ``False``). + """ + + req = urllib.request.Request(url, method="HEAD") + req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with + + # Detect filename + filename = filename or req_info.get_filename() or os.path.basename(url) + filepath = os.path.join(download_folder, filename) + if resume and os.path.exists(filepath): + mode = "ab" + local_size: Optional[int] = os.path.getsize(filepath) + + elif not resume and os.path.exists(filepath): + raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath)) + else: + mode = "wb" + local_size = None + + if hash_value and local_size == int(req_info.get("Content-Length", -1)): + with open(filepath, "rb") as file_obj: + if validate_file(file_obj, hash_value, hash_type): + return + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + with open(filepath, mode) as fpointer: + for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): + fpointer.write(chunk) + + with open(filepath, "rb") as file_obj: + if hash_value and not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + +def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + + Returns: + bool: return True if its a valid file, else False. + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = file_obj.read(1024**2) + if not chunk: + break + hash_func.update(chunk) + + return hash_func.hexdigest() == hash_value + + +def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: + """Extract archive. + Args: + from_path (str): the path of the archive. + to_path (str or None, optional): the root path of the extraced files (directory of from_path) + (Default: ``None``) + overwrite (bool, optional): overwrite existing files (Default: ``False``) + + Returns: + list: List of paths to extracted files even if not overwritten. + """ + + if to_path is None: + to_path = os.path.dirname(from_path) + + try: + with tarfile.open(from_path, "r") as tar: + logging.info("Opened tar file %s.", from_path) + files = [] + for file_ in tar: # type: Any + file_path = os.path.join(to_path, file_.name) + if file_.isfile(): + files.append(file_path) + if os.path.exists(file_path): + logging.info("%s already extracted.", file_path) + if not overwrite: + continue + tar.extract(file_, to_path) + return files + except tarfile.ReadError: + pass + + try: + with zipfile.ZipFile(from_path, "r") as zfile: + logging.info("Opened zip file %s.", from_path) + files = zfile.namelist() + for file_ in files: + file_path = os.path.join(to_path, file_) + if os.path.exists(file_path): + logging.info("%s already extracted.", file_path) + if not overwrite: + continue + zfile.extract(file_, to_path) + return files + except zipfile.BadZipFile: + pass + + raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.") + + +def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str): + """Download dataset from kaggle. + Args: + dataset_path (str): + This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning' + dataset_name (str): Name of the folder the dataset will be saved in. + output_path (str): Path of the location you want the dataset folder to be saved to. + """ + data_path = os.path.join(output_path, dataset_name) + try: + import kaggle # pylint: disable=import-outside-toplevel + + kaggle.api.authenticate() + print(f"""\nDownloading {dataset_name}...""") + kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) + except OSError: + print( + f"""[!] in order to download kaggle datasets, you need to have a kaggle api token stored in your {os.path.join(expanduser('~'), '.kaggle/kaggle.json')}""" + ) diff --git a/viXTTS/TTS/utils/downloaders.py b/viXTTS/TTS/utils/downloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..104dc7b94e17b1d7f828103d2396d6c5115b628a --- /dev/null +++ b/viXTTS/TTS/utils/downloaders.py @@ -0,0 +1,126 @@ +import os +from typing import Optional + +from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive + + +def download_ljspeech(path: str): + """Download and extract LJSpeech dataset + + Args: + path (str): path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_vctk(path: str, use_kaggle: Optional[bool] = False): + """Download and extract VCTK dataset. + + Args: + path (str): path to the directory where the dataset will be stored. + + use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False. + """ + if use_kaggle: + download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path) + else: + os.makedirs(path, exist_ok=True) + url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_tweb(path: str): + """Download and extract Tweb dataset + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path) + + +def download_libri_tts(path: str, subset: Optional[str] = "all"): + """Download and extract libri tts dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + subset (str, optional): Name of the subset to download. If you only want to download a certain + portion specify it here. Defaults to 'all'. + """ + + subset_dict = { + "libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz", + "libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz", + "libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz", + "libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz", + "libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz", + "libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz", + "libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz", + } + + os.makedirs(path, exist_ok=True) + if subset == "all": + for sub, val in subset_dict.items(): + print(f" > Downloading {sub}...") + download_url(val, path) + basename = os.path.basename(val) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + print(" > All subsets downloaded") + else: + url = subset_dict[subset] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_thorsten_de(path: str): + """Download and extract Thorsten german male voice dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) + + +def download_mailabs(path: str, language: str = "english"): + """Download and extract Mailabs dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + language (str): Language subset to download. Defaults to english. + """ + language_dict = { + "english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz", + "german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz", + "french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz", + "italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz", + "spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz", + } + os.makedirs(path, exist_ok=True) + url = language_dict[language] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + print(" > Extracting archive file...") + extract_archive(archive) diff --git a/viXTTS/TTS/utils/generic_utils.py b/viXTTS/TTS/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa4741ab79686123605fa84f1f9d3863b340e44 --- /dev/null +++ b/viXTTS/TTS/utils/generic_utils.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +import datetime +import importlib +import logging +import os +import re +import subprocess +import sys +from pathlib import Path +from typing import Dict + +import fsspec +import torch + + +def to_cuda(x: torch.Tensor) -> torch.Tensor: + if x is None: + return None + if torch.is_tensor(x): + x = x.contiguous() + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) + return x + + +def get_cuda(): + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return use_cuda, device + + +def get_git_branch(): + try: + out = subprocess.check_output(["git", "branch"]).decode("utf8") + current = next(line for line in out.split("\n") if line.startswith("*")) + current.replace("* ", "") + except subprocess.CalledProcessError: + current = "inside_docker" + except (FileNotFoundError, StopIteration) as e: + current = "unknown" + return current + + +def get_commit_hash(): + """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" + # try: + # subprocess.check_output(['git', 'diff-index', '--quiet', + # 'HEAD']) # Verify client is clean + # except: + # raise RuntimeError( + # " !! Commit before training to get the commit hash.") + try: + commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() + # Not copying .git folder into docker container + except (subprocess.CalledProcessError, FileNotFoundError): + commit = "0000000" + return commit + + +def get_experiment_folder_path(root_path, model_name): + """Get an experiment folder path with the current date and time""" + date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") + commit_hash = get_commit_hash() + output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) + return output_folder + + +def remove_experiment_folder(experiment_path): + """Check folder if there is a checkpoint, otherwise remove the folder""" + fs = fsspec.get_mapper(experiment_path).fs + checkpoint_files = fs.glob(experiment_path + "/*.pth") + if not checkpoint_files: + if fs.exists(experiment_path): + fs.rm(experiment_path, recursive=True) + print(" ! Run is removed from {}".format(experiment_path)) + else: + print(" ! Run is kept in {}".format(experiment_path)) + + +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def to_camel(text): + text = text.capitalize() + text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + text = text.replace("Tts", "TTS") + text = text.replace("vc", "VC") + return text + + +def find_module(module_path: str, module_name: str) -> object: + module_name = module_name.lower() + module = importlib.import_module(module_path + "." + module_name) + class_name = to_camel(module_name) + return getattr(module, class_name) + + +def import_class(module_path: str) -> object: + """Import a class from a module path. + + Args: + module_path (str): The module path of the class. + + Returns: + object: The imported class. + """ + class_name = module_path.split(".")[-1] + module_path = ".".join(module_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_import_path(obj: object) -> str: + """Get the import path of a class. + + Args: + obj (object): The class object. + + Returns: + str: The import path of the class. + """ + return ".".join([type(obj).__module__, type(obj).__name__]) + + +def get_user_data_dir(appname): + TTS_HOME = os.environ.get("TTS_HOME") + XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME") + if TTS_HOME is not None: + ans = Path(TTS_HOME).expanduser().resolve(strict=False) + elif XDG_DATA_HOME is not None: + ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": + import winreg # pylint: disable=import-outside-toplevel + + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" + ) + dir_, _ = winreg.QueryValueEx(key, "Local AppData") + ans = Path(dir_).resolve(strict=False) + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() + else: + ans = Path.home().joinpath(".local/share") + return ans.joinpath(appname) + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + print(" | > Layer missing in the model definition: {}".format(k)) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) + return model_dict + + +def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: + """Format kwargs to hande auxilary inputs to models. + + Args: + def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. + kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. + + Returns: + Dict: arguments with formatted auxilary inputs. + """ + kwargs = kwargs.copy() + for name in def_args: + if name not in kwargs or kwargs[name] is None: + kwargs[name] = def_args[name] + return kwargs + + +class KeepAverage: + def __init__(self): + self.avg_values = {} + self.iters = {} + + def __getitem__(self, key): + return self.avg_values[key] + + def items(self): + return self.avg_values.items() + + def add_value(self, name, init_val=0, init_iter=0): + self.avg_values[name] = init_val + self.iters[name] = init_iter + + def update_value(self, name, value, weighted_avg=False): + if name not in self.avg_values: + # add value if not exist before + self.add_value(name, init_val=value) + else: + # else update existing value + if weighted_avg: + self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value + self.iters[name] += 1 + else: + self.avg_values[name] = self.avg_values[name] * self.iters[name] + value + self.iters[name] += 1 + self.avg_values[name] /= self.iters[name] + + def add_values(self, name_dict): + for key, value in name_dict.items(): + self.add_value(key, init_val=value) + + def update_values(self, value_dict): + for key, value in value_dict.items(): + self.update_value(key, value) + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + lg = logging.getLogger(logger_name) + formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") + lg.setLevel(level) + if tofile: + log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) + fh = logging.FileHandler(log_file, mode="w") + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) diff --git a/viXTTS/TTS/utils/io.py b/viXTTS/TTS/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..3107ba661babc57bb67f6b06c84a00f5a468b132 --- /dev/null +++ b/viXTTS/TTS/utils/io.py @@ -0,0 +1,70 @@ +import os +import pickle as pickle_tts +from typing import Any, Callable, Dict, Union + +import fsspec +import torch + +from TTS.utils.generic_utils import get_user_data_dir + + +class RenamingUnpickler(pickle_tts.Unpickler): + """Overload default pickler to solve module renaming problem""" + + def find_class(self, module, name): + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) + + +class AttrDict(dict): + """A custom dict which converts dict keys + to class attributes""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def load_fsspec( + path: str, + map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + cache: bool = True, + **kwargs, +) -> Any: + """Like torch.load but can load from other locations (e.g. s3:// , gs://). + + Args: + path: Any path or url supported by fsspec. + map_location: torch.device or str. + cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. + **kwargs: Keyword arguments forwarded to torch.load. + + Returns: + Object stored in path. + """ + is_local = os.path.isdir(path) or os.path.isfile(path) + if cache and not is_local: + with fsspec.open( + f"filecache::{path}", + filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, + mode="rb", + ) as f: + return torch.load(f, map_location=map_location, **kwargs) + else: + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location=map_location, **kwargs) + + +def load_checkpoint( + model, checkpoint_path, use_cuda=False, eval=False, cache=False +): # pylint: disable=redefined-builtin + try: + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) + model.load_state_dict(state["model"]) + if use_cuda: + model.cuda() + if eval: + model.eval() + return model, state diff --git a/viXTTS/TTS/utils/manage.py b/viXTTS/TTS/utils/manage.py new file mode 100644 index 0000000000000000000000000000000000000000..3a527f46099ce79cc08a70ccb1aa22aaf0f7c429 --- /dev/null +++ b/viXTTS/TTS/utils/manage.py @@ -0,0 +1,621 @@ +import json +import os +import re +import tarfile +import zipfile +from pathlib import Path +from shutil import copyfile, rmtree +from typing import Dict, List, Tuple + +import fsspec +import requests +from tqdm import tqdm + +from TTS.config import load_config, read_json_with_comments +from TTS.utils.generic_utils import get_user_data_dir + +LICENSE_URLS = { + "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", + "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/", + "mit": "https://choosealicense.com/licenses/mit/", + "apache 2.0": "https://choosealicense.com/licenses/apache-2.0/", + "apache2": "https://choosealicense.com/licenses/apache-2.0/", + "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/", + "cpml": "https://coqui.ai/cpml.txt", +} + + +class ModelManager(object): + tqdm_progress = None + """Manage TTS models defined in .models.json. + It provides an interface to list and download + models defines in '.model.json' + + Models are downloaded under '.TTS' folder in the user's + home path. + + Args: + models_file (str): path to .model.json file. Defaults to None. + output_prefix (str): prefix to `tts` to download models. Defaults to None + progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. + verbose (bool): print info. Defaults to True. + """ + + def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True): + super().__init__() + self.progress_bar = progress_bar + self.verbose = verbose + if output_prefix is None: + self.output_prefix = get_user_data_dir("tts") + else: + self.output_prefix = os.path.join(output_prefix, "tts") + self.models_dict = None + if models_file is not None: + self.read_models_file(models_file) + else: + # try the default location + path = Path(__file__).parent / "../.models.json" + self.read_models_file(path) + + def read_models_file(self, file_path): + """Read .models.json as a dict + + Args: + file_path (str): path to .models.json. + """ + self.models_dict = read_json_with_comments(file_path) + + def _list_models(self, model_type, model_count=0): + if self.verbose: + print("\n Name format: type/language/dataset/model") + model_list = [] + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + output_path = os.path.join(self.output_prefix, model_full_name) + if self.verbose: + if os.path.exists(output_path): + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") + else: + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + model_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + return model_list + + def _list_for_model_type(self, model_type): + models_name_list = [] + model_count = 1 + models_name_list.extend(self._list_models(model_type, model_count)) + return models_name_list + + def list_models(self): + models_name_list = [] + model_count = 1 + for model_type in self.models_dict: + model_list = self._list_models(model_type, model_count) + models_name_list.extend(model_list) + return models_name_list + + def model_info_by_idx(self, model_query): + """Print the description of the model from .models.json file using model_idx + + Args: + model_query (str): / + """ + model_name_list = [] + model_type, model_query_idx = model_query.split("/") + try: + model_query_idx = int(model_query_idx) + if model_query_idx <= 0: + print("> model_query_idx should be a positive integer!") + return + except: + print("> model_query_idx should be an integer!") + return + model_count = 0 + if model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + else: + print(f"> model_type {model_type} does not exist in the list.") + return + if model_query_idx > model_count: + print(f"model query idx exceeds the number of available models [{model_count}] ") + else: + model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") + print(f"> model type : {model_type}") + print(f"> language supported : {lang}") + print(f"> dataset used : {dataset}") + print(f"> model name : {model}") + if "description" in self.models_dict[model_type][lang][dataset][model]: + print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}") + else: + print("> description : coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}") + + def model_info_by_full_name(self, model_query_name): + """Print the description of the model from .models.json file using model_full_name + + Args: + model_query_name (str): Format is /// + """ + model_type, lang, dataset, model = model_query_name.split("/") + if model_type in self.models_dict: + if lang in self.models_dict[model_type]: + if dataset in self.models_dict[model_type][lang]: + if model in self.models_dict[model_type][lang][dataset]: + print(f"> model type : {model_type}") + print(f"> language supported : {lang}") + print(f"> dataset used : {dataset}") + print(f"> model name : {model}") + if "description" in self.models_dict[model_type][lang][dataset][model]: + print( + f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}" + ) + else: + print("> description : coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + print( + f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}" + ) + else: + print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.") + else: + print(f"> dataset {dataset} does not exist for {model_type}/{lang}.") + else: + print(f"> lang {lang} does not exist for {model_type}.") + else: + print(f"> model_type {model_type} does not exist in the list.") + + def list_tts_models(self): + """Print all `TTS` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("tts_models") + + def list_vocoder_models(self): + """Print all the `vocoder` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("vocoder_models") + + def list_vc_models(self): + """Print all the voice conversion models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("voice_conversion_models") + + def list_langs(self): + """Print all the available languages""" + print(" Name format: type/language") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + print(f" >: {model_type}/{lang} ") + + def list_datasets(self): + """Print all the datasets""" + print(" Name format: type/language/dataset") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + print(f" >: {model_type}/{lang}/{dataset}") + + @staticmethod + def print_model_license(model_item: Dict): + """Print the license of a model + + Args: + model_item (dict): model item in the models.json + """ + if "license" in model_item and model_item["license"].strip() != "": + print(f" > Model's license - {model_item['license']}") + if model_item["license"].lower() in LICENSE_URLS: + print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") + else: + print(" > Check https://opensource.org/licenses for more info.") + else: + print(" > Model's license - No license information available") + + def _download_github_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["github_rls_url"], list): + self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + + def _download_hf_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["hf_url"], list): + self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar) + + def download_fairseq_model(self, model_name, output_path): + URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/" + _, lang, _, _ = model_name.split("/") + model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") + self._download_tar_file(model_download_uri, output_path, self.progress_bar) + + @staticmethod + def set_model_url(model_item: Dict): + model_item["model_url"] = None + if "github_rls_url" in model_item: + model_item["model_url"] = model_item["github_rls_url"] + elif "hf_url" in model_item: + model_item["model_url"] = model_item["hf_url"] + elif "fairseq" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/" + elif "xtts" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/" + return model_item + + def _set_model_item(self, model_name): + # fetch model info from the dict + if "fairseq" in model_name: + model_type = "tts_models" + lang = model_name.split("/")[1] + model_item = { + "model_type": "tts_models", + "license": "CC BY-NC 4.0", + "default_vocoder": None, + "author": "fairseq", + "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", + } + model_item["model_name"] = model_name + elif "xtts" in model_name and len(model_name.split("/")) != 4: + # loading xtts models with only model name (e.g. xtts_v2.0.2) + # check model name has the version number with regex + version_regex = r"v\d+\.\d+\.\d+" + if re.search(version_regex, model_name): + model_version = model_name.split("_")[-1] + else: + model_version = "main" + model_type = "tts_models" + lang = "multilingual" + dataset = "multi-dataset" + model = model_name + model_item = { + "default_vocoder": None, + "license": "CPML", + "contact": "info@coqui.ai", + "tos_required": True, + "hf_url": [ + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth", + ], + } + else: + # get model from models.json + model_type, lang, dataset, model = model_name.split("/") + model_item = self.models_dict[model_type][lang][dataset][model] + model_item["model_type"] = model_type + + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + md5hash = model_item["model_hash"] if "model_hash" in model_item else None + model_item = self.set_model_url(model_item) + return model_item, model_full_name, model, md5hash + + @staticmethod + def ask_tos(model_full_path): + """Ask the user to agree to the terms of service""" + tos_path = os.path.join(model_full_path, "tos_agreed.txt") + print(" > You must confirm the following:") + print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"') + print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]') + answer = input(" | | > ") + if answer.lower() == "y": + with open(tos_path, "w", encoding="utf-8") as f: + f.write("I have read, understood and agreed to the Terms and Conditions.") + return True + return False + + @staticmethod + def tos_agreed(model_item, model_full_path): + """Check if the user has agreed to the terms of service""" + if "tos_required" in model_item and model_item["tos_required"]: + tos_path = os.path.join(model_full_path, "tos_agreed.txt") + if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1": + return True + return False + return True + + def create_dir_and_download_model(self, model_name, model_item, output_path): + os.makedirs(output_path, exist_ok=True) + # handle TOS + if not self.tos_agreed(model_item, output_path): + if not self.ask_tos(output_path): + os.rmdir(output_path) + raise Exception(" [!] You must agree to the terms of service to use this model.") + print(f" > Downloading model to {output_path}") + try: + if "fairseq" in model_name: + self.download_fairseq_model(model_name, output_path) + elif "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + + except requests.RequestException as e: + print(f" > Failed to download the model file to {output_path}") + rmtree(output_path) + raise e + self.print_model_license(model_item=model_item) + + def check_if_configs_are_equal(self, model_name, model_item, output_path): + with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: + config_local = json.load(f) + remote_url = None + for url in model_item["hf_url"]: + if "config.json" in url: + remote_url = url + break + + with fsspec.open(remote_url, "r", encoding="utf-8") as f: + config_remote = json.load(f) + + if not config_local == config_remote: + print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...") + self.create_dir_and_download_model(model_name, model_item, output_path) + + def download_model(self, model_name): + """Download model files given the full model name. + Model name is in the format + 'type/language/dataset/model' + e.g. 'tts_model/en/ljspeech/tacotron' + + Every model must have the following files: + - *.pth : pytorch model checkpoint file. + - config.json : model config file. + - scale_stats.npy (if exist): scale values for preprocessing. + + Args: + model_name (str): model name as explained above. + """ + model_item, model_full_name, model, md5sum = self._set_model_item(model_name) + # set the model specific output path + output_path = os.path.join(self.output_prefix, model_full_name) + if os.path.exists(output_path): + if md5sum is not None: + md5sum_file = os.path.join(output_path, "hash.md5") + if os.path.isfile(md5sum_file): + with open(md5sum_file, mode="r") as f: + if not f.read() == md5sum: + print(f" > {model_name} has been updated, clearing model cache...") + self.create_dir_and_download_model(model_name, model_item, output_path) + else: + print(f" > {model_name} is already downloaded.") + else: + print(f" > {model_name} has been updated, clearing model cache...") + self.create_dir_and_download_model(model_name, model_item, output_path) + # if the configs are different, redownload it + # ToDo: we need a better way to handle it + if "xtts" in model_name: + try: + self.check_if_configs_are_equal(model_name, model_item, output_path) + except: + pass + else: + print(f" > {model_name} is already downloaded.") + else: + self.create_dir_and_download_model(model_name, model_item, output_path) + + # find downloaded files + output_model_path = output_path + output_config_path = None + if ( + model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name + ): # TODO:This is stupid but don't care for now. + output_model_path, output_config_path = self._find_files(output_path) + # update paths in the config.json + self._update_paths(output_path, output_config_path) + return output_model_path, output_config_path, model_item + + @staticmethod + def _find_files(output_path: str) -> Tuple[str, str]: + """Find the model and config files in the output path + + Args: + output_path (str): path to the model files + + Returns: + Tuple[str, str]: path to the model file and config file + """ + model_file = None + config_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + model_file = os.path.join(output_path, file_name) + elif file_name == "config.json": + config_file = os.path.join(output_path, file_name) + if model_file is None: + raise ValueError(" [!] Model file not found in the output path") + if config_file is None: + raise ValueError(" [!] Config file not found in the output path") + return model_file, config_file + + @staticmethod + def _find_speaker_encoder(output_path: str) -> str: + """Find the speaker encoder file in the output path + + Args: + output_path (str): path to the model files + + Returns: + str: path to the speaker encoder file + """ + speaker_encoder_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_se.pth", "model_se.pth.tar"]: + speaker_encoder_file = os.path.join(output_path, file_name) + return speaker_encoder_file + + def _update_paths(self, output_path: str, config_path: str) -> None: + """Update paths for certain files in config.json after download. + + Args: + output_path (str): local path the model is downloaded to. + config_path (str): local config.json path. + """ + output_stats_path = os.path.join(output_path, "scale_stats.npy") + output_d_vector_file_path = os.path.join(output_path, "speakers.json") + output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth") + output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth") + speaker_encoder_config_path = os.path.join(output_path, "config_se.json") + speaker_encoder_model_path = self._find_speaker_encoder(output_path) + + # update the scale_path.npy file path in the model config.json + self._update_path("audio.stats_path", output_stats_path, config_path) + + # update the speakers.json file path in the model config.json to the current path + self._update_path("d_vector_file", output_d_vector_file_path, config_path) + self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path) + + # update the speaker_ids.json file path in the model config.json to the current path + self._update_path("speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path) + + # update the speaker_encoder file path in the model config.json to the current path + self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) + self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) + + @staticmethod + def _update_path(field_name, new_path, config_path): + """Update the path in the model config.json for the current environment after download""" + if new_path and os.path.exists(new_path): + config = load_config(config_path) + field_names = field_name.split(".") + if len(field_names) > 1: + # field name points to a sub-level field + sub_conf = config + for fd in field_names[:-1]: + if fd in sub_conf: + sub_conf = sub_conf[fd] + else: + return + if isinstance(sub_conf[field_names[-1]], list): + sub_conf[field_names[-1]] = [new_path] + else: + sub_conf[field_names[-1]] = new_path + else: + # field name points to a top-level field + if not field_name in config: + return + if isinstance(config[field_name], list): + config[field_name] = [new_path] + else: + config[field_name] = new_path + config.save_json(config_path) + + @staticmethod + def _download_zip_file(file_url, output_folder, progress_bar): + """Download the github releases""" + # download the file + r = requests.get(file_url, stream=True) + # extract the file + try: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_zip_name, "wb") as file: + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + with zipfile.ZipFile(temp_zip_name) as z: + z.extractall(output_folder) + os.remove(temp_zip_name) # delete zip after extract + except zipfile.BadZipFile: + print(f" > Error: Bad zip file - {file_url}") + raise zipfile.BadZipFile # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in z.namelist(): + src_path = os.path.join(output_folder, file_path) + if os.path.isfile(src_path): + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove redundant (hidden or not) folders + for file_path in z.namelist(): + if os.path.isdir(os.path.join(output_folder, file_path)): + rmtree(os.path.join(output_folder, file_path)) + + @staticmethod + def _download_tar_file(file_url, output_folder, progress_bar): + """Download the github releases""" + # download the file + r = requests.get(file_url, stream=True) + # extract the file + try: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_tar_name, "wb") as file: + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + with tarfile.open(temp_tar_name) as t: + t.extractall(output_folder) + tar_names = t.getnames() + os.remove(temp_tar_name) # delete tar after extract + except tarfile.ReadError: + print(f" > Error: Bad tar file - {file_url}") + raise tarfile.ReadError # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): + src_path = os.path.join(output_folder, tar_names[0], file_path) + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove the extracted folder + rmtree(os.path.join(output_folder, tar_names[0])) + + @staticmethod + def _download_model_files(file_urls, output_folder, progress_bar): + """Download the github releases""" + for file_url in file_urls: + # download the file + r = requests.get(file_url, stream=True) + # extract the file + bease_filename = file_url.split("/")[-1] + temp_zip_name = os.path.join(output_folder, bease_filename) + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + with open(temp_zip_name, "wb") as file: + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + + @staticmethod + def _check_dict_key(my_dict, key): + if key in my_dict.keys() and my_dict[key] is not None: + if not isinstance(key, str): + return True + if isinstance(key, str) and len(my_dict[key]) > 0: + return True + return False diff --git a/viXTTS/TTS/utils/radam.py b/viXTTS/TTS/utils/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd14990f33cb671f030e401a3a2f9b96c2710cd --- /dev/null +++ b/viXTTS/TTS/utils/radam.py @@ -0,0 +1,105 @@ +# modified from https://github.com/LiyuanLucasLiu/RAdam + +import math + +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)] + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # pylint: disable=useless-super-delegation + super().__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + + return loss diff --git a/viXTTS/TTS/utils/samplers.py b/viXTTS/TTS/utils/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b08a763a33e40d00a32577e31ffc12b8e228bc46 --- /dev/null +++ b/viXTTS/TTS/utils/samplers.py @@ -0,0 +1,201 @@ +import math +import random +from typing import Callable, List, Union + +from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler + + +class SubsetSampler(Sampler): + """ + Samples elements sequentially from a given list of indices. + + Args: + indices (list): a sequence of indices + """ + + def __init__(self, indices): + super().__init__(indices) + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in range(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class PerfectBatchSampler(Sampler): + """ + Samples a mini-batch of indices for a balanced class batching + + Args: + dataset_items(list): dataset items to sample from. + classes (list): list of classes of dataset_items to sample from. + batch_size (int): total number of samples to be sampled in a mini-batch. + num_gpus (int): number of GPU in the data parallel mode. + shuffle (bool): if True, samples randomly, otherwise samples sequentially. + drop_last (bool): if True, drops last incomplete batch. + """ + + def __init__( + self, + dataset_items, + classes, + batch_size, + num_classes_in_batch, + num_gpus=1, + shuffle=True, + drop_last=False, + label_key="class_name", + ): + super().__init__(dataset_items) + assert ( + batch_size % (num_classes_in_batch * num_gpus) == 0 + ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." + + label_indices = {} + for idx, item in enumerate(dataset_items): + label = item[label_key] + if label not in label_indices.keys(): + label_indices[label] = [idx] + else: + label_indices[label].append(idx) + + if shuffle: + self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] + else: + self._samplers = [SubsetSampler(label_indices[key]) for key in classes] + + self._batch_size = batch_size + self._drop_last = drop_last + self._dp_devices = num_gpus + self._num_classes_in_batch = num_classes_in_batch + + def __iter__(self): + batch = [] + if self._num_classes_in_batch != len(self._samplers): + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + else: + valid_samplers_idx = None + + iters = [iter(s) for s in self._samplers] + done = False + + while True: + b = [] + for i, it in enumerate(iters): + if valid_samplers_idx is not None and i not in valid_samplers_idx: + continue + idx = next(it, None) + if idx is None: + done = True + break + b.append(idx) + if done: + break + batch += b + if len(batch) == self._batch_size: + yield batch + batch = [] + if valid_samplers_idx is not None: + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + + if not self._drop_last: + if len(batch) > 0: + groups = len(batch) // self._num_classes_in_batch + if groups % self._dp_devices == 0: + yield batch + else: + batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] + if len(batch) > 0: + yield batch + + def __len__(self): + class_batch_size = self._batch_size // self._num_classes_in_batch + return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) + + +def identity(x): + return x + + +class SortedSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Taken from https://github.com/PetrochukM/PyTorch-NLP + + Args: + data (iterable): Iterable data. + sort_key (callable): Specifies a function of one argument that is used to extract a + numerical comparison key from each list element. + + Example: + >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + """ + + def __init__(self, data, sort_key: Callable = identity): + super().__init__(data) + self.data = data + self.sort_key = sort_key + zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] + zip_ = sorted(zip_, key=lambda r: r[1]) + self.sorted_indexes = [item[0] for item in zip_] + + def __iter__(self): + return iter(self.sorted_indexes) + + def __len__(self): + return len(self.data) + + +class BucketBatchSampler(BatchSampler): + """Bucket batch sampler + + Adapted from https://github.com/PetrochukM/PyTorch-NLP + + Args: + sampler (torch.data.utils.sampler.Sampler): + batch_size (int): Size of mini-batch. + drop_last (bool): If `True` the sampler will drop the last batch if its size would be less + than `batch_size`. + data (list): List of data samples. + sort_key (callable, optional): Callable to specify a comparison key for sorting. + bucket_size_multiplier (int, optional): Buckets are of size + `batch_size * bucket_size_multiplier`. + + Example: + >>> sampler = WeightedRandomSampler(weights, len(weights)) + >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) + """ + + def __init__( + self, + sampler, + data, + batch_size, + drop_last, + sort_key: Union[Callable, List] = identity, + bucket_size_multiplier=100, + ): + super().__init__(sampler, batch_size, drop_last) + self.data = data + self.sort_key = sort_key + _bucket_size = batch_size * bucket_size_multiplier + if hasattr(sampler, "__len__"): + _bucket_size = min(_bucket_size, len(sampler)) + self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) + + def __iter__(self): + for idxs in self.bucket_sampler: + bucket_data = [self.data[idx] for idx in idxs] + sorted_sampler = SortedSampler(bucket_data, self.sort_key) + for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): + sorted_idxs = [idxs[i] for i in batch_idx] + yield sorted_idxs + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + return math.ceil(len(self.sampler) / self.batch_size) diff --git a/viXTTS/TTS/utils/synthesizer.py b/viXTTS/TTS/utils/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b98647c30c4cb8b99cd02aa6920306a472dd1569 --- /dev/null +++ b/viXTTS/TTS/utils/synthesizer.py @@ -0,0 +1,505 @@ +import os +import time +from typing import List + +import numpy as np +import pysbd +import torch +from torch import nn + +from TTS.config import load_config +from TTS.tts.configs.vits_config import VitsConfig +from TTS.tts.models import setup_model as setup_tts_model +from TTS.tts.models.vits import Vits + +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import +from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import save_wav +from TTS.vc.models import setup_model as setup_vc_model +from TTS.vocoder.models import setup_model as setup_vocoder_model +from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input + + +class Synthesizer(nn.Module): + def __init__( + self, + tts_checkpoint: str = "", + tts_config_path: str = "", + tts_speakers_file: str = "", + tts_languages_file: str = "", + vocoder_checkpoint: str = "", + vocoder_config: str = "", + encoder_checkpoint: str = "", + encoder_config: str = "", + vc_checkpoint: str = "", + vc_config: str = "", + model_dir: str = "", + voice_dir: str = None, + use_cuda: bool = False, + ) -> None: + """General 🐸 TTS interface for inference. It takes a tts and a vocoder + model and synthesize speech from the provided text. + + The text is divided into a list of sentences using `pysbd` and synthesize + speech on each sentence separately. + + If you have certain special characters in your text, you need to handle + them before providing the text to Synthesizer. + + TODO: set the segmenter based on the source language + + Args: + tts_checkpoint (str, optional): path to the tts model file. + tts_config_path (str, optional): path to the tts config file. + vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. + vocoder_config (str, optional): path to the vocoder config file. Defaults to None. + encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`, + encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`, + vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`, + vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, + use_cuda (bool, optional): enable/disable cuda. Defaults to False. + """ + super().__init__() + self.tts_checkpoint = tts_checkpoint + self.tts_config_path = tts_config_path + self.tts_speakers_file = tts_speakers_file + self.tts_languages_file = tts_languages_file + self.vocoder_checkpoint = vocoder_checkpoint + self.vocoder_config = vocoder_config + self.encoder_checkpoint = encoder_checkpoint + self.encoder_config = encoder_config + self.vc_checkpoint = vc_checkpoint + self.vc_config = vc_config + self.use_cuda = use_cuda + + self.tts_model = None + self.vocoder_model = None + self.vc_model = None + self.speaker_manager = None + self.tts_speakers = {} + self.language_manager = None + self.num_languages = 0 + self.tts_languages = {} + self.d_vector_dim = 0 + self.seg = self._get_segmenter("en") + self.use_cuda = use_cuda + self.voice_dir = voice_dir + if self.use_cuda: + assert torch.cuda.is_available(), "CUDA is not availabe on this machine." + + if tts_checkpoint: + self._load_tts(tts_checkpoint, tts_config_path, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + + if vocoder_checkpoint: + self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] + + if vc_checkpoint: + self._load_vc(vc_checkpoint, vc_config, use_cuda) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] + + if model_dir: + if "fairseq" in model_dir: + self._load_fairseq_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + else: + self._load_tts_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] + + @staticmethod + def _get_segmenter(lang: str): + """get the sentence segmenter for the given language. + + Args: + lang (str): target language code. + + Returns: + [type]: [description] + """ + return pysbd.Segmenter(language=lang, clean=True) + + def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None: + """Load the voice conversion model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + vc_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.vc_config = load_config(vc_config_path) + self.vc_model = setup_vc_model(config=self.vc_config) + self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) + if use_cuda: + self.vc_model.cuda() + + def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the fairseq model from a directory. + + We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + self.tts_config = VitsConfig() + self.tts_model = Vits.init_from_config(self.tts_config) + self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) + self.tts_config = self.tts_model.config + if use_cuda: + self.tts_model.cuda() + + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the TTS model from a directory. + + We assume the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + config = load_config(os.path.join(model_dir, "config.json")) + self.tts_config = config + self.tts_model = setup_tts_model(config) + self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) + if use_cuda: + self.tts_model.cuda() + + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: + """Load the TTS model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + 5. Init the speaker manager in the model. + + Args: + tts_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.tts_config = load_config(tts_config_path) + if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: + raise ValueError("Phonemizer is not defined in the TTS config.") + + self.tts_model = setup_tts_model(config=self.tts_config) + + if not self.encoder_checkpoint: + self._set_speaker_encoder_paths_from_tts_config() + + self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) + if use_cuda: + self.tts_model.cuda() + + if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"): + self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda) + + def _set_speaker_encoder_paths_from_tts_config(self): + """Set the encoder paths from the tts model config for models with speaker encoders.""" + if hasattr(self.tts_config, "model_args") and hasattr( + self.tts_config.model_args, "speaker_encoder_config_path" + ): + self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path + self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path + + def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: + """Load the vocoder model. + + 1. Load the vocoder config. + 2. Init the AudioProcessor for the vocoder. + 3. Init the vocoder model from the config. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + model_file (str): path to the model checkpoint. + model_config (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + self.vocoder_config = load_config(model_config) + self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio) + self.vocoder_model = setup_vocoder_model(self.vocoder_config) + self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) + if use_cuda: + self.vocoder_model.cuda() + + def split_into_sentences(self, text) -> List[str]: + """Split give text into sentences. + + Args: + text (str): input text in string format. + + Returns: + List[str]: list of sentences. + """ + return self.seg.segment(text) + + def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None: + """Save the waveform as a file. + + Args: + wav (List[int]): waveform as a list of values. + path (str): output path to save the waveform. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + # if tensor convert to numpy + if torch.is_tensor(wav): + wav = wav.cpu().numpy() + if isinstance(wav, list): + wav = np.array(wav) + save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out) + + def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]: + output_wav = self.vc_model.voice_conversion(source_wav, target_wav) + return output_wav + + def tts( + self, + text: str = "", + speaker_name: str = "", + language_name: str = "", + speaker_wav=None, + style_wav=None, + style_text=None, + reference_wav=None, + reference_speaker_name=None, + split_sentences: bool = True, + **kwargs, + ) -> List[int]: + """🐸 TTS magic. Run all the models and generate speech. + + Args: + text (str): input text. + speaker_name (str, optional): speaker id for multi-speaker models. Defaults to "". + language_name (str, optional): language id for multi-language models. Defaults to "". + speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None. + style_wav ([type], optional): style waveform for GST. Defaults to None. + style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None. + reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None. + reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None. + split_sentences (bool, optional): split the input text into sentences. Defaults to True. + **kwargs: additional arguments to pass to the TTS model. + Returns: + List[int]: [description] + """ + start_time = time.time() + wavs = [] + + if not text and not reference_wav: + raise ValueError( + "You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API." + ) + + if text: + sens = [text] + if split_sentences: + print(" > Text splitted to sentences.") + sens = self.split_into_sentences(text) + print(sens) + + # handle multi-speaker + if "voice_dir" in kwargs: + self.voice_dir = kwargs["voice_dir"] + kwargs.pop("voice_dir") + speaker_embedding = None + speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): + if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts": + if self.tts_config.use_d_vector_file: + # get the average speaker embedding from the saved d_vectors. + speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding( + speaker_name, num_samples=None, randomize=False + ) + speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name] + # handle Neon models with single speaker. + elif len(self.tts_model.speaker_manager.name_to_id) == 1: + speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0] + elif not speaker_name and not speaker_wav: + raise ValueError( + " [!] Looks like you are using a multi-speaker model. " + "You need to define either a `speaker_idx` or a `speaker_wav` to use a multi-speaker model." + ) + else: + speaker_embedding = None + else: + if speaker_name and self.voice_dir is None: + raise ValueError( + f" [!] Missing speakers.json file path for selecting speaker {speaker_name}." + "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " + ) + + # handle multi-lingual + language_id = None + if self.tts_languages_file or ( + hasattr(self.tts_model, "language_manager") + and self.tts_model.language_manager is not None + and not self.tts_config.model == "xtts" + ): + if len(self.tts_model.language_manager.name_to_id) == 1: + language_id = list(self.tts_model.language_manager.name_to_id.values())[0] + + elif language_name and isinstance(language_name, str): + try: + language_id = self.tts_model.language_manager.name_to_id[language_name] + except KeyError as e: + raise ValueError( + f" [!] Looks like you use a multi-lingual model. " + f"Language {language_name} is not in the available languages: " + f"{self.tts_model.language_manager.name_to_id.keys()}." + ) from e + + elif not language_name: + raise ValueError( + " [!] Look like you use a multi-lingual model. " + "You need to define either a `language_name` or a `style_wav` to use a multi-lingual model." + ) + + else: + raise ValueError( + f" [!] Missing language_ids.json file path for selecting language {language_name}." + "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " + ) + + # compute a new d_vector from the given clip. + if ( + speaker_wav is not None + and self.tts_model.speaker_manager is not None + and hasattr(self.tts_model.speaker_manager, "encoder_ap") + and self.tts_model.speaker_manager.encoder_ap is not None + ): + speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + + vocoder_device = "cpu" + use_gl = self.vocoder_model is None + if not use_gl: + vocoder_device = next(self.vocoder_model.parameters()).device + if self.use_cuda: + vocoder_device = "cuda" + + if not reference_wav: # not voice conversion + for sen in sens: + if hasattr(self.tts_model, "synthesize"): + outputs = self.tts_model.synthesize( + text=sen, + config=self.tts_config, + speaker_id=speaker_name, + voice_dirs=self.voice_dir, + d_vector=speaker_embedding, + speaker_wav=speaker_wav, + language=language_name, + **kwargs, + ) + else: + # synthesize voice + outputs = synthesis( + model=self.tts_model, + text=sen, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + speaker_id=speaker_id, + style_wav=style_wav, + style_text=style_text, + use_griffin_lim=use_gl, + d_vector=speaker_embedding, + language_id=language_id, + ) + waveform = outputs["wav"] + if not use_gl: + mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() + # denormalize tts output based on tts audio config + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + print(" > interpolating tts model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl: + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + waveform = waveform.squeeze() + + # trim silence + if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]: + waveform = trim_silence(waveform, self.tts_model.ap) + + wavs += list(waveform) + wavs += [0] * 10000 + else: + # get the speaker embedding or speaker id for the reference wav file + reference_speaker_embedding = None + reference_speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): + if reference_speaker_name and isinstance(reference_speaker_name, str): + if self.tts_config.use_d_vector_file: + # get the speaker embedding from the saved d_vectors. + reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name( + reference_speaker_name + )[0] + reference_speaker_embedding = np.array(reference_speaker_embedding)[ + None, : + ] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + reference_speaker_id = self.tts_model.speaker_manager.name_to_id[reference_speaker_name] + else: + reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip( + reference_wav + ) + outputs = transfer_voice( + model=self.tts_model, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + reference_wav=reference_wav, + speaker_id=speaker_id, + d_vector=speaker_embedding, + use_griffin_lim=use_gl, + reference_speaker_id=reference_speaker_id, + reference_d_vector=reference_speaker_embedding, + ) + waveform = outputs + if not use_gl: + mel_postnet_spec = outputs[0].detach().cpu().numpy() + # denormalize tts output based on tts audio config + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + print(" > interpolating tts model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"): + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + wavs = waveform.squeeze() + + # compute stats + process_time = time.time() - start_time + audio_time = len(wavs) / self.tts_config.audio["sample_rate"] + print(f" > Processing time: {process_time}") + print(f" > Real-time factor: {process_time / audio_time}") + return wavs diff --git a/viXTTS/TTS/utils/training.py b/viXTTS/TTS/utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..b51f55e92b56bece69ae61f99f68b48c88938261 --- /dev/null +++ b/viXTTS/TTS/utils/training.py @@ -0,0 +1,44 @@ +import numpy as np +import torch + + +def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): + r"""Check model gradient against unexpected jumps and failures""" + skip_flag = False + if ignore_stopnet: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_( + [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + else: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + + # compatibility with different torch versions + if isinstance(grad_norm, float): + if np.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + else: + if torch.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + return grad_norm, skip_flag + + +def gradual_training_scheduler(global_step, config): + """Setup the gradual training schedule wrt number + of active GPUs""" + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + num_gpus = 1 + new_values = None + # we set the scheduling wrt num_gpus + for values in config.gradual_training: + if global_step * num_gpus >= values[0]: + new_values = values + return new_values[1], new_values[2] diff --git a/viXTTS/TTS/utils/vad.py b/viXTTS/TTS/utils/vad.py new file mode 100644 index 0000000000000000000000000000000000000000..aefce2b50b9b3924c4c0e8fb2e4a988e0758e9cd --- /dev/null +++ b/viXTTS/TTS/utils/vad.py @@ -0,0 +1,88 @@ +import torch +import torchaudio + + +def read_audio(path): + wav, sr = torchaudio.load(path) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + return wav.squeeze(0), sr + + +def resample_wav(wav, sr, new_sr): + wav = wav.unsqueeze(0) + transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr) + wav = transform(wav) + return wav.squeeze(0) + + +def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False): + factor = new_sr / vad_sr + new_timestamps = [] + if just_begging_end and timestamps: + # get just the start and end timestamps + new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)} + new_timestamps.append(new_dict) + else: + for ts in timestamps: + # map to the new SR + new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)} + new_timestamps.append(new_dict) + + return new_timestamps + + +def get_vad_model_and_utils(use_cuda=False, use_onnx=False): + model, utils = torch.hub.load( + repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=use_onnx, force_onnx_cpu=True + ) + if use_cuda: + model = model.cuda() + + get_speech_timestamps, save_audio, _, _, collect_chunks = utils + return model, get_speech_timestamps, save_audio, collect_chunks + + +def remove_silence( + model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False +): + # get the VAD model and utils functions + model, get_speech_timestamps, _, collect_chunks = model_and_utils + + # read ground truth wav and resample the audio for the VAD + try: + wav, gt_sample_rate = read_audio(audio_path) + except: + print(f"> ❗ Failed to read {audio_path}") + return None, False + + # if needed, resample the audio for the VAD model + if gt_sample_rate != vad_sample_rate: + wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate) + else: + wav_vad = wav + + if use_cuda: + wav_vad = wav_vad.cuda() + + # get speech timestamps from full audio file + speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768) + + # map the current speech_timestamps to the sample rate of the ground truth audio + new_speech_timestamps = map_timestamps_to_new_sr( + vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end + ) + + # if have speech timestamps else save the wav + if new_speech_timestamps: + wav = collect_chunks(new_speech_timestamps, wav) + is_speech = True + else: + print(f"> The file {audio_path} probably does not have speech please check it !!") + is_speech = False + + # save + torchaudio.save(out_path, wav[None, :], gt_sample_rate) + return out_path, is_speech diff --git a/viXTTS/TTS/vc/configs/__init__.py b/viXTTS/TTS/vc/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/vc/configs/freevc_config.py b/viXTTS/TTS/vc/configs/freevc_config.py new file mode 100644 index 0000000000000000000000000000000000000000..207181b303982f260c46619bc8ac470f5e950223 --- /dev/null +++ b/viXTTS/TTS/vc/configs/freevc_config.py @@ -0,0 +1,278 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +from coqpit import Coqpit + +from TTS.vc.configs.shared_configs import BaseVCConfig + + +@dataclass +class FreeVCAudioConfig(Coqpit): + """Audio configuration + + Args: + max_wav_value (float): + The maximum value of the waveform. + + input_sample_rate (int): + The sampling rate of the input waveform. + + output_sample_rate (int): + The sampling rate of the output waveform. + + filter_length (int): + The length of the filter. + + hop_length (int): + The hop length. + + win_length (int): + The window length. + + n_mel_channels (int): + The number of mel channels. + + mel_fmin (float): + The minimum frequency of the mel filterbank. + + mel_fmax (Optional[float]): + The maximum frequency of the mel filterbank. + """ + + max_wav_value: float = field(default=32768.0) + input_sample_rate: int = field(default=16000) + output_sample_rate: int = field(default=24000) + filter_length: int = field(default=1280) + hop_length: int = field(default=320) + win_length: int = field(default=1280) + n_mel_channels: int = field(default=80) + mel_fmin: float = field(default=0.0) + mel_fmax: Optional[float] = field(default=None) + + +@dataclass +class FreeVCArgs(Coqpit): + """FreeVC model arguments + + Args: + spec_channels (int): + The number of channels in the spectrogram. + + inter_channels (int): + The number of channels in the intermediate layers. + + hidden_channels (int): + The number of channels in the hidden layers. + + filter_channels (int): + The number of channels in the filter layers. + + n_heads (int): + The number of attention heads. + + n_layers (int): + The number of layers. + + kernel_size (int): + The size of the kernel. + + p_dropout (float): + The dropout probability. + + resblock (str): + The type of residual block. + + resblock_kernel_sizes (List[int]): + The kernel sizes for the residual blocks. + + resblock_dilation_sizes (List[List[int]]): + The dilation sizes for the residual blocks. + + upsample_rates (List[int]): + The upsample rates. + + upsample_initial_channel (int): + The number of channels in the initial upsample layer. + + upsample_kernel_sizes (List[int]): + The kernel sizes for the upsample layers. + + n_layers_q (int): + The number of layers in the quantization network. + + use_spectral_norm (bool): + Whether to use spectral normalization. + + gin_channels (int): + The number of channels in the global conditioning vector. + + ssl_dim (int): + The dimension of the self-supervised learning embedding. + + use_spk (bool): + Whether to use external speaker encoder. + """ + + spec_channels: int = field(default=641) + inter_channels: int = field(default=192) + hidden_channels: int = field(default=192) + filter_channels: int = field(default=768) + n_heads: int = field(default=2) + n_layers: int = field(default=6) + kernel_size: int = field(default=3) + p_dropout: float = field(default=0.1) + resblock: str = field(default="1") + resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2]) + upsample_initial_channel: int = field(default=512) + upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + n_layers_q: int = field(default=3) + use_spectral_norm: bool = field(default=False) + gin_channels: int = field(default=256) + ssl_dim: int = field(default=1024) + use_spk: bool = field(default=False) + num_spks: int = field(default=0) + segment_size: int = field(default=8960) + + +@dataclass +class FreeVCConfig(BaseVCConfig): + """Defines parameters for FreeVC End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (FreeVCArgs): + Model architecture arguments. Defaults to `FreeVCArgs()`. + + audio (FreeVCAudioConfig): + Audio processing configuration. Defaults to `FreeVCAudioConfig()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.vc.configs.freevc_config import FreeVCConfig + >>> config = FreeVCConfig() + """ + + model: str = "freevc" + # model specific params + model_args: FreeVCArgs = field(default_factory=FreeVCArgs) + audio: FreeVCAudioConfig = field(default_factory=FreeVCAudioConfig) + + # optimizer + # TODO with training support + + # loss params + # TODO with training support + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + speakers_file: str = None + speaker_embedding_channels: int = 256 + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: List[str] = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/viXTTS/TTS/vc/configs/shared_configs.py b/viXTTS/TTS/vc/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..74164a744452a00c7f318fbdcc55438cddcc70be --- /dev/null +++ b/viXTTS/TTS/vc/configs/shared_configs.py @@ -0,0 +1,155 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import Coqpit, check_argument + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseVCConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + + audio (BaseAudioConfig): + Audio processor config object instance. + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. + + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). + + compute_f0 (int): + (Not in use yet). + + compute_energy (int): + (Not in use yet). + + compute_linear_spec (bool): + If True data loader computes and returns linear spectrograms alongside the other data. + + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + + use_noise_augment (bool): + Augment the input audio with random noise. + + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + + shuffle (bool): + If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True. + + drop_last (bool): + If True, the data loader will drop the last batch if it is not complete. It helps to prevent + issues that emerge from the partial batch statistics. Defaults to True. + + add_blank (bool): + Add blank characters between each other two characters. It improves performance for some models at expense + of slower run-time due to the longer input sequence. + + datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") + compute_f0: bool = False + compute_energy: bool = False + compute_linear_spec: bool = False + precompute_num_workers: int = 0 + use_noise_augment: bool = False + start_by_longest: bool = False + shuffle: bool = False + drop_last: bool = False + # dataset + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = "radam" + optimizer_params: dict = None + # scheduler + lr_scheduler: str = None + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/viXTTS/TTS/vc/models/__init__.py b/viXTTS/TTS/vc/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a09b4e53e2a4e40dd7c9d0d4c7e5b3c30f317a5 --- /dev/null +++ b/viXTTS/TTS/vc/models/__init__.py @@ -0,0 +1,17 @@ +import importlib +import re +from typing import Dict, List, Union + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": + print(" > Using model: {}".format(config.model)) + # fetch the right model implementation. + if "model" in config and config["model"].lower() == "freevc": + MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC + model = MyModel.init_from_config(config, samples) + return model diff --git a/viXTTS/TTS/vc/models/base_vc.py b/viXTTS/TTS/vc/models/base_vc.py new file mode 100644 index 0000000000000000000000000000000000000000..19f2761bbc42051a2f03d1170c011955dcdd28cb --- /dev/null +++ b/viXTTS/TTS/vc/models/base_vc.py @@ -0,0 +1,429 @@ +import os +import random +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper + +from TTS.model import BaseTrainerModel +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + +# pylint: skip-file + + +class BaseVC(BaseTrainerModel): + """Base `vc` class. Every new `vc` model must inherit this. + + It defines common `vc` specific functions on top of `Model` implementation. + """ + + MODEL_TYPE = "vc" + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__() + self.config = config + self.ap = ap + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. + + Args: + config (Coqpit): Model configuration. + """ + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + print(" > Init speaker_embedding layer.") + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if self.speaker_manager is not None: + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if self.language_manager is not None and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `VCDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + energy = batch["energy"] + language_ids = batch["language_ids"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_names"], + } + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + print(" > Using Language weighted sampler with alpha:", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + print(" > Using Speaker weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + print(" > Using Length weighted sampler with alpha:", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None + else: + speaker_id_mapping = None + d_vector_mapping = None + + # setup multi-lingual attributes + if self.language_manager is not None: + language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), + compute_energy=config.get("compute_energy", False), + energy_cache_path=config.get("energy_cache_path", None), + samples=samples, + ap=self.ap, + return_wav=config.return_wav if "return_wav" in config else False, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + use_noise_augment=False if is_eval else config.use_noise_augment, + verbose=verbose, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=None, + start_by_longest=config.start_by_longest, + language_id_mapping=language_id_mapping, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=config.shuffle if sampler is None else False, # if there is no other sampler + collate_fn=dataset.collate_fn, + drop_last=config.drop_last, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def _get_test_aux_input( + self, + ) -> Dict: + d_vector = None + if self.config.use_d_vector_file: + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] + d_vector = (random.sample(sorted(d_vector), 1),) + + aux_inputs = { + "speaker_id": None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1), + "d_vector": d_vector, + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `vc` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + if isinstance(sen, list): + aux_inputs = self.get_aux_input_from_test_sentences(sen) + sen = aux_inputs["text"] + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return test_figures, test_audios + + def on_init_start(self, trainer): + """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.pth") + self.speaker_manager.save_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `speakers.pth` is saved to {output_path}.") + print(" > `speakers_file` is updated in the config.json.") + + if self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + print(f" > `language_ids.json` is saved to {output_path}.") + print(" > `language_ids_file` is updated in the config.json.") diff --git a/viXTTS/TTS/vc/models/freevc.py b/viXTTS/TTS/vc/models/freevc.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb9989224215c8659bd5f2a04f74f07b5104d37 --- /dev/null +++ b/viXTTS/TTS/vc/models/freevc.py @@ -0,0 +1,562 @@ +from typing import Dict, List, Optional, Tuple, Union + +import librosa +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import Conv1d, Conv2d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + +import TTS.vc.modules.freevc.commons as commons +import TTS.vc.modules.freevc.modules as modules +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.io import load_fsspec +from TTS.vc.configs.freevc_config import FreeVCConfig +from TTS.vc.models.base_vc import BaseVC +from TTS.vc.modules.freevc.commons import get_padding, init_weights +from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch +from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx +from TTS.vc.modules.freevc.wavlm import get_wavlm + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Encoder(nn.Module): + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_parametrizations(l, "weight") + for l in self.resblocks: + remove_parametrizations(l, "weight") + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SpeakerEncoder(torch.nn.Module): + def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256): + super(SpeakerEncoder, self).__init__() + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + def forward(self, mels): + self.lstm.flatten_parameters() + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + def compute_partial_slices(self, total_frames, partial_frames, partial_hop): + mel_slices = [] + for i in range(0, total_frames - partial_frames, partial_hop): + mel_range = torch.arange(i, i + partial_frames) + mel_slices.append(mel_range) + + return mel_slices + + def embed_utterance(self, mel, partial_frames=128, partial_hop=64): + mel_len = mel.size(1) + last_mel = mel[:, -partial_frames:] + + if mel_len > partial_frames: + mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop) + mels = list(mel[:, s] for s in mel_slices) + mels.append(last_mel) + mels = torch.stack(tuple(mels), 0).squeeze(1) + + with torch.no_grad(): + partial_embeds = self(mels) + embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) + # embed = embed / torch.linalg.norm(embed, 2) + else: + with torch.no_grad(): + embed = self(last_mel) + + return embed + + +class FreeVC(BaseVC): + """ + + Papaer:: + https://arxiv.org/abs/2210.15418# + + Paper Abstract:: + Voice conversion (VC) can be achieved by first extracting source content information and target speaker + information, and then reconstructing waveform with these information. However, current approaches normally + either extract dirty content information with speaker information leaked in, or demand a large amount of + annotated data for training. Besides, the quality of reconstructed waveform can be degraded by the + mismatch between conversion model and vocoder. In this paper, we adopt the end-to-end framework of VITS for + high-quality waveform reconstruction, and propose strategies for clean content information extraction without + text annotation. We disentangle content information by imposing an information bottleneck to WavLM features, + and propose the spectrogram-resize based data augmentation to improve the purity of extracted content + information. Experimental results show that the proposed method outperforms the latest VC models trained with + annotated data and has greater robustness. + + Original Code:: + https://github.com/OlaWod/FreeVC + + Examples: + >>> from TTS.vc.configs.freevc_config import FreeVCConfig + >>> from TTS.vc.models.freevc import FreeVC + >>> config = FreeVCConfig() + >>> model = FreeVC(config) + """ + + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + super().__init__(config, None, speaker_manager, None) + + self.init_multispeaker(config) + + self.spec_channels = self.args.spec_channels + self.inter_channels = self.args.inter_channels + self.hidden_channels = self.args.hidden_channels + self.filter_channels = self.args.filter_channels + self.n_heads = self.args.n_heads + self.n_layers = self.args.n_layers + self.kernel_size = self.args.kernel_size + self.p_dropout = self.args.p_dropout + self.resblock = self.args.resblock + self.resblock_kernel_sizes = self.args.resblock_kernel_sizes + self.resblock_dilation_sizes = self.args.resblock_dilation_sizes + self.upsample_rates = self.args.upsample_rates + self.upsample_initial_channel = self.args.upsample_initial_channel + self.upsample_kernel_sizes = self.args.upsample_kernel_sizes + self.segment_size = self.args.segment_size + self.gin_channels = self.args.gin_channels + self.ssl_dim = self.args.ssl_dim + self.use_spk = self.args.use_spk + + self.enc_p = Encoder(self.args.ssl_dim, self.inter_channels, self.hidden_channels, 5, 1, 16) + self.dec = Generator( + self.inter_channels, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + gin_channels=self.gin_channels, + ) + self.enc_q = Encoder( + self.spec_channels, self.inter_channels, self.hidden_channels, 5, 1, 16, gin_channels=self.gin_channels + ) + self.flow = ResidualCouplingBlock( + self.inter_channels, self.hidden_channels, 5, 1, 4, gin_channels=self.gin_channels + ) + if not self.use_spk: + self.enc_spk = SpeakerEncoder(model_hidden_size=self.gin_channels, model_embedding_size=self.gin_channels) + else: + self.load_pretrained_speaker_encoder() + + self.wavlm = get_wavlm() + + @property + def device(self): + return next(self.parameters()).device + + def load_pretrained_speaker_encoder(self): + """Load pretrained speaker encoder model as mentioned in the paper.""" + print(" > Loading pretrained speaker encoder model ...") + self.enc_spk_ex = SpeakerEncoderEx( + "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt" + ) + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.num_spks = self.args.num_spks + if self.speaker_manager: + self.num_spks = self.speaker_manager.num_spks + + def forward( + self, + c: torch.Tensor, + spec: torch.Tensor, + g: Optional[torch.Tensor] = None, + mel: Optional[torch.Tensor] = None, + c_lengths: Optional[torch.Tensor] = None, + spec_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ]: + """ + Forward pass of the model. + + Args: + c: WavLM features. Shape: (batch_size, c_seq_len). + spec: The input spectrogram. Shape: (batch_size, spec_seq_len, spec_dim). + g: The speaker embedding. Shape: (batch_size, spk_emb_dim). + mel: The input mel-spectrogram for the speaker encoder. Shape: (batch_size, mel_seq_len, mel_dim). + c_lengths: The lengths of the WavLM features. Shape: (batch_size,). + spec_lengths: The lengths of the spectrogram. Shape: (batch_size,). + + Returns: + o: The output spectrogram. Shape: (batch_size, spec_seq_len, spec_dim). + ids_slice: The slice indices. Shape: (batch_size, num_slices). + spec_mask: The spectrogram mask. Shape: (batch_size, spec_seq_len). + (z, z_p, m_p, logs_p, m_q, logs_q): A tuple of latent variables. + """ + + # If c_lengths is None, set it to the length of the last dimension of c + if c_lengths is None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + + # If spec_lengths is None, set it to the length of the last dimension of spec + if spec_lengths is None: + spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) + + # If use_spk is False, compute g from mel using enc_spk + g = None + if not self.use_spk: + g = self.enc_spk(mel).unsqueeze(-1) + + # Compute m_p, logs_p, z, m_q, logs_q, and spec_mask using enc_p and enc_q + _, m_p, logs_p, _ = self.enc_p(c, c_lengths) + z, m_q, logs_q, spec_mask = self.enc_q(spec.transpose(1, 2), spec_lengths, g=g) + + # Compute z_p using flow + z_p = self.flow(z, spec_mask, g=g) + + # Randomly slice z and compute o using dec + z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + + return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + @torch.no_grad() + def inference(self, c, g=None, mel=None, c_lengths=None): + """ + Inference pass of the model + + Args: + c (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len). + g (torch.Tensor): Speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + mel (torch.Tensor): Mel-spectrogram tensor. Shape: (batch_size, mel_seq_len, mel_dim). + c_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,). + + Returns: + torch.Tensor: Output tensor. + """ + if c_lengths == None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if not self.use_spk: + g = self.enc_spk.embed_utterance(mel) + g = g.unsqueeze(-1) + z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g) + return o + + def extract_wavlm_features(self, y): + """Extract WavLM features from an audio tensor. + + Args: + y (torch.Tensor): Audio tensor. Shape: (batch_size, audio_seq_len). + """ + + with torch.no_grad(): + c = self.wavlm.extract_features(y)[0] + c = c.transpose(1, 2) + return c + + def load_audio(self, wav): + """Read and format the input audio.""" + if isinstance(wav, str): + wav, _ = librosa.load(wav, sr=self.config.audio.input_sample_rate) + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav).to(self.device) + if isinstance(wav, torch.Tensor): + wav = wav.to(self.device) + if isinstance(wav, list): + wav = torch.from_numpy(np.array(wav)).to(self.device) + return wav.float() + + @torch.inference_mode() + def voice_conversion(self, src, tgt): + """ + Voice conversion pass of the model. + + Args: + src (str or torch.Tensor): Source utterance. + tgt (str or torch.Tensor): Target utterance. + + Returns: + torch.Tensor: Output tensor. + """ + + wav_tgt = self.load_audio(tgt).cpu().numpy() + wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) + + if self.config.model_args.use_spk: + g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt) + g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device) + else: + wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device) + mel_tgt = mel_spectrogram_torch( + wav_tgt, + self.config.audio.filter_length, + self.config.audio.n_mel_channels, + self.config.audio.input_sample_rate, + self.config.audio.hop_length, + self.config.audio.win_length, + self.config.audio.mel_fmin, + self.config.audio.mel_fmax, + ) + # src + wav_src = self.load_audio(src) + c = self.extract_wavlm_features(wav_src[None, :]) + + if self.config.model_args.use_spk: + audio = self.inference(c, g=g_tgt) + else: + audio = self.inference(c, mel=mel_tgt.transpose(1, 2)) + audio = audio[0][0].data.cpu().float().numpy() + return audio + + def eval_step(): + ... + + @staticmethod + def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None, verbose=True): + model = FreeVC(config) + return model + + def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"], strict=strict) + if eval: + self.eval() + + def train_step(): + ... diff --git a/viXTTS/TTS/vc/modules/__init__.py b/viXTTS/TTS/vc/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/vc/modules/freevc/__init__.py b/viXTTS/TTS/vc/modules/freevc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/vc/modules/freevc/commons.py b/viXTTS/TTS/vc/modules/freevc/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..e799cc2a5bea018706abe7556780d1102e5d0889 --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/commons.py @@ -0,0 +1,164 @@ +import math + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def rand_spec_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/viXTTS/TTS/vc/modules/freevc/mel_processing.py b/viXTTS/TTS/vc/modules/freevc/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..2dcbf214935a1fde832a32139145ce87fa752598 --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/mel_processing.py @@ -0,0 +1,125 @@ +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/viXTTS/TTS/vc/modules/freevc/modules.py b/viXTTS/TTS/vc/modules/freevc/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb549900382ba838fdeb30256681223c847e119 --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/modules.py @@ -0,0 +1,387 @@ +import torch +from torch import nn +from torch.nn import Conv1d +from torch.nn import functional as F +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + +import TTS.vc.modules.freevc.commons as commons +from TTS.vc.modules.freevc.commons import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + remove_parametrizations(self.cond_layer, "weight") + for l in self.in_layers: + remove_parametrizations(l, "weight") + for l in self.res_skip_layers: + remove_parametrizations(l, "weight") + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_parametrizations(l, "weight") + for l in self.convs2: + remove_parametrizations(l, "weight") + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_parametrizations(l, "weight") + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x diff --git a/viXTTS/TTS/vc/modules/freevc/speaker_encoder/__init__.py b/viXTTS/TTS/vc/modules/freevc/speaker_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/vc/modules/freevc/speaker_encoder/audio.py b/viXTTS/TTS/vc/modules/freevc/speaker_encoder/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..52f6fd0893a1e67270ccd6af4a50d1370a8fa5c4 --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/speaker_encoder/audio.py @@ -0,0 +1,65 @@ +import struct +from pathlib import Path +from typing import Optional, Union + +# import webrtcvad +import librosa +import numpy as np +from scipy.ndimage.morphology import binary_dilation + +from TTS.vc.modules.freevc.speaker_encoder.hparams import * + +int16_max = (2**15) - 1 + + +def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None): + """ + Applies the preprocessing operations used in training the Speaker Encoder to a waveform + either on disk or in memory. The waveform will be resampled to match the data hyperparameters. + + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + just .wav), either the waveform as a numpy array of floats. + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + this argument will be ignored. + """ + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(fpath_or_wav, sr=None) + else: + wav = fpath_or_wav + + # Resample the wav if needed + if source_sr is not None and source_sr != sampling_rate: + wav = librosa.resample(wav, source_sr, sampling_rate) + + # Apply the preprocessing: normalize volume and shorten long silences + wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) + wav = trim_long_silences(wav) + + return wav + + +def wav_to_mel_spectrogram(wav): + """ + Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. + Note: this not a log-mel spectrogram. + """ + frames = librosa.feature.melspectrogram( + y=wav, + sr=sampling_rate, + n_fft=int(sampling_rate * mel_window_length / 1000), + hop_length=int(sampling_rate * mel_window_step / 1000), + n_mels=mel_n_channels, + ) + return frames.astype(np.float32).T + + +def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2)) + if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): + return wav + return wav * (10 ** (dBFS_change / 20)) diff --git a/viXTTS/TTS/vc/modules/freevc/speaker_encoder/hparams.py b/viXTTS/TTS/vc/modules/freevc/speaker_encoder/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..2c536ae16cf8134d66c83aaf978ed01fc396b680 --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/speaker_encoder/hparams.py @@ -0,0 +1,31 @@ +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + + +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 diff --git a/viXTTS/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/viXTTS/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2e21a14fd833ecf4f1b1d5b8573ac100dffacbfa --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py @@ -0,0 +1,175 @@ +from pathlib import Path +from time import perf_counter as timer +from typing import List, Union + +import numpy as np +import torch +from torch import nn + +from TTS.utils.io import load_fsspec +from TTS.vc.modules.freevc.speaker_encoder import audio +from TTS.vc.modules.freevc.speaker_encoder.hparams import * + + +class SpeakerEncoder(nn.Module): + def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True): + """ + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). + If None, defaults to cuda if it is available on your machine, otherwise the model will + run on cpu. Outputs are always returned on the cpu, as numpy arrays. + """ + super().__init__() + + # Define the network + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + # Get the target device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + device = torch.device(device) + self.device = device + + # Load the pretrained model'speaker weights + # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt") + # if not weights_fpath.exists(): + # raise Exception("Couldn't find the voice encoder pretrained model at %s." % + # weights_fpath) + + start = timer() + checkpoint = load_fsspec(weights_fpath, map_location="cpu") + + self.load_state_dict(checkpoint["model_state"], strict=False) + self.to(device) + + if verbose: + print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start)) + + def forward(self, mels: torch.FloatTensor): + """ + Computes the embeddings of a batch of utterance spectrograms. + :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). + Embeddings are positive and L2-normed, thus they lay in the range [0, 1]. + """ + # Pass the input through the LSTM layers and retrieve the final hidden state of the last + # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings. + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + @staticmethod + def compute_partial_slices(n_samples: int, rate, min_coverage): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to + obtain partial utterances of each. Both the waveform and the + mel spectrogram slices are returned, so as to make each partial utterance waveform + correspond to its spectrogram. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wav_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 < min_coverage <= 1 + + # Compute how many frames separate two partial utterances + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + assert 0 < frame_step, "The rate is too high" + assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % ( + sampling_rate / (samples_per_frame * partials_n_frames) + ) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partials_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partials_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + if coverage < min_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75): + """ + Computes an embedding for a single utterance. The utterance is divided in partial + utterances and an embedding is computed for each. The complete utterance embedding is the + L2-normed average embedding of the partial utterances. + + TODO: independent batched version of this function + + :param wav: a preprocessed utterance waveform as a numpy array of float32 + :param return_partials: if True, the partial embeddings will also be returned along with + the wav slices corresponding to each partial utterance. + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. + """ + # Compute where to split the utterance into partials and pad the waveform with zeros if + # the partial utterances cover a larger range. + wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) + max_wave_length = wav_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials and forward them through the model + mel = audio.wav_to_mel_spectrogram(wav) + mels = np.array([mel[s] for s in mel_slices]) + with torch.no_grad(): + mels = torch.from_numpy(mels).to(self.device) + partial_embeds = self(mels).cpu().numpy() + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wav_slices + return embed + + def embed_speaker(self, wavs: List[np.ndarray], **kwargs): + """ + Compute the embedding of a collection of wavs (presumably from the same speaker) by + averaging their embedding and L2-normalizing it. + + :param wavs: list of wavs a numpy arrays of float32. + :param kwargs: extra arguments to embed_utterance() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). + """ + raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0) + return raw_embed / np.linalg.norm(raw_embed, 2) diff --git a/viXTTS/TTS/vc/modules/freevc/wavlm/__init__.py b/viXTTS/TTS/vc/modules/freevc/wavlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6edada407b2210b5b99f6628e4f765a24c4d3dcb --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/wavlm/__init__.py @@ -0,0 +1,35 @@ +import os +import urllib.request + +import torch + +from TTS.utils.generic_utils import get_user_data_dir +from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig + +model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt" + + +def get_wavlm(device="cpu"): + """Download the model and return the model object.""" + + output_path = get_user_data_dir("tts") + + output_path = os.path.join(output_path, "wavlm") + if not os.path.exists(output_path): + os.makedirs(output_path) + + output_path = os.path.join(output_path, "WavLM-Large.pt") + if not os.path.exists(output_path): + print(f" > Downloading WavLM model to {output_path} ...") + urllib.request.urlretrieve(model_uri, output_path) + + checkpoint = torch.load(output_path, map_location=torch.device(device)) + cfg = WavLMConfig(checkpoint["cfg"]) + wavlm = WavLM(cfg).to(device) + wavlm.load_state_dict(checkpoint["model"]) + wavlm.eval() + return wavlm + + +if __name__ == "__main__": + wavlm = get_wavlm() diff --git a/viXTTS/TTS/vc/modules/freevc/wavlm/config.json b/viXTTS/TTS/vc/modules/freevc/wavlm/config.json new file mode 100644 index 0000000000000000000000000000000000000000..c6f851b93d5eee0c975af1e6708735bbf9bb8be4 --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/wavlm/config.json @@ -0,0 +1,99 @@ +{ + "_name_or_path": "./wavlm-large/", + "activation_dropout": 0.0, + "adapter_kernel_size": 3, + "adapter_stride": 2, + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ + "WavLMModel" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "classifier_proj_size": 256, + "codevector_dim": 768, + "contrastive_logits_temperature": 0.1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "feat_quantizer_dropout": 0.0, + "final_dropout": 0.0, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.075, + "mask_time_selection": "static", + "max_bucket_distance": 800, + "model_type": "wavlm", + "num_adapter_layers": 3, + "num_attention_heads": 16, + "num_buckets": 320, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_ctc_classes": 80, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "num_negatives": 100, + "output_hidden_size": 1024, + "pad_token_id": 0, + "proj_codevector_dim": 768, + "replace_prob": 0.5, + "tokenizer_class": "Wav2Vec2CTCTokenizer", + "torch_dtype": "float32", + "transformers_version": "4.15.0.dev0", + "use_weighted_layer_sum": false, + "vocab_size": 32 + } \ No newline at end of file diff --git a/viXTTS/TTS/vc/modules/freevc/wavlm/modules.py b/viXTTS/TTS/vc/modules/freevc/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..37c1a6e8774cdfd439baa38a8a7ad55fd79ebf7c --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/wavlm/modules.py @@ -0,0 +1,768 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function""" + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2] + else: + x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2]) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate") + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros(in_features // block_size * out_features, device=weight.device) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) + mask.bernoulli_(p) + mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + + # scale weights and apply mask + mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size) + self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size) + + self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/viXTTS/TTS/vc/modules/freevc/wavlm/wavlm.py b/viXTTS/TTS/vc/modules/freevc/wavlm/wavlm.py new file mode 100644 index 0000000000000000000000000000000000000000..fc93bd4f50773582c9fca5497ed639b2cfe23058 --- /dev/null +++ b/viXTTS/TTS/vc/modules/freevc/wavlm/wavlm.py @@ -0,0 +1,719 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import logging +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from TTS.vc.modules.freevc.wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GLU_Linear, + GradMultiply, + MultiheadAttention, + SamePad, + TransposeLast, + get_activation_fn, + init_bert_params, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + ) + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + ) + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + # padding_mask = padding_mask.all(-1) + padding_mask = padding_mask.any(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default", + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride)) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1)) + self.conv_layers.append(torch.nn.LayerNorm([dim, idim])) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + self_attn_mask=streaming_mask, + pos_bias=pos_bias, + ) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias diff --git a/viXTTS/TTS/vocoder/README.md b/viXTTS/TTS/vocoder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b9fb17c8f09fa6e8c217087e31fb8c52d96da536 --- /dev/null +++ b/viXTTS/TTS/vocoder/README.md @@ -0,0 +1,39 @@ +# Mozilla TTS Vocoders (Experimental) + +Here there are vocoder model implementations which can be combined with the other TTS models. + +Currently, following models are implemented: + +- Melgan +- MultiBand-Melgan +- ParallelWaveGAN +- GAN-TTS (Discriminator Only) + +It is also very easy to adapt different vocoder models as we provide a flexible and modular (but not too modular) framework. + +## Training a model + +You can see here an example (Soon)[Colab Notebook]() training MelGAN with LJSpeech dataset. + +In order to train a new model, you need to gather all wav files into a folder and give this folder to `data_path` in '''config.json''' + +You need to define other relevant parameters in your ```config.json``` and then start traning with the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --config_path path/to/config.json``` + +Example config files can be found under `tts/vocoder/configs/` folder. + +You can continue a previous training run by the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --continue_path path/to/your/model/folder``` + +You can fine-tune a pre-trained model by the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --restore_path path/to/your/model.pth``` + +Restoring a model starts a new training in a different folder. It only restores model weights with the given checkpoint file. However, continuing a training starts from the same directory where the previous training run left off. + +You can also follow your training runs on Tensorboard as you do with our TTS models. + +## Acknowledgement +Thanks to @kan-bayashi for his [repository](https://github.com/kan-bayashi/ParallelWaveGAN) being the start point of our work. diff --git a/viXTTS/TTS/vocoder/__init__.py b/viXTTS/TTS/vocoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/vocoder/__pycache__/__init__.cpython-310.pyc b/viXTTS/TTS/vocoder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6af2e2e4801addc88a646b15200753e3ca56f2f Binary files /dev/null and b/viXTTS/TTS/vocoder/__pycache__/__init__.cpython-310.pyc differ diff --git a/viXTTS/TTS/vocoder/configs/__init__.py b/viXTTS/TTS/vocoder/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e11b990c6d7294e7cb00c3e024bbb5f94a8105 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +configs_dir = os.path.dirname(__file__) +for file in os.listdir(configs_dir): + path = os.path.join(configs_dir, file) + if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): + config_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("TTS.vocoder.configs." + config_name) + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + + if isclass(attribute): + # Add the class to this package's variables + globals()[attribute_name] = attribute diff --git a/viXTTS/TTS/vocoder/configs/fullband_melgan_config.py b/viXTTS/TTS/vocoder/configs/fullband_melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab83aace678e328a8f99a5f0dc63e54ed99d4c4 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/fullband_melgan_config.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, field + +from .shared_configs import BaseGANVocoderConfig + + +@dataclass +class FullbandMelganConfig(BaseGANVocoderConfig): + """Defines parameters for FullBand MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import FullbandMelganConfig + >>> config = FullbandMelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fullband_melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "fullband_melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]} + ) + generator_model: str = "melgan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 4} + ) + + # Training - overrides + batch_size: int = 16 + seq_len: int = 8192 + pad_short: int = 2000 + use_noise_augment: bool = True + use_cache: bool = True + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0.0 diff --git a/viXTTS/TTS/vocoder/configs/hifigan_config.py b/viXTTS/TTS/vocoder/configs/hifigan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9a102f0c89588b1a7fe270225e4b0fefa2e4bc71 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/hifigan_config.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class HifiganConfig(BaseGANVocoderConfig): + """Defines parameters for FullBand MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import HifiganConfig + >>> config = HifiganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `hifigan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'hifigan_discriminator`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `hifigan_generator`. + generator_model_params (dict): Parameters of the generator model. Defaults to + ` + { + "upsample_factors": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_type": "1", + } + ` + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): + STFT loss parameters. Default to + `{ + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240] + }` + l1_spec_loss_params (dict): + L1 spectrogram loss parameters. Default to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "hifigan" + # model specific params + discriminator_model: str = "hifigan_discriminator" + generator_model: str = "hifigan_generator" + generator_model_params: dict = field( + default_factory=lambda: { + "upsample_factors": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_type": "1", + } + ) + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = False + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = True + + # loss weights - overrides + stft_loss_weight: float = 0 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 45 + l1_spec_loss_params: dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + # optimizer parameters + lr: float = 1e-4 + wd: float = 1e-6 diff --git a/viXTTS/TTS/vocoder/configs/melgan_config.py b/viXTTS/TTS/vocoder/configs/melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dc35b6f8b70891d4904baefad802d9c62fe67925 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/melgan_config.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class MelganConfig(BaseGANVocoderConfig): + """Defines parameters for MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import MelganConfig + >>> config = MelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]} + ) + generator_model: str = "melgan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 3} + ) + + # Training - overrides + batch_size: int = 16 + seq_len: int = 8192 + pad_short: int = 2000 + use_noise_augment: bool = True + use_cache: bool = True + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0 diff --git a/viXTTS/TTS/vocoder/configs/multiband_melgan_config.py b/viXTTS/TTS/vocoder/configs/multiband_melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..763113537f36a8615b2b77369bf5bde01527fe53 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/multiband_melgan_config.py @@ -0,0 +1,144 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class MultibandMelganConfig(BaseGANVocoderConfig): + """Defines parameters for MultiBandMelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import MultibandMelganConfig + >>> config = MultibandMelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `multiband_melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{ + "base_channels": 16, + "max_channels": 512, + "downsample_factors": [4, 4, 4] + }` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + generator_model_param (dict): + The generator model parameters. Defaults to `{"upsample_factors": [8, 4, 2], "num_res_blocks": 4}`. + use_pqmf (bool): + enable / disable PQMF modulation for multi-band training. Defaults to True. + lr_gen (float): + Initial learning rate for the generator model. Defaults to 0.0001. + lr_disc (float): + Initial learning rate for the discriminator model. Defaults to 0.0001. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `MultiStepLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to + `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `MultiStepLR`. + lr_scheduler_dict_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to + `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + use_stft_loss (bool):` + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "multiband_melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]} + ) + generator_model: str = "multiband_melgan_generator" + generator_model_params: dict = field(default_factory=lambda: {"upsample_factors": [8, 4, 2], "num_res_blocks": 4}) + use_pqmf: bool = True + + # optimizer - overrides + lr_gen: float = 0.0001 # Initial learning rate. + lr_disc: float = 0.0001 # Initial learning rate. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + lr_scheduler_gen: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) + lr_scheduler_disc: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) + + # Training - overrides + batch_size: int = 64 + seq_len: int = 16384 + pad_short: int = 2000 + use_noise_augment: bool = False + use_cache: bool = True + steps_to_start_discriminator: bool = 200000 + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = True + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + subband_stft_loss_params: dict = field( + default_factory=lambda: {"n_ffts": [384, 683, 171], "hop_lengths": [30, 60, 10], "win_lengths": [150, 300, 60]} + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0 diff --git a/viXTTS/TTS/vocoder/configs/parallel_wavegan_config.py b/viXTTS/TTS/vocoder/configs/parallel_wavegan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6059d7f04f7c9fadd6df1d424e8e164e54a7310e --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/parallel_wavegan_config.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass, field + +from .shared_configs import BaseGANVocoderConfig + + +@dataclass +class ParallelWaveganConfig(BaseGANVocoderConfig): + """Defines parameters for ParallelWavegan vocoder. + + Args: + model (str): + Model name used for selecting the right configuration at initialization. Defaults to `gan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'parallel_wavegan_discriminator`. + discriminator_model_params (dict): The discriminator model kwargs. Defaults to + '{"num_layers": 10}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `parallel_wavegan_generator`. + generator_model_param (dict): + The generator model kwargs. Defaults to `{"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + use_stft_loss (bool):` + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 0. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + lr_gen (float): + Generator model initial learning rate. Defaults to 0.0002. + lr_disc (float): + Discriminator model initial learning rate. Defaults to 0.0002. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `ExponentialLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. + lr_scheduler_dict_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. + """ + + model: str = "parallel_wavegan" + + # Model specific params + discriminator_model: str = "parallel_wavegan_discriminator" + discriminator_model_params: dict = field(default_factory=lambda: {"num_layers": 10}) + generator_model: str = "parallel_wavegan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30} + ) + + # Training - overrides + batch_size: int = 6 + seq_len: int = 25600 + pad_short: int = 2000 + use_noise_augment: bool = False + use_cache: bool = True + steps_to_start_discriminator: int = 200000 + target_loss: str = "loss_1" + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 0 + l1_spec_loss_weight: float = 0 + + # optimizer overrides + lr_gen: float = 0.0002 # Initial learning rate. + lr_disc: float = 0.0002 # Initial learning rate. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) + lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1} + ) + scheduler_after_epoch: bool = False diff --git a/viXTTS/TTS/vocoder/configs/shared_configs.py b/viXTTS/TTS/vocoder/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..a558cfcabbc2abc26be60065d3ac75cebd829f28 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/shared_configs.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field + +from TTS.config import BaseAudioConfig, BaseTrainingConfig + + +@dataclass +class BaseVocoderConfig(BaseTrainingConfig): + """Shared parameters among all the vocoder models. + Args: + audio (BaseAudioConfig): + Audio processor config instance. Defaultsto `BaseAudioConfig()`. + use_noise_augment (bool): + Augment the input audio with random noise. Defaults to False/ + eval_split_size (int): + Number of instances used for evaluation. Defaults to 10. + data_path (str): + Root path of the training data. All the audio files found recursively from this root path are used for + training. Defaults to `""`. + feature_path (str): + Root path to the precomputed feature files. Defaults to None. + seq_len (int): + Length of the waveform segments used for training. Defaults to 1000. + pad_short (int): + Extra padding for the waveforms shorter than `seq_len`. Defaults to 0. + conv_path (int): + Extra padding for the feature frames against convolution of the edge frames. Defaults to MISSING. + Defaults to 0. + use_cache (bool): + enable / disable in memory caching of the computed features. If the RAM is not enough, if may cause OOM. + Defaults to False. + epochs (int): + Number of training epochs to. Defaults to 10000. + wd (float): + Weight decay. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # dataloading + use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms. + eval_split_size: int = 10 # number of samples used for evaluation. + # dataset + data_path: str = "" # root data path. It finds all wav files recursively from there. + feature_path: str = None # if you use precomputed features + seq_len: int = 1000 # signal length used in training. + pad_short: int = 0 # additional padding for short wavs + conv_pad: int = 0 # additional padding against convolutions applied to spectrograms + use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM. + # OPTIMIZER + epochs: int = 10000 # total number of epochs to train. + wd: float = 0.0 # Weight decay weight. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + + +@dataclass +class BaseGANVocoderConfig(BaseVocoderConfig): + """Base config class used among all the GAN based vocoders. + Args: + use_stft_loss (bool): + enable / disable the use of STFT loss. Defaults to True. + use_subband_stft_loss (bool): + enable / disable the use of Subband STFT loss. Defaults to True. + use_mse_gan_loss (bool): + enable / disable the use of Mean Squared Error based GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable the use of Hinge GAN loss. Defaults to True. + use_feat_match_loss (bool): + enable / disable feature matching loss. Defaults to True. + use_l1_spec_loss (bool): + enable / disable L1 spectrogram loss. Defaults to True. + stft_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + subband_stft_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + mse_G_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 1. + hinge_G_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + feat_match_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 100. + l1_spec_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 45. + stft_loss_params (dict): + Parameters for the STFT loss. Defaults to `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`. + l1_spec_loss_params (dict): + Parameters for the L1 spectrogram loss. Defaults to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + target_loss (str): + Target loss name that defines the quality of the model. Defaults to `G_avg_loss`. + grad_clip (list): + A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping. + Defaults to [5, 5]. + lr_gen (float): + Generator model initial learning rate. Defaults to 0.0002. + lr_disc (float): + Discriminator model initial learning rate. Defaults to 0.0002. + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `ExponentialLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. + lr_scheduler_disc_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + scheduler_after_epoch (bool): + Whether to update the learning rate schedulers after each epoch. Defaults to True. + use_pqmf (bool): + enable / disable PQMF for subband approximation at training. Defaults to False. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + diff_samples_for_G_and_D (bool): + enable / disable use of different training samples for the generator and the discriminator iterations. + Enabling it results in slower iterations but faster convergance in some cases. Defaults to False. + """ + + model: str = "gan" + + # LOSS PARAMETERS + use_stft_loss: bool = True + use_subband_stft_loss: bool = True + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = True + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = True + + # loss weights + stft_loss_weight: float = 0 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 100 + l1_spec_loss_weight: float = 45 + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + l1_spec_loss_params: dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch + + # optimizer + grad_clip: float = field(default_factory=lambda: [5, 5]) + lr_gen: float = 0.0002 # Initial learning rate. + lr_disc: float = 0.0002 # Initial learning rate. + lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + scheduler_after_epoch: bool = True + + use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN) + steps_to_start_discriminator = 0 # start training the discriminator after this number of steps. + diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps. diff --git a/viXTTS/TTS/vocoder/configs/univnet_config.py b/viXTTS/TTS/vocoder/configs/univnet_config.py new file mode 100644 index 0000000000000000000000000000000000000000..67f324cfce5f701f0d7453beab81590bef6be114 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/univnet_config.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass, field +from typing import Dict + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class UnivnetConfig(BaseGANVocoderConfig): + """Defines parameters for UnivNet vocoder. + + Example: + + >>> from TTS.vocoder.configs import UnivNetConfig + >>> config = UnivNetConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `UnivNet`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'UnivNet_discriminator`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `UnivNet_generator`. + generator_model_params (dict): Parameters of the generator model. Defaults to + ` + { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ` + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 32. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False. + stft_loss_params (dict): + STFT loss parameters. Default to + `{ + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240] + }` + l1_spec_loss_params (dict): + L1 spectrogram loss parameters. Default to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "univnet" + batch_size: int = 32 + # model specific params + discriminator_model: str = "univnet_discriminator" + generator_model: str = "univnet_generator" + generator_model_params: Dict = field( + default_factory=lambda: { + "in_channels": 64, + "out_channels": 1, + "hidden_channels": 32, + "cond_channels": 80, + "upsample_factors": [8, 8, 4], + "lvc_layers_each_block": 4, + "lvc_kernel_size": 3, + "kpnet_hidden_channels": 64, + "kpnet_conv_size": 3, + "dropout": 0.0, + } + ) + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet) + use_l1_spec_loss: bool = False + + # loss weights - overrides + stft_loss_weight: float = 2.5 + stft_loss_params: Dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 0 + l1_spec_loss_weight: float = 0 + l1_spec_loss_params: Dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + # optimizer parameters + lr_gen: float = 1e-4 # Initial learning rate. + lr_disc: float = 1e-4 # Initial learning rate. + lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + # lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + # lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0}) + steps_to_start_discriminator: int = 200000 + + def __post_init__(self): + super().__post_init__() + self.generator_model_params["cond_channels"] = self.audio.num_mels diff --git a/viXTTS/TTS/vocoder/configs/wavegrad_config.py b/viXTTS/TTS/vocoder/configs/wavegrad_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c39813ae68c3d8c77614c9a5188ac5f2a59d991d --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/wavegrad_config.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseVocoderConfig +from TTS.vocoder.models.wavegrad import WavegradArgs + + +@dataclass +class WavegradConfig(BaseVocoderConfig): + """Defines parameters for WaveGrad vocoder. + Example: + + >>> from TTS.vocoder.configs import WavegradConfig + >>> config = WavegradConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `wavegrad`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `wavegrad`. + model_params (WavegradArgs): Model parameters. Check `WavegradArgs` for default values. + target_loss (str): + Target loss name that defines the quality of the model. Defaults to `avg_wavegrad_loss`. + epochs (int): + Number of epochs to traing the model. Defaults to 10000. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 96. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 6144. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + mixed_precision (bool): + enable / disable mixed precision training. Default is True. + eval_split_size (int): + Number of samples used for evalutaion. Defaults to 50. + train_noise_schedule (dict): + Training noise schedule. Defaults to + `{"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}` + test_noise_schedule (dict): + Inference noise schedule. For a better performance, you may need to use `bin/tune_wavegrad.py` to find a + better schedule. Defaults to + ` + { + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50, + } + ` + grad_clip (float): + Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 1.0 + lr (float): + Initila leraning rate. Defaults to 1e-4. + lr_scheduler (str): + One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`. + lr_scheduler_params (dict): + kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}` + """ + + model: str = "wavegrad" + # Model specific params + generator_model: str = "wavegrad" + model_params: WavegradArgs = field(default_factory=WavegradArgs) + target_loss: str = "loss" # loss value to pick the best model to save after each epoch + + # Training - overrides + epochs: int = 10000 + batch_size: int = 96 + seq_len: int = 6144 + use_cache: bool = True + mixed_precision: bool = True + eval_split_size: int = 50 + + # NOISE SCHEDULE PARAMS + train_noise_schedule: dict = field(default_factory=lambda: {"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}) + + test_noise_schedule: dict = field( + default_factory=lambda: { # inference noise schedule. Try TTS/bin/tune_wavegrad.py to find the optimal values. + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50, + } + ) + + # optimizer overrides + grad_clip: float = 1.0 + lr: float = 1e-4 # Initial learning rate. + lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) diff --git a/viXTTS/TTS/vocoder/configs/wavernn_config.py b/viXTTS/TTS/vocoder/configs/wavernn_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f39400e5e50b56d4ff79c8c148fd518b3ec3b390 --- /dev/null +++ b/viXTTS/TTS/vocoder/configs/wavernn_config.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseVocoderConfig +from TTS.vocoder.models.wavernn import WavernnArgs + + +@dataclass +class WavernnConfig(BaseVocoderConfig): + """Defines parameters for Wavernn vocoder. + Example: + + >>> from TTS.vocoder.configs import WavernnConfig + >>> config = WavernnConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `wavernn`. + mode (str): + Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single + Gaussian Distribution and `bits` for quantized bits as the model's output. + mulaw (bool): + enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults + to `True`. + generator_model (str): + One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `WaveRNN`. + wavernn_model_params (dict): + kwargs for the WaveRNN model. Defaults to + `{ + "rnn_dims": 512, + "fc_dims": 512, + "compute_dims": 128, + "res_out_dims": 128, + "num_res_blocks": 10, + "use_aux_net": True, + "use_upsample_net": True, + "upsample_factors": [4, 8, 8] + }` + batched (bool): + enable / disable the batched inference. It speeds up the inference by splitting the input into segments and + processing the segments in a batch. Then it merges the outputs with a certain overlap and smoothing. If + you set it False, without CUDA, it is too slow to be practical. Defaults to True. + target_samples (int): + Size of the segments in batched mode. Defaults to 11000. + overlap_sampels (int): + Size of the overlap between consecutive segments. Defaults to 550. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 256. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 1280. + + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + mixed_precision (bool): + enable / disable mixed precision training. Default is True. + eval_split_size (int): + Number of samples used for evalutaion. Defaults to 50. + num_epochs_before_test (int): + Number of epochs waited to run the next evalution. Since inference takes some time, it is better to + wait some number of epochs not ot waste training time. Defaults to 10. + grad_clip (float): + Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 4.0 + lr (float): + Initila leraning rate. Defaults to 1e-4. + lr_scheduler (str): + One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`. + lr_scheduler_params (dict): + kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [200000, 400000, 600000]}` + """ + + model: str = "wavernn" + + # Model specific params + model_args: WavernnArgs = field(default_factory=WavernnArgs) + target_loss: str = "loss" + + # Inference + batched: bool = True + target_samples: int = 11000 + overlap_samples: int = 550 + + # Training - overrides + epochs: int = 10000 + batch_size: int = 256 + seq_len: int = 1280 + use_noise_augment: bool = False + use_cache: bool = True + mixed_precision: bool = True + eval_split_size: int = 50 + num_epochs_before_test: int = ( + 10 # number of epochs to wait until the next test run (synthesizing a full audio clip). + ) + + # optimizer overrides + grad_clip: float = 4.0 + lr: float = 1e-4 # Initial learning rate. + lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.5, "milestones": [200000, 400000, 600000]}) diff --git a/viXTTS/TTS/vocoder/datasets/__init__.py b/viXTTS/TTS/vocoder/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..871eb0d20276ffc691fd6da796bf65df6c23ea0d --- /dev/null +++ b/viXTTS/TTS/vocoder/datasets/__init__.py @@ -0,0 +1,58 @@ +from typing import List + +from coqpit import Coqpit +from torch.utils.data import Dataset + +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset + + +def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset: + if config.model.lower() in "gan": + dataset = GANDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + verbose=verbose, + ) + dataset.shuffle_mapping() + elif config.model.lower() == "wavegrad": + dataset = WaveGradDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + verbose=verbose, + ) + elif config.model.lower() == "wavernn": + dataset = WaveRNNDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_params.pad, + mode=config.model_params.mode, + mulaw=config.model_params.mulaw, + is_training=not is_eval, + verbose=verbose, + ) + else: + raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") + return dataset diff --git a/viXTTS/TTS/vocoder/datasets/gan_dataset.py b/viXTTS/TTS/vocoder/datasets/gan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..50c38c4deb8fd861f7cef8144df3098c3558aeb4 --- /dev/null +++ b/viXTTS/TTS/vocoder/datasets/gan_dataset.py @@ -0,0 +1,152 @@ +import glob +import os +import random +from multiprocessing import Manager + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class GANDataset(Dataset): + """ + GAN Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly and returns + random segments of (audio, feature) couples. + """ + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + return_pairs=False, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): + super().__init__() + self.ap = ap + self.item_list = items + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.seq_len = seq_len + self.hop_len = hop_len + self.pad_short = pad_short + self.conv_pad = conv_pad + self.return_pairs = return_pairs + self.is_training = is_training + self.return_segments = return_segments + self.use_cache = use_cache + self.use_noise_augment = use_noise_augment + self.verbose = verbose + + assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." + self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) + + # map G and D instances + self.G_to_D_mappings = list(range(len(self.item_list))) + self.shuffle_mapping() + + # cache acoustic features + if use_cache: + self.create_feature_cache() + + def create_feature_cache(self): + self.manager = Manager() + self.cache = self.manager.list() + self.cache += [None for _ in range(len(self.item_list))] + + @staticmethod + def find_wav_files(path): + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, idx): + """Return different items for Generator and Discriminator and + cache acoustic features""" + + # set the seed differently for each worker + if torch.utils.data.get_worker_info(): + random.seed(torch.utils.data.get_worker_info().seed) + + if self.return_segments: + item1 = self.load_item(idx) + if self.return_pairs: + idx2 = self.G_to_D_mappings[idx] + item2 = self.load_item(idx2) + return item1, item2 + return item1 + item1 = self.load_item(idx) + return item1 + + def _pad_short_samples(self, audio, mel=None): + """Pad samples shorter than the output sequence length""" + if len(audio) < self.seq_len: + audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0) + + if mel is not None and mel.shape[1] < self.feat_frame_len: + pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0] + mel = np.pad( + mel, + ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), + mode="constant", + constant_values=pad_value.mean(), + ) + return audio, mel + + def shuffle_mapping(self): + random.shuffle(self.G_to_D_mappings) + + def load_item(self, idx): + """load (audio, feat) couple""" + if self.compute_feat: + # compute features from wav + wavpath = self.item_list[idx] + # print(wavpath) + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = self.ap.melspectrogram(audio) + audio, mel = self._pad_short_samples(audio, mel) + else: + # load precomputed features + wavpath, feat_path = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = np.load(feat_path) + audio, mel = self._pad_short_samples(audio, mel) + + # correct the audio length wrt padding applied in stft + audio = np.pad(audio, (0, self.hop_len), mode="edge") + audio = audio[: mel.shape[-1] * self.hop_len] + assert ( + mel.shape[-1] * self.hop_len == audio.shape[-1] + ), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}" + + audio = torch.from_numpy(audio).float().unsqueeze(0) + mel = torch.from_numpy(mel).float().squeeze(0) + + if self.return_segments: + max_mel_start = mel.shape[1] - self.feat_frame_len + mel_start = random.randint(0, max_mel_start) + mel_end = mel_start + self.feat_frame_len + mel = mel[:, mel_start:mel_end] + + audio_start = mel_start * self.hop_len + audio = audio[:, audio_start : audio_start + self.seq_len] + + if self.use_noise_augment and self.is_training and self.return_segments: + audio = audio + (1 / 32768) * torch.randn_like(audio) + return (mel, audio) diff --git a/viXTTS/TTS/vocoder/datasets/preprocess.py b/viXTTS/TTS/vocoder/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..503bb04b2fba637ec3b8ab449018558898f05024 --- /dev/null +++ b/viXTTS/TTS/vocoder/datasets/preprocess.py @@ -0,0 +1,75 @@ +import glob +import os +from pathlib import Path + +import numpy as np +from coqpit import Coqpit +from tqdm import tqdm + +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize + + +def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): + """Process wav and compute mel and quantized wave signal. + It is mainly used by WaveRNN dataloader. + + Args: + out_path (str): Parent folder path to save the files. + config (Coqpit): Model config. + ap (AudioProcessor): Audio processor. + """ + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + wav_files = find_wav_files(config.data_path) + for path in tqdm(wav_files): + wav_name = Path(path).stem + quant_path = os.path.join(out_path, "quant", wav_name + ".npy") + mel_path = os.path.join(out_path, "mel", wav_name + ".npy") + y = ap.load_wav(path) + mel = ap.melspectrogram(y) + np.save(mel_path, mel) + if isinstance(config.mode, int): + quant = ( + mulaw_encode(wav=y, mulaw_qc=config.mode) + if config.model_args.mulaw + else quantize(x=y, quantize_bits=config.mode) + ) + np.save(quant_path, quant) + + +def find_wav_files(data_path, file_ext="wav"): + wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True) + return wav_paths + + +def find_feat_files(data_path): + feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True) + return feat_paths + + +def load_wav_data(data_path, eval_split_size, file_ext="wav"): + wav_paths = find_wav_files(data_path, file_ext=file_ext) + assert len(wav_paths) > 0, f" [!] {data_path} is empty." + np.random.seed(0) + np.random.shuffle(wav_paths) + return wav_paths[:eval_split_size], wav_paths[eval_split_size:] + + +def load_wav_feat_data(data_path, feat_path, eval_split_size): + wav_paths = find_wav_files(data_path) + feat_paths = find_feat_files(feat_path) + + wav_paths.sort(key=lambda x: Path(x).stem) + feat_paths.sort(key=lambda x: Path(x).stem) + + assert len(wav_paths) == len(feat_paths), f" [!] {len(wav_paths)} vs {feat_paths}" + for wav, feat in zip(wav_paths, feat_paths): + wav_name = Path(wav).stem + feat_name = Path(feat).stem + assert wav_name == feat_name + + items = list(zip(wav_paths, feat_paths)) + np.random.seed(0) + np.random.shuffle(items) + return items[:eval_split_size], items[eval_split_size:] diff --git a/viXTTS/TTS/vocoder/datasets/wavegrad_dataset.py b/viXTTS/TTS/vocoder/datasets/wavegrad_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..305fe430e3da880c03aee625525ce825c8ef87a3 --- /dev/null +++ b/viXTTS/TTS/vocoder/datasets/wavegrad_dataset.py @@ -0,0 +1,151 @@ +import glob +import os +import random +from multiprocessing import Manager +from typing import List, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class WaveGradDataset(Dataset): + """ + WaveGrad Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly and returns + random segments of (audio, feature) couples. + """ + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): + super().__init__() + self.ap = ap + self.item_list = items + self.seq_len = seq_len if return_segments else None + self.hop_len = hop_len + self.pad_short = pad_short + self.conv_pad = conv_pad + self.is_training = is_training + self.return_segments = return_segments + self.use_cache = use_cache + self.use_noise_augment = use_noise_augment + self.verbose = verbose + + if return_segments: + assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." + self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) + + # cache acoustic features + if use_cache: + self.create_feature_cache() + + def create_feature_cache(self): + self.manager = Manager() + self.cache = self.manager.list() + self.cache += [None for _ in range(len(self.item_list))] + + @staticmethod + def find_wav_files(path): + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, idx): + item = self.load_item(idx) + return item + + def load_test_samples(self, num_samples: int) -> List[Tuple]: + """Return test samples. + + Args: + num_samples (int): Number of samples to return. + + Returns: + List[Tuple]: melspectorgram and audio. + + Shapes: + - melspectrogram (Tensor): :math:`[C, T]` + - audio (Tensor): :math:`[T_audio]` + """ + samples = [] + return_segments = self.return_segments + self.return_segments = False + for idx in range(num_samples): + mel, audio = self.load_item(idx) + samples.append([mel, audio]) + self.return_segments = return_segments + return samples + + def load_item(self, idx): + """load (audio, feat) couple""" + # compute features from wav + wavpath = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + + if self.return_segments: + # correct audio length wrt segment length + if audio.shape[-1] < self.seq_len + self.pad_short: + audio = np.pad( + audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0 + ) + assert ( + audio.shape[-1] >= self.seq_len + self.pad_short + ), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" + + # correct the audio length wrt hop length + p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1] + audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0) + + if self.use_cache: + self.cache[idx] = audio + + if self.return_segments: + max_start = len(audio) - self.seq_len + start = random.randint(0, max_start) + end = start + self.seq_len + audio = audio[start:end] + + if self.use_noise_augment and self.is_training and self.return_segments: + audio = audio + (1 / 32768) * torch.randn_like(audio) + + mel = self.ap.melspectrogram(audio) + mel = mel[..., :-1] # ignore the padding + + audio = torch.from_numpy(audio).float() + mel = torch.from_numpy(mel).float().squeeze(0) + return (mel, audio) + + @staticmethod + def collate_full_clips(batch): + """This is used in tune_wavegrad.py. + It pads sequences to the max length.""" + max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1] + max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0] + + mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length]) + audios = torch.zeros([len(batch), max_audio_length]) + + for idx, b in enumerate(batch): + mel = b[0] + audio = b[1] + mels[idx, :, : mel.shape[1]] = mel + audios[idx, : audio.shape[0]] = audio + + return mels, audios diff --git a/viXTTS/TTS/vocoder/datasets/wavernn_dataset.py b/viXTTS/TTS/vocoder/datasets/wavernn_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a67c5b31a0902c2d1ae97a0357ba009869835711 --- /dev/null +++ b/viXTTS/TTS/vocoder/datasets/wavernn_dataset.py @@ -0,0 +1,118 @@ +import numpy as np +import torch +from torch.utils.data import Dataset + +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize + + +class WaveRNNDataset(Dataset): + """ + WaveRNN Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly. + """ + + def __init__( + self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True + ): + super().__init__() + self.ap = ap + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.item_list = items + self.seq_len = seq_len + self.hop_len = hop_len + self.mel_len = seq_len // hop_len + self.pad = pad + self.mode = mode + self.mulaw = mulaw + self.is_training = is_training + self.verbose = verbose + self.return_segments = return_segments + + assert self.seq_len % self.hop_len == 0 + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, index): + item = self.load_item(index) + return item + + def load_test_samples(self, num_samples): + samples = [] + return_segments = self.return_segments + self.return_segments = False + for idx in range(num_samples): + mel, audio, _ = self.load_item(idx) + samples.append([mel, audio]) + self.return_segments = return_segments + return samples + + def load_item(self, index): + """ + load (audio, feat) couple if feature_path is set + else compute it on the fly + """ + if self.compute_feat: + wavpath = self.item_list[index] + audio = self.ap.load_wav(wavpath) + if self.return_segments: + min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) + else: + min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) + if audio.shape[0] < min_audio_len: + print(" [!] Instance is too short! : {}".format(wavpath)) + audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) + mel = self.ap.melspectrogram(audio) + + if self.mode in ["gauss", "mold"]: + x_input = audio + elif isinstance(self.mode, int): + x_input = ( + mulaw_encode(wav=audio, mulaw_qc=self.mode) + if self.mulaw + else quantize(x=audio, quantize_bits=self.mode) + ) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + else: + wavpath, feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + + if mel.shape[-1] < self.mel_len + 2 * self.pad: + print(" [!] Instance is too short! : {}".format(wavpath)) + self.item_list[index] = self.item_list[index + 1] + feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + if self.mode in ["gauss", "mold"]: + x_input = self.ap.load_wav(wavpath) + elif isinstance(self.mode, int): + x_input = np.load(feat_path.replace("/mel/", "/quant/")) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + return mel, x_input, wavpath + + def collate(self, batch): + mel_win = self.seq_len // self.hop_len + 2 * self.pad + max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] + + mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] + sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] + + mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)] + + coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)] + + mels = np.stack(mels).astype(np.float32) + if self.mode in ["gauss", "mold"]: + coarse = np.stack(coarse).astype(np.float32) + coarse = torch.FloatTensor(coarse) + x_input = coarse[:, : self.seq_len] + elif isinstance(self.mode, int): + coarse = np.stack(coarse).astype(np.int64) + coarse = torch.LongTensor(coarse) + x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0 + y_coarse = coarse[:, 1:] + mels = torch.FloatTensor(mels) + return x_input, mels, y_coarse diff --git a/viXTTS/TTS/vocoder/layers/__init__.py b/viXTTS/TTS/vocoder/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/vocoder/layers/hifigan.py b/viXTTS/TTS/vocoder/layers/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd75133bbc38b95ce6ea6a9a05079b9b205a539 --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/hifigan.py @@ -0,0 +1,56 @@ +from torch import nn +from torch.nn.utils.parametrize import remove_parametrizations + + +# pylint: disable=dangerous-default-value +class ResStack(nn.Module): + def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]): + super().__init__() + resstack = [] + for dilation in dilations: + resstack += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + nn.utils.parametrizations.weight_norm( + nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation) + ), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(padding), + nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), + ] + self.resstack = nn.Sequential(*resstack) + + self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) + + def forward(self, x): + x1 = self.shortcut(x) + x2 = self.resstack(x) + return x1 + x2 + + def remove_weight_norm(self): + remove_parametrizations(self.shortcut, "weight") + remove_parametrizations(self.resstack[2], "weight") + remove_parametrizations(self.resstack[5], "weight") + remove_parametrizations(self.resstack[8], "weight") + remove_parametrizations(self.resstack[11], "weight") + remove_parametrizations(self.resstack[14], "weight") + remove_parametrizations(self.resstack[17], "weight") + + +class MRF(nn.Module): + def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value + super().__init__() + self.resblock1 = ResStack(kernels[0], channel, 0, dilations) + self.resblock2 = ResStack(kernels[1], channel, 6, dilations) + self.resblock3 = ResStack(kernels[2], channel, 12, dilations) + + def forward(self, x): + x1 = self.resblock1(x) + x2 = self.resblock2(x) + x3 = self.resblock3(x) + return x1 + x2 + x3 + + def remove_weight_norm(self): + self.resblock1.remove_weight_norm() + self.resblock2.remove_weight_norm() + self.resblock3.remove_weight_norm() diff --git a/viXTTS/TTS/vocoder/layers/losses.py b/viXTTS/TTS/vocoder/layers/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..74cfc7262b4e1f4df9e9250964383050d1d26818 --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/losses.py @@ -0,0 +1,368 @@ +from typing import Dict, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss + +################################# +# GENERATOR LOSSES +################################# + + +class STFTLoss(nn.Module): + """STFT loss. Input generate and real waveforms are converted + to spectrograms compared with L1 and Spectral convergence losses. + It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + + def __init__(self, n_fft, hop_length, win_length): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.stft = TorchSTFT(n_fft, hop_length, win_length) + + def forward(self, y_hat, y): + y_hat_M = self.stft(y_hat) + y_M = self.stft(y) + # magnitude loss + loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) + # spectral convergence loss + loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") + return loss_mag, loss_sc + + +class MultiScaleSTFTLoss(torch.nn.Module): + """Multi-scale STFT loss. Input generate and real waveforms are converted + to spectrograms compared with L1 and Spectral convergence losses. + It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + + def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)): + super().__init__() + self.loss_funcs = torch.nn.ModuleList() + for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): + self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length)) + + def forward(self, y_hat, y): + N = len(self.loss_funcs) + loss_sc = 0 + loss_mag = 0 + for f in self.loss_funcs: + lm, lsc = f(y_hat, y) + loss_mag += lm + loss_sc += lsc + loss_sc /= N + loss_mag /= N + return loss_mag, loss_sc + + +class L1SpecLoss(nn.Module): + """L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf""" + + def __init__( + self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True + ): + super().__init__() + self.use_mel = use_mel + self.stft = TorchSTFT( + n_fft, + hop_length, + win_length, + sample_rate=sample_rate, + mel_fmin=mel_fmin, + mel_fmax=mel_fmax, + n_mels=n_mels, + use_mel=use_mel, + ) + + def forward(self, y_hat, y): + y_hat_M = self.stft(y_hat) + y_M = self.stft(y) + # magnitude loss + loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) + return loss_mag + + +class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): + """Multiscale STFT loss for multi band model outputs. + From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106""" + + # pylint: disable=no-self-use + def forward(self, y_hat, y): + y_hat = y_hat.view(-1, 1, y_hat.shape[2]) + y = y.view(-1, 1, y.shape[2]) + return super().forward(y_hat.squeeze(1), y.squeeze(1)) + + +class MSEGLoss(nn.Module): + """Mean Squared Generator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_real): + loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) + return loss_fake + + +class HingeGLoss(nn.Module): + """Hinge Discriminator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_real): + # TODO: this might be wrong + loss_fake = torch.mean(F.relu(1.0 - score_real)) + return loss_fake + + +################################## +# DISCRIMINATOR LOSSES +################################## + + +class MSEDLoss(nn.Module): + """Mean Squared Discriminator Loss""" + + def __init__( + self, + ): + super().__init__() + self.loss_func = nn.MSELoss() + + # pylint: disable=no-self-use + def forward(self, score_fake, score_real): + loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape)) + loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape)) + loss_d = loss_real + loss_fake + return loss_d, loss_real, loss_fake + + +class HingeDLoss(nn.Module): + """Hinge Discriminator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_fake, score_real): + loss_real = torch.mean(F.relu(1.0 - score_real)) + loss_fake = torch.mean(F.relu(1.0 + score_fake)) + loss_d = loss_real + loss_fake + return loss_d, loss_real, loss_fake + + +class MelganFeatureLoss(nn.Module): + def __init__( + self, + ): + super().__init__() + self.loss_func = nn.L1Loss() + + # pylint: disable=no-self-use + def forward(self, fake_feats, real_feats): + loss_feats = 0 + num_feats = 0 + for idx, _ in enumerate(fake_feats): + for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]): + loss_feats += self.loss_func(fake_feat, real_feat) + num_feats += 1 + loss_feats = loss_feats / num_feats + return loss_feats + + +##################################### +# LOSS WRAPPERS +##################################### + + +def _apply_G_adv_loss(scores_fake, loss_func): + """Compute G adversarial loss function + and normalize values""" + adv_loss = 0 + if isinstance(scores_fake, list): + for score_fake in scores_fake: + fake_loss = loss_func(score_fake) + adv_loss += fake_loss + adv_loss /= len(scores_fake) + else: + fake_loss = loss_func(scores_fake) + adv_loss = fake_loss + return adv_loss + + +def _apply_D_loss(scores_fake, scores_real, loss_func): + """Compute D loss func and normalize loss values""" + loss = 0 + real_loss = 0 + fake_loss = 0 + if isinstance(scores_fake, list): + # multi-scale loss + for score_fake, score_real in zip(scores_fake, scores_real): + total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real) + loss += total_loss + real_loss += real_loss_ + fake_loss += fake_loss_ + # normalize loss values with number of scales (discriminators) + loss /= len(scores_fake) + real_loss /= len(scores_real) + fake_loss /= len(scores_fake) + else: + # single scale loss + total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real) + loss = total_loss + return loss, real_loss, fake_loss + + +################################## +# MODEL LOSSES +################################## + + +class GeneratorLoss(nn.Module): + """Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes + losses. It allows to experiment with different combinations of loss functions with different models by just + changing configurations. + + Args: + C (AttrDict): model configuration. + """ + + def __init__(self, C): + super().__init__() + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." + + self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False + self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False + self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False + self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False + self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False + self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False + + self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0 + self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0 + self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0 + self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0 + self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0 + self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0 + + if C.use_stft_loss: + self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) + if C.use_subband_stft_loss: + self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params) + if C.use_mse_gan_loss: + self.mse_loss = MSEGLoss() + if C.use_hinge_gan_loss: + self.hinge_loss = HingeGLoss() + if C.use_feat_match_loss: + self.feat_match_loss = MelganFeatureLoss() + if C.use_l1_spec_loss: + assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"] + self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params) + + def forward( + self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None + ): + gen_loss = 0 + adv_loss = 0 + return_dict = {} + + # STFT Loss + if self.use_stft_loss: + stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1)) + return_dict["G_stft_loss_mg"] = stft_loss_mg + return_dict["G_stft_loss_sc"] = stft_loss_sc + gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) + + # L1 Spec loss + if self.use_l1_spec_loss: + l1_spec_loss = self.l1_spec_loss(y_hat, y) + return_dict["G_l1_spec_loss"] = l1_spec_loss + gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss + + # subband STFT Loss + if self.use_subband_stft_loss: + subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) + return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg + return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc + gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) + + # multiscale MSE adversarial loss + if self.use_mse_gan_loss and scores_fake is not None: + mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) + return_dict["G_mse_fake_loss"] = mse_fake_loss + adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss + + # multiscale Hinge adversarial loss + if self.use_hinge_gan_loss and not scores_fake is not None: + hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) + return_dict["G_hinge_fake_loss"] = hinge_fake_loss + adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss + + # Feature Matching Loss + if self.use_feat_match_loss and not feats_fake is None: + feat_match_loss = self.feat_match_loss(feats_fake, feats_real) + return_dict["G_feat_match_loss"] = feat_match_loss + adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss + return_dict["loss"] = gen_loss + adv_loss + return_dict["G_gen_loss"] = gen_loss + return_dict["G_adv_loss"] = adv_loss + return return_dict + + +class DiscriminatorLoss(nn.Module): + """Like ```GeneratorLoss```""" + + def __init__(self, C): + super().__init__() + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." + + self.use_mse_gan_loss = C.use_mse_gan_loss + self.use_hinge_gan_loss = C.use_hinge_gan_loss + + if C.use_mse_gan_loss: + self.mse_loss = MSEDLoss() + if C.use_hinge_gan_loss: + self.hinge_loss = HingeDLoss() + + def forward(self, scores_fake, scores_real): + loss = 0 + return_dict = {} + + if self.use_mse_gan_loss: + mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss + ) + return_dict["D_mse_gan_loss"] = mse_D_loss + return_dict["D_mse_gan_real_loss"] = mse_D_real_loss + return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss + loss += mse_D_loss + + if self.use_hinge_gan_loss: + hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss + ) + return_dict["D_hinge_gan_loss"] = hinge_D_loss + return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss + return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss + loss += hinge_D_loss + + return_dict["loss"] = loss + return return_dict + + +class WaveRNNLoss(nn.Module): + def __init__(self, wave_rnn_mode: Union[str, int]): + super().__init__() + if wave_rnn_mode == "mold": + self.loss_func = discretized_mix_logistic_loss + elif wave_rnn_mode == "gauss": + self.loss_func = gaussian_loss + elif isinstance(wave_rnn_mode, int): + self.loss_func = torch.nn.CrossEntropyLoss() + else: + raise ValueError(" [!] Unknown mode for Wavernn.") + + def forward(self, y_hat, y) -> Dict: + loss = self.loss_func(y_hat, y) + return {"loss": loss} diff --git a/viXTTS/TTS/vocoder/layers/lvc_block.py b/viXTTS/TTS/vocoder/layers/lvc_block.py new file mode 100644 index 0000000000000000000000000000000000000000..8913a1132ec769fd304077412289c01c0d1cb17b --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/lvc_block.py @@ -0,0 +1,198 @@ +import torch +import torch.nn.functional as F + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( # pylint: disable=dangerous-default-value + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): + kpnet_ + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers + l_b = conv_out_channels * conv_layers + + padding = (kpnet_conv_size - 1) // 2 + self.input_conv = torch.nn.Sequential( + torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_conv = torch.nn.Sequential( + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True) + self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + Returns: + """ + batch, _, cond_length = c.shape + + c = self.input_conv(c) + c = c + self.residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + + kernels = k.contiguous().view( + batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length + ) + bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length) + return kernels, bias + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + upsample_ratio, + conv_layers=4, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = conv_layers + self.conv_kernel_size = conv_kernel_size + self.convs = torch.nn.ModuleList() + + self.upsample = torch.nn.ConvTranspose1d( + in_channels, + in_channels, + kernel_size=upsample_ratio * 2, + stride=upsample_ratio, + padding=upsample_ratio // 2 + upsample_ratio % 2, + output_padding=upsample_ratio % 2, + ) + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=conv_layers, + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + ) + + for i in range(conv_layers): + padding = (3**i) * int((conv_kernel_size - 1) / 2) + conv = torch.nn.Conv1d( + in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i + ) + + self.convs.append(conv) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + in_channels = x.shape[1] + kernels, bias = self.kernel_predictor(c) + + x = F.leaky_relu(x, 0.2) + x = self.upsample(x) + + for i in range(self.conv_layers): + y = F.leaky_relu(x, 0.2) + y = self.convs[i](y) + y = F.leaky_relu(y, 0.2) + + k = kernels[:, i, :, :, :, :] + b = bias[:, i, :, :] + y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length) + x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :]) + return x + + @staticmethod + def location_variable_convolution(x, kernel, bias, dilation, hop_size): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + + assert in_length == ( + kernel_length * hop_size + ), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o + bias.unsqueeze(-1).unsqueeze(-1) + o = o.contiguous().view(batch, out_channels, -1) + return o diff --git a/viXTTS/TTS/vocoder/layers/melgan.py b/viXTTS/TTS/vocoder/layers/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad41a0f78d7e33021b445936a79318c63447284 --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/melgan.py @@ -0,0 +1,43 @@ +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + + +class ResidualStack(nn.Module): + def __init__(self, channels, num_res_blocks, kernel_size): + super().__init__() + + assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." + base_padding = (kernel_size - 1) // 2 + + self.blocks = nn.ModuleList() + for idx in range(num_res_blocks): + layer_kernel_size = kernel_size + layer_dilation = layer_kernel_size**idx + layer_padding = base_padding * layer_dilation + self.blocks += [ + nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(layer_padding), + weight_norm( + nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True) + ), + nn.LeakyReLU(0.2), + weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)), + ) + ] + + self.shortcuts = nn.ModuleList( + [weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)] + ) + + def forward(self, x): + for block, shortcut in zip(self.blocks, self.shortcuts): + x = shortcut(x) + block(x) + return x + + def remove_weight_norm(self): + for block, shortcut in zip(self.blocks, self.shortcuts): + remove_parametrizations(block[2], "weight") + remove_parametrizations(block[4], "weight") + remove_parametrizations(shortcut, "weight") diff --git a/viXTTS/TTS/vocoder/layers/parallel_wavegan.py b/viXTTS/TTS/vocoder/layers/parallel_wavegan.py new file mode 100644 index 0000000000000000000000000000000000000000..51142e5eceb20564585635a9040a24bc8eb3b6e3 --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/parallel_wavegan.py @@ -0,0 +1,77 @@ +import torch +from torch.nn import functional as F + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False, + ): + super().__init__() + self.dropout = dropout + # no future time stamps available + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + # dilation conv + self.conv = torch.nn.Conv1d( + res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) + else: + self.conv1x1_aux = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) + self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) + + def forward(self, x, c): + """ + x: B x D_res x T + c: B x D_aux x T + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv(x) + + # remove future time steps if use_causal_conv conv + x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + assert self.conv1x1_aux is not None + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # for skip connection + s = self.conv1x1_skip(x) + + # for residual connection + x = (self.conv1x1_out(x) + residual) * (0.5**2) + + return x, s diff --git a/viXTTS/TTS/vocoder/layers/pqmf.py b/viXTTS/TTS/vocoder/layers/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..6253efbbefc32222464a97bee99707d46bcdcf8b --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/pqmf.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +import torch.nn.functional as F +from scipy import signal as sig + + +# adapted from +# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan +class PQMF(torch.nn.Module): + def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): + super().__init__() + + self.N = N + self.taps = taps + self.cutoff = cutoff + self.beta = beta + + QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) + H = np.zeros((N, len(QMF))) + G = np.zeros((N, len(QMF))) + for k in range(N): + constant_factor = ( + (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + ) # TODO: (taps - 1) -> taps + phase = (-1) ** k * np.pi / 4 + H[k] = 2 * QMF * np.cos(constant_factor + phase) + + G[k] = 2 * QMF * np.cos(constant_factor - phase) + + H = torch.from_numpy(H[:, None, :]).float() + G = torch.from_numpy(G[None, :, :]).float() + + self.register_buffer("H", H) + self.register_buffer("G", G) + + updown_filter = torch.zeros((N, N, N)).float() + for k in range(N): + updown_filter[k, k, 0] = 1.0 + self.register_buffer("updown_filter", updown_filter) + self.N = N + + self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) + + def forward(self, x): + return self.analysis(x) + + def analysis(self, x): + return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N) + + def synthesis(self, x): + x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N) + x = F.conv1d(x, self.G, padding=self.taps // 2) + return x diff --git a/viXTTS/TTS/vocoder/layers/qmf.dat b/viXTTS/TTS/vocoder/layers/qmf.dat new file mode 100644 index 0000000000000000000000000000000000000000..17eab1379de991c36897c2ce701802ef76849c0d --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/qmf.dat @@ -0,0 +1,640 @@ + 0.0000000e+000 + -5.5252865e-004 + -5.6176926e-004 + -4.9475181e-004 + -4.8752280e-004 + -4.8937912e-004 + -5.0407143e-004 + -5.2265643e-004 + -5.4665656e-004 + -5.6778026e-004 + -5.8709305e-004 + -6.1327474e-004 + -6.3124935e-004 + -6.5403334e-004 + -6.7776908e-004 + -6.9416146e-004 + -7.1577365e-004 + -7.2550431e-004 + -7.4409419e-004 + -7.4905981e-004 + -7.6813719e-004 + -7.7248486e-004 + -7.8343323e-004 + -7.7798695e-004 + -7.8036647e-004 + -7.8014496e-004 + -7.7579773e-004 + -7.6307936e-004 + -7.5300014e-004 + -7.3193572e-004 + -7.2153920e-004 + -6.9179375e-004 + -6.6504151e-004 + -6.3415949e-004 + -5.9461189e-004 + -5.5645764e-004 + -5.1455722e-004 + -4.6063255e-004 + -4.0951215e-004 + -3.5011759e-004 + -2.8969812e-004 + -2.0983373e-004 + -1.4463809e-004 + -6.1733441e-005 + 1.3494974e-005 + 1.0943831e-004 + 2.0430171e-004 + 2.9495311e-004 + 4.0265402e-004 + 5.1073885e-004 + 6.2393761e-004 + 7.4580259e-004 + 8.6084433e-004 + 9.8859883e-004 + 1.1250155e-003 + 1.2577885e-003 + 1.3902495e-003 + 1.5443220e-003 + 1.6868083e-003 + 1.8348265e-003 + 1.9841141e-003 + 2.1461584e-003 + 2.3017255e-003 + 2.4625617e-003 + 2.6201759e-003 + 2.7870464e-003 + 2.9469448e-003 + 3.1125421e-003 + 3.2739613e-003 + 3.4418874e-003 + 3.6008268e-003 + 3.7603923e-003 + 3.9207432e-003 + 4.0819753e-003 + 4.2264269e-003 + 4.3730720e-003 + 4.5209853e-003 + 4.6606461e-003 + 4.7932561e-003 + 4.9137604e-003 + 5.0393023e-003 + 5.1407354e-003 + 5.2461166e-003 + 5.3471681e-003 + 5.4196776e-003 + 5.4876040e-003 + 5.5475715e-003 + 5.5938023e-003 + 5.6220643e-003 + 5.6455197e-003 + 5.6389200e-003 + 5.6266114e-003 + 5.5917129e-003 + 5.5404364e-003 + 5.4753783e-003 + 5.3838976e-003 + 5.2715759e-003 + 5.1382275e-003 + 4.9839688e-003 + 4.8109469e-003 + 4.6039530e-003 + 4.3801862e-003 + 4.1251642e-003 + 3.8456408e-003 + 3.5401247e-003 + 3.2091886e-003 + 2.8446758e-003 + 2.4508540e-003 + 2.0274176e-003 + 1.5784683e-003 + 1.0902329e-003 + 5.8322642e-004 + 2.7604519e-005 + -5.4642809e-004 + -1.1568136e-003 + -1.8039473e-003 + -2.4826724e-003 + -3.1933778e-003 + -3.9401124e-003 + -4.7222596e-003 + -5.5337211e-003 + -6.3792293e-003 + -7.2615817e-003 + -8.1798233e-003 + -9.1325330e-003 + -1.0115022e-002 + -1.1131555e-002 + -1.2185000e-002 + -1.3271822e-002 + -1.4390467e-002 + -1.5540555e-002 + -1.6732471e-002 + -1.7943338e-002 + -1.9187243e-002 + -2.0453179e-002 + -2.1746755e-002 + -2.3068017e-002 + -2.4416099e-002 + -2.5787585e-002 + -2.7185943e-002 + -2.8607217e-002 + -3.0050266e-002 + -3.1501761e-002 + -3.2975408e-002 + -3.4462095e-002 + -3.5969756e-002 + -3.7481285e-002 + -3.9005368e-002 + -4.0534917e-002 + -4.2064909e-002 + -4.3609754e-002 + -4.5148841e-002 + -4.6684303e-002 + -4.8216572e-002 + -4.9738576e-002 + -5.1255616e-002 + -5.2763075e-002 + -5.4245277e-002 + -5.5717365e-002 + -5.7161645e-002 + -5.8591568e-002 + -5.9983748e-002 + -6.1345517e-002 + -6.2685781e-002 + -6.3971590e-002 + -6.5224711e-002 + -6.6436751e-002 + -6.7607599e-002 + -6.8704383e-002 + -6.9763024e-002 + -7.0762871e-002 + -7.1700267e-002 + -7.2568258e-002 + -7.3362026e-002 + -7.4100364e-002 + -7.4745256e-002 + -7.5313734e-002 + -7.5800836e-002 + -7.6199248e-002 + -7.6499217e-002 + -7.6709349e-002 + -7.6817398e-002 + -7.6823001e-002 + -7.6720492e-002 + -7.6505072e-002 + -7.6174832e-002 + -7.5730576e-002 + -7.5157626e-002 + -7.4466439e-002 + -7.3640601e-002 + -7.2677464e-002 + -7.1582636e-002 + -7.0353307e-002 + -6.8966401e-002 + -6.7452502e-002 + -6.5769067e-002 + -6.3944481e-002 + -6.1960278e-002 + -5.9816657e-002 + -5.7515269e-002 + -5.5046003e-002 + -5.2409382e-002 + -4.9597868e-002 + -4.6630331e-002 + -4.3476878e-002 + -4.0145828e-002 + -3.6641812e-002 + -3.2958393e-002 + -2.9082401e-002 + -2.5030756e-002 + -2.0799707e-002 + -1.6370126e-002 + -1.1762383e-002 + -6.9636862e-003 + -1.9765601e-003 + 3.2086897e-003 + 8.5711749e-003 + 1.4128883e-002 + 1.9883413e-002 + 2.5822729e-002 + 3.1953127e-002 + 3.8277657e-002 + 4.4780682e-002 + 5.1480418e-002 + 5.8370533e-002 + 6.5440985e-002 + 7.2694330e-002 + 8.0137293e-002 + 8.7754754e-002 + 9.5553335e-002 + 1.0353295e-001 + 1.1168269e-001 + 1.2000780e-001 + 1.2850029e-001 + 1.3715518e-001 + 1.4597665e-001 + 1.5496071e-001 + 1.6409589e-001 + 1.7338082e-001 + 1.8281725e-001 + 1.9239667e-001 + 2.0212502e-001 + 2.1197359e-001 + 2.2196527e-001 + 2.3206909e-001 + 2.4230169e-001 + 2.5264803e-001 + 2.6310533e-001 + 2.7366340e-001 + 2.8432142e-001 + 2.9507167e-001 + 3.0590986e-001 + 3.1682789e-001 + 3.2781137e-001 + 3.3887227e-001 + 3.4999141e-001 + 3.6115899e-001 + 3.7237955e-001 + 3.8363500e-001 + 3.9492118e-001 + 4.0623177e-001 + 4.1756969e-001 + 4.2891199e-001 + 4.4025538e-001 + 4.5159965e-001 + 4.6293081e-001 + 4.7424532e-001 + 4.8552531e-001 + 4.9677083e-001 + 5.0798175e-001 + 5.1912350e-001 + 5.3022409e-001 + 5.4125534e-001 + 5.5220513e-001 + 5.6307891e-001 + 5.7385241e-001 + 5.8454032e-001 + 5.9511231e-001 + 6.0557835e-001 + 6.1591099e-001 + 6.2612427e-001 + 6.3619801e-001 + 6.4612697e-001 + 6.5590163e-001 + 6.6551399e-001 + 6.7496632e-001 + 6.8423533e-001 + 6.9332824e-001 + 7.0223887e-001 + 7.1094104e-001 + 7.1944626e-001 + 7.2774489e-001 + 7.3582118e-001 + 7.4368279e-001 + 7.5131375e-001 + 7.5870808e-001 + 7.6586749e-001 + 7.7277809e-001 + 7.7942875e-001 + 7.8583531e-001 + 7.9197358e-001 + 7.9784664e-001 + 8.0344858e-001 + 8.0876950e-001 + 8.1381913e-001 + 8.1857760e-001 + 8.2304199e-001 + 8.2722753e-001 + 8.3110385e-001 + 8.3469374e-001 + 8.3797173e-001 + 8.4095414e-001 + 8.4362383e-001 + 8.4598185e-001 + 8.4803158e-001 + 8.4978052e-001 + 8.5119715e-001 + 8.5230470e-001 + 8.5310209e-001 + 8.5357206e-001 + 8.5373856e-001 + 8.5357206e-001 + 8.5310209e-001 + 8.5230470e-001 + 8.5119715e-001 + 8.4978052e-001 + 8.4803158e-001 + 8.4598185e-001 + 8.4362383e-001 + 8.4095414e-001 + 8.3797173e-001 + 8.3469374e-001 + 8.3110385e-001 + 8.2722753e-001 + 8.2304199e-001 + 8.1857760e-001 + 8.1381913e-001 + 8.0876950e-001 + 8.0344858e-001 + 7.9784664e-001 + 7.9197358e-001 + 7.8583531e-001 + 7.7942875e-001 + 7.7277809e-001 + 7.6586749e-001 + 7.5870808e-001 + 7.5131375e-001 + 7.4368279e-001 + 7.3582118e-001 + 7.2774489e-001 + 7.1944626e-001 + 7.1094104e-001 + 7.0223887e-001 + 6.9332824e-001 + 6.8423533e-001 + 6.7496632e-001 + 6.6551399e-001 + 6.5590163e-001 + 6.4612697e-001 + 6.3619801e-001 + 6.2612427e-001 + 6.1591099e-001 + 6.0557835e-001 + 5.9511231e-001 + 5.8454032e-001 + 5.7385241e-001 + 5.6307891e-001 + 5.5220513e-001 + 5.4125534e-001 + 5.3022409e-001 + 5.1912350e-001 + 5.0798175e-001 + 4.9677083e-001 + 4.8552531e-001 + 4.7424532e-001 + 4.6293081e-001 + 4.5159965e-001 + 4.4025538e-001 + 4.2891199e-001 + 4.1756969e-001 + 4.0623177e-001 + 3.9492118e-001 + 3.8363500e-001 + 3.7237955e-001 + 3.6115899e-001 + 3.4999141e-001 + 3.3887227e-001 + 3.2781137e-001 + 3.1682789e-001 + 3.0590986e-001 + 2.9507167e-001 + 2.8432142e-001 + 2.7366340e-001 + 2.6310533e-001 + 2.5264803e-001 + 2.4230169e-001 + 2.3206909e-001 + 2.2196527e-001 + 2.1197359e-001 + 2.0212502e-001 + 1.9239667e-001 + 1.8281725e-001 + 1.7338082e-001 + 1.6409589e-001 + 1.5496071e-001 + 1.4597665e-001 + 1.3715518e-001 + 1.2850029e-001 + 1.2000780e-001 + 1.1168269e-001 + 1.0353295e-001 + 9.5553335e-002 + 8.7754754e-002 + 8.0137293e-002 + 7.2694330e-002 + 6.5440985e-002 + 5.8370533e-002 + 5.1480418e-002 + 4.4780682e-002 + 3.8277657e-002 + 3.1953127e-002 + 2.5822729e-002 + 1.9883413e-002 + 1.4128883e-002 + 8.5711749e-003 + 3.2086897e-003 + -1.9765601e-003 + -6.9636862e-003 + -1.1762383e-002 + -1.6370126e-002 + -2.0799707e-002 + -2.5030756e-002 + -2.9082401e-002 + -3.2958393e-002 + -3.6641812e-002 + -4.0145828e-002 + -4.3476878e-002 + -4.6630331e-002 + -4.9597868e-002 + -5.2409382e-002 + -5.5046003e-002 + -5.7515269e-002 + -5.9816657e-002 + -6.1960278e-002 + -6.3944481e-002 + -6.5769067e-002 + -6.7452502e-002 + -6.8966401e-002 + -7.0353307e-002 + -7.1582636e-002 + -7.2677464e-002 + -7.3640601e-002 + -7.4466439e-002 + -7.5157626e-002 + -7.5730576e-002 + -7.6174832e-002 + -7.6505072e-002 + -7.6720492e-002 + -7.6823001e-002 + -7.6817398e-002 + -7.6709349e-002 + -7.6499217e-002 + -7.6199248e-002 + -7.5800836e-002 + -7.5313734e-002 + -7.4745256e-002 + -7.4100364e-002 + -7.3362026e-002 + -7.2568258e-002 + -7.1700267e-002 + -7.0762871e-002 + -6.9763024e-002 + -6.8704383e-002 + -6.7607599e-002 + -6.6436751e-002 + -6.5224711e-002 + -6.3971590e-002 + -6.2685781e-002 + -6.1345517e-002 + -5.9983748e-002 + -5.8591568e-002 + -5.7161645e-002 + -5.5717365e-002 + -5.4245277e-002 + -5.2763075e-002 + -5.1255616e-002 + -4.9738576e-002 + -4.8216572e-002 + -4.6684303e-002 + -4.5148841e-002 + -4.3609754e-002 + -4.2064909e-002 + -4.0534917e-002 + -3.9005368e-002 + -3.7481285e-002 + -3.5969756e-002 + -3.4462095e-002 + -3.2975408e-002 + -3.1501761e-002 + -3.0050266e-002 + -2.8607217e-002 + -2.7185943e-002 + -2.5787585e-002 + -2.4416099e-002 + -2.3068017e-002 + -2.1746755e-002 + -2.0453179e-002 + -1.9187243e-002 + -1.7943338e-002 + -1.6732471e-002 + -1.5540555e-002 + -1.4390467e-002 + -1.3271822e-002 + -1.2185000e-002 + -1.1131555e-002 + -1.0115022e-002 + -9.1325330e-003 + -8.1798233e-003 + -7.2615817e-003 + -6.3792293e-003 + -5.5337211e-003 + -4.7222596e-003 + -3.9401124e-003 + -3.1933778e-003 + -2.4826724e-003 + -1.8039473e-003 + -1.1568136e-003 + -5.4642809e-004 + 2.7604519e-005 + 5.8322642e-004 + 1.0902329e-003 + 1.5784683e-003 + 2.0274176e-003 + 2.4508540e-003 + 2.8446758e-003 + 3.2091886e-003 + 3.5401247e-003 + 3.8456408e-003 + 4.1251642e-003 + 4.3801862e-003 + 4.6039530e-003 + 4.8109469e-003 + 4.9839688e-003 + 5.1382275e-003 + 5.2715759e-003 + 5.3838976e-003 + 5.4753783e-003 + 5.5404364e-003 + 5.5917129e-003 + 5.6266114e-003 + 5.6389200e-003 + 5.6455197e-003 + 5.6220643e-003 + 5.5938023e-003 + 5.5475715e-003 + 5.4876040e-003 + 5.4196776e-003 + 5.3471681e-003 + 5.2461166e-003 + 5.1407354e-003 + 5.0393023e-003 + 4.9137604e-003 + 4.7932561e-003 + 4.6606461e-003 + 4.5209853e-003 + 4.3730720e-003 + 4.2264269e-003 + 4.0819753e-003 + 3.9207432e-003 + 3.7603923e-003 + 3.6008268e-003 + 3.4418874e-003 + 3.2739613e-003 + 3.1125421e-003 + 2.9469448e-003 + 2.7870464e-003 + 2.6201759e-003 + 2.4625617e-003 + 2.3017255e-003 + 2.1461584e-003 + 1.9841141e-003 + 1.8348265e-003 + 1.6868083e-003 + 1.5443220e-003 + 1.3902495e-003 + 1.2577885e-003 + 1.1250155e-003 + 9.8859883e-004 + 8.6084433e-004 + 7.4580259e-004 + 6.2393761e-004 + 5.1073885e-004 + 4.0265402e-004 + 2.9495311e-004 + 2.0430171e-004 + 1.0943831e-004 + 1.3494974e-005 + -6.1733441e-005 + -1.4463809e-004 + -2.0983373e-004 + -2.8969812e-004 + -3.5011759e-004 + -4.0951215e-004 + -4.6063255e-004 + -5.1455722e-004 + -5.5645764e-004 + -5.9461189e-004 + -6.3415949e-004 + -6.6504151e-004 + -6.9179375e-004 + -7.2153920e-004 + -7.3193572e-004 + -7.5300014e-004 + -7.6307936e-004 + -7.7579773e-004 + -7.8014496e-004 + -7.8036647e-004 + -7.7798695e-004 + -7.8343323e-004 + -7.7248486e-004 + -7.6813719e-004 + -7.4905981e-004 + -7.4409419e-004 + -7.2550431e-004 + -7.1577365e-004 + -6.9416146e-004 + -6.7776908e-004 + -6.5403334e-004 + -6.3124935e-004 + -6.1327474e-004 + -5.8709305e-004 + -5.6778026e-004 + -5.4665656e-004 + -5.2265643e-004 + -5.0407143e-004 + -4.8937912e-004 + -4.8752280e-004 + -4.9475181e-004 + -5.6176926e-004 + -5.5252865e-004 diff --git a/viXTTS/TTS/vocoder/layers/upsample.py b/viXTTS/TTS/vocoder/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..e169db00b2749493e1cec07ee51c93178dada118 --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/upsample.py @@ -0,0 +1,102 @@ +import torch +from torch.nn import functional as F + + +class Stretch2d(torch.nn.Module): + def __init__(self, x_scale, y_scale, mode="nearest"): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + """ + x (Tensor): Input tensor (B, C, F, T). + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + """ + return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + + +class UpsampleNetwork(torch.nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + super().__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = torch.nn.ModuleList() + for scale in upsample_factors: + # interpolation layer + stretch = Stretch2d(scale, 1, interpolate_mode) + self.up_layers += [stretch] + + # conv layer + assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.up_layers += [conv] + + # nonlinear + if nonlinear_activation is not None: + nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + self.up_layers += [nonlinear] + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsample) + """ + c = c.unsqueeze(1) # (B, 1, C, T) + for f in self.up_layers: + c = f(c) + return c.squeeze(1) # (B, C, T') + + +class ConvUpsample(torch.nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False, + ): + super().__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + # To capture wide-context information in conditional features + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + # NOTE(kan-bayashi): Here do not use padding because the input is already padded + self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) + self.upsample = UpsampleNetwork( + upsample_factors=upsample_factors, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + ) + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsampled), + """ + c_ = self.conv_in(c) + c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) diff --git a/viXTTS/TTS/vocoder/layers/wavegrad.py b/viXTTS/TTS/vocoder/layers/wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1512c6d4b836c02b8b04136d69a32833d9f8c3 --- /dev/null +++ b/viXTTS/TTS/vocoder/layers/wavegrad.py @@ -0,0 +1,166 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.orthogonal_(self.weight) + nn.init.zeros_(self.bias) + + +class PositionalEncoding(nn.Module): + """Positional encoding with noise level conditioning""" + + def __init__(self, n_channels, max_len=10000): + super().__init__() + self.n_channels = n_channels + self.max_len = max_len + self.C = 5000 + self.pe = torch.zeros(0, 0) + + def forward(self, x, noise_level): + if x.shape[2] > self.pe.shape[1]: + self.init_pe_matrix(x.shape[1], x.shape[2], x) + return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C + + def init_pe_matrix(self, n_channels, max_len, x): + pe = torch.zeros(max_len, n_channels) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels) + + pe[:, 0::2] = torch.sin(position / div_term) + pe[:, 1::2] = torch.cos(position / div_term) + self.pe = pe.transpose(0, 1).to(x) + + +class FiLM(nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.encoding = PositionalEncoding(input_size) + self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) + self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) + + nn.init.xavier_uniform_(self.input_conv.weight) + nn.init.xavier_uniform_(self.output_conv.weight) + nn.init.zeros_(self.input_conv.bias) + nn.init.zeros_(self.output_conv.bias) + + def forward(self, x, noise_scale): + o = self.input_conv(x) + o = F.leaky_relu(o, 0.2) + o = self.encoding(o, noise_scale) + shift, scale = torch.chunk(self.output_conv(o), 2, dim=1) + return shift, scale + + def remove_weight_norm(self): + remove_parametrizations(self.input_conv, "weight") + remove_parametrizations(self.output_conv, "weight") + + def apply_weight_norm(self): + self.input_conv = weight_norm(self.input_conv) + self.output_conv = weight_norm(self.output_conv) + + +@torch.jit.script +def shif_and_scale(x, scale, shift): + o = shift + scale * x + return o + + +class UBlock(nn.Module): + def __init__(self, input_size, hidden_size, factor, dilation): + super().__init__() + assert isinstance(dilation, (list, tuple)) + assert len(dilation) == 4 + + self.factor = factor + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]), + ] + ) + self.out_block = nn.ModuleList( + [ + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]), + ] + ) + + def forward(self, x, shift, scale): + x_inter = F.interpolate(x, size=x.shape[-1] * self.factor) + res = self.res_block(x_inter) + o = F.leaky_relu(x_inter, 0.2) + o = F.interpolate(o, size=x.shape[-1] * self.factor) + o = self.main_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.main_block[1](o) + res2 = res + o + o = shif_and_scale(res2, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[1](o) + o = o + res2 + return o + + def remove_weight_norm(self): + remove_parametrizations(self.res_block, "weight") + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + remove_parametrizations(layer, "weight") + for _, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + remove_parametrizations(layer, "weight") + + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) + for idx, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + self.out_block[idx] = weight_norm(layer) + + +class DBlock(nn.Module): + def __init__(self, input_size, hidden_size, factor): + super().__init__() + self.factor = factor + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), + Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), + Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), + ] + ) + + def forward(self, x): + size = x.shape[-1] // self.factor + res = self.res_block(x) + res = F.interpolate(res, size=size) + o = F.interpolate(x, size=size) + for layer in self.main_block: + o = F.leaky_relu(o, 0.2) + o = layer(o) + return o + res + + def remove_weight_norm(self): + remove_parametrizations(self.res_block, "weight") + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + remove_parametrizations(layer, "weight") + + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) diff --git a/viXTTS/TTS/vocoder/models/__init__.py b/viXTTS/TTS/vocoder/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65901617b69d3ae708e09226c5e4ad903f89a929 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/__init__.py @@ -0,0 +1,154 @@ +import importlib +import re + +from coqpit import Coqpit + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_model(config: Coqpit): + """Load models directly from configuration.""" + if "discriminator_model" in config and "generator_model" in config: + MyModel = importlib.import_module("TTS.vocoder.models.gan") + MyModel = getattr(MyModel, "GAN") + else: + MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) + if config.model.lower() == "wavernn": + MyModel = getattr(MyModel, "Wavernn") + elif config.model.lower() == "gan": + MyModel = getattr(MyModel, "GAN") + elif config.model.lower() == "wavegrad": + MyModel = getattr(MyModel, "Wavegrad") + else: + try: + MyModel = getattr(MyModel, to_camel(config.model)) + except ModuleNotFoundError as e: + raise ValueError(f"Model {config.model} not exist!") from e + print(" > Vocoder Model: {}".format(config.model)) + return MyModel.init_from_config(config) + + +def setup_generator(c): + """TODO: use config object as arguments""" + print(" > Generator Model: {}".format(c.generator_model)) + MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.generator_model)) + # this is to preserve the Wavernn class name (instead of Wavernn) + if c.generator_model.lower() in "hifigan_generator": + model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) + elif c.generator_model.lower() in "melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model in "melgan_fb_generator": + raise ValueError("melgan_fb_generator is now fullband_melgan_generator") + elif c.generator_model.lower() in "multiband_melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "fullband_melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "parallel_wavegan_generator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + stacks=c.generator_model_params["stacks"], + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=c.audio["num_mels"], + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=c.generator_model_params["upsample_factors"], + ) + elif c.generator_model.lower() in "univnet_generator": + model = MyModel(**c.generator_model_params) + else: + raise NotImplementedError(f"Model {c.generator_model} not implemented!") + return model + + +def setup_discriminator(c): + """TODO: use config objekt as arguments""" + print(" > Discriminator Model: {}".format(c.discriminator_model)) + if "parallel_wavegan" in c.discriminator_model: + MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") + else: + MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) + if c.discriminator_model in "hifigan_discriminator": + model = MyModel() + if c.discriminator_model in "random_window_discriminator": + model = MyModel( + cond_channels=c.audio["num_mels"], + hop_length=c.audio["hop_length"], + uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], + cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], + cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], + window_sizes=c.discriminator_model_params["window_sizes"], + ) + if c.discriminator_model in "melgan_multiscale_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=c.discriminator_model_params["base_channels"], + max_channels=c.discriminator_model_params["max_channels"], + downsample_factors=c.discriminator_model_params["downsample_factors"], + ) + if c.discriminator_model == "residual_parallel_wavegan_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params["num_layers"], + stacks=c.discriminator_model_params["stacks"], + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ) + if c.discriminator_model == "parallel_wavegan_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params["num_layers"], + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ) + if c.discriminator_model == "univnet_discriminator": + model = MyModel() + return model diff --git a/viXTTS/TTS/vocoder/models/base_vocoder.py b/viXTTS/TTS/vocoder/models/base_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcbe7ba1cb933c2bd3e8925c32d781aeaf79add --- /dev/null +++ b/viXTTS/TTS/vocoder/models/base_vocoder.py @@ -0,0 +1,55 @@ +from coqpit import Coqpit + +from TTS.model import BaseTrainerModel + +# pylint: skip-file + + +class BaseVocoder(BaseTrainerModel): + """Base `vocoder` class. Every new `vocoder` model must inherit this. + + It defines `vocoder` specific functions on top of `Model`. + + Notes on input/output tensor shapes: + Any input or output tensor of the model must be shaped as + + - 3D tensors `batch x time x channels` + - 2D tensors `batch x channels` + - 1D tensors `batch x 1` + """ + + MODEL_TYPE = "vocoder" + + def __init__(self, config): + super().__init__() + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + if "characters" in config: + _, self.config, num_chars = self.get_characters(config) + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + self.config = config + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/viXTTS/TTS/vocoder/models/fullband_melgan_generator.py b/viXTTS/TTS/vocoder/models/fullband_melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..ee25559af0d468aac535841bdfdd33b366250f43 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/fullband_melgan_generator.py @@ -0,0 +1,33 @@ +import torch + +from TTS.vocoder.models.melgan_generator import MelganGenerator + + +class FullbandMelganGenerator(MelganGenerator): + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=4, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) + + @torch.no_grad() + def inference(self, cond_features): + cond_features = cond_features.to(self.layers[1].weight.device) + cond_features = torch.nn.functional.pad( + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.layers(cond_features) diff --git a/viXTTS/TTS/vocoder/models/gan.py b/viXTTS/TTS/vocoder/models/gan.py new file mode 100644 index 0000000000000000000000000000000000000000..19c30e983e5bb2066d3ccd22dc5cb21c091cb60a --- /dev/null +++ b/viXTTS/TTS/vocoder/models/gan.py @@ -0,0 +1,374 @@ +from inspect import signature +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_fsspec +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss +from TTS.vocoder.models import setup_discriminator, setup_generator +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.generic_utils import plot_results + + +class GAN(BaseVocoder): + def __init__(self, config: Coqpit, ap: AudioProcessor = None): + """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. + It also helps mixing and matching different generator and disciminator networks easily. + + To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest + is handled by the `GAN` class. + + Args: + config (Coqpit): Model configuration. + ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None. + + Examples: + Initializing the GAN model with HifiGAN generator and discriminator. + >>> from TTS.vocoder.configs import HifiganConfig + >>> config = HifiganConfig() + >>> model = GAN(config) + """ + super().__init__(config) + self.config = config + self.model_g = setup_generator(config) + self.model_d = setup_discriminator(config) + self.train_disc = False # if False, train only the generator. + self.y_hat_g = None # the last generator prediction to be passed onto the discriminator + self.ap = ap + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the generator's forward pass. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: output of the GAN generator network. + """ + return self.model_g.forward(x) + + def inference(self, x: torch.Tensor) -> torch.Tensor: + """Run the generator's inference pass. + + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: output of the GAN generator network. + """ + return self.model_g.inference(x) + + def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for + network on the current pass. + + Args: + batch (Dict): Batch of samples returned by the dataloader. + criterion (Dict): Criterion used to compute the losses. + optimizer_idx (int): ID of the optimizer in use on the current pass. + + Raises: + ValueError: `optimizer_idx` is an unexpected value. + + Returns: + Tuple[Dict, Dict]: model outputs and the computed loss values. + """ + outputs = {} + loss_dict = {} + + x = batch["input"] + y = batch["waveform"] + + if optimizer_idx not in [0, 1]: + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + if optimizer_idx == 0: + # DISCRIMINATOR optimization + + # generator pass + y_hat = self.model_g(x)[:, :, : y.size(2)] + + # cache for generator loss + # pylint: disable=W0201 + self.y_hat_g = y_hat + self.y_hat_sub = None + self.y_sub_g = None + + # PQMF formatting + if y_hat.shape[1] > 1: + self.y_hat_sub = y_hat + y_hat = self.model_g.pqmf_synthesis(y_hat) + self.y_hat_g = y_hat # save for generator loss + self.y_sub_g = self.model_g.pqmf_analysis(y) + + scores_fake, feats_fake, feats_real = None, None, None + + if self.train_disc: + # use different samples for G and D trainings + if self.config.diff_samples_for_G_and_D: + x_d = batch["input_disc"] + y_d = batch["waveform_disc"] + # use a different sample than generator + with torch.no_grad(): + y_hat = self.model_g(x_d) + + # PQMF formatting + if y_hat.shape[1] > 1: + y_hat = self.model_g.pqmf_synthesis(y_hat) + else: + # use the same samples as generator + x_d = x.clone() + y_d = y.clone() + y_hat = self.y_hat_g + + # run D with or without cond. features + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(y_hat.detach().clone(), x_d) + D_out_real = self.model_d(y_d, x_d) + else: + D_out_fake = self.model_d(y_hat.detach()) + D_out_real = self.model_d(y_d) + + # format D outputs + if isinstance(D_out_fake, tuple): + # self.model_d returns scores and features + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + scores_real, feats_real = None, None + else: + scores_real, feats_real = D_out_real + else: + # model D returns only scores + scores_fake = D_out_fake + scores_real = D_out_real + + # compute losses + loss_dict = criterion[optimizer_idx](scores_fake, scores_real) + outputs = {"model_outputs": y_hat} + + if optimizer_idx == 1: + # GENERATOR loss + scores_fake, feats_fake, feats_real = None, None, None + if self.train_disc: + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(self.y_hat_g, x) + else: + D_out_fake = self.model_d(self.y_hat_g) + D_out_real = None + + if self.config.use_feat_match_loss: + with torch.no_grad(): + D_out_real = self.model_d(y) + + # format D outputs + if isinstance(D_out_fake, tuple): + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + feats_real = None + else: + _, feats_real = D_out_real + else: + scores_fake = D_out_fake + feats_fake, feats_real = None, None + + # compute losses + loss_dict = criterion[optimizer_idx]( + self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g + ) + outputs = {"model_outputs": self.y_hat_g} + return outputs, loss_dict + + def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: + """Logging shared by the training and evaluation. + + Args: + name (str): Name of the run. `train` or `eval`, + ap (AudioProcessor): Audio processor used in training. + batch (Dict): Batch used in the last train/eval step. + outputs (Dict): Model outputs from the last train/eval step. + + Returns: + Tuple[Dict, Dict]: log figures and audio samples. + """ + y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"] + y = batch["waveform"] + figures = plot_results(y_hat, y, ap, name) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name}/audio": sample_voice} + return figures, audios + + def train_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + """Call `_log()` for training.""" + figures, audios = self._log("eval", self.ap, batch, outputs) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Call `train_step()` with `no_grad()`""" + self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + """Call `_log()` for evaluation.""" + figures, audios = self._log("eval", self.ap, batch, outputs) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, # pylint: disable=unused-argument, redefined-builtin + cache: bool = False, + ) -> None: + """Load a GAN checkpoint and initialize model parameters. + + Args: + config (Coqpit): Model config. + checkpoint_path (str): Checkpoint file path. + eval (bool, optional): If true, load the model for inference. If falseDefaults to False. + """ + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + # band-aid for older than v0.0.15 GAN models + if "model_disc" in state: + self.model_g.load_checkpoint(config, checkpoint_path, eval) + else: + self.load_state_dict(state["model"]) + if eval: + self.model_d = None + if hasattr(self.model_g, "remove_weight_norm"): + self.model_g.remove_weight_norm() + + def on_train_step_start(self, trainer) -> None: + """Enable the discriminator training based on `steps_to_start_discriminator` + + Args: + trainer (Trainer): Trainer object. + """ + self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + + Returns: + List: optimizers. + """ + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g + ) + optimizer2 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d + ) + return [optimizer2, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler2, scheduler1] + + @staticmethod + def format_batch(batch: List) -> Dict: + """Format the batch for training. + + Args: + batch (List): Batch out of the dataloader. + + Returns: + Dict: formatted model inputs. + """ + if isinstance(batch[0], list): + x_G, y_G = batch[0] + x_D, y_D = batch[1] + return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D} + x, y = batch + return {"input": x, "waveform": y} + + def get_data_loader( # pylint: disable=no-self-use, unused-argument + self, + config: Coqpit, + assets: Dict, + is_eval: True, + samples: List, + verbose: bool, + num_gpus: int, + rank: int = None, # pylint: disable=unused-argument + ): + """Initiate and return the GAN dataloader. + + Args: + config (Coqpit): Model config. + ap (AudioProcessor): Audio processor. + is_eval (True): Set the dataloader for evaluation if true. + samples (List): Data samples. + verbose (bool): Log information if true. + num_gpus (int): Number of GPUs in use. + rank (int): Rank of the current GPU. Defaults to None. + + Returns: + DataLoader: Torch dataloader. + """ + dataset = GANDataset( + ap=self.ap, + items=samples, + seq_len=config.seq_len, + hop_len=self.ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + verbose=verbose, + ) + dataset.shuffle_mapping() + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + drop_last=False, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_criterion(self): + """Return criterions for the optimizers""" + return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] + + @staticmethod + def init_from_config(config: Coqpit, verbose=True) -> "GAN": + ap = AudioProcessor.init_from_config(config, verbose=verbose) + return GAN(config, ap=ap) diff --git a/viXTTS/TTS/vocoder/models/hifigan_discriminator.py b/viXTTS/TTS/vocoder/models/hifigan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..7447a5fbc45991578039068ea6e2e951bba2eb21 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/hifigan_discriminator.py @@ -0,0 +1,217 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import functional as F + +LRELU_SLOPE = 0.1 + + +class DiscriminatorP(torch.nn.Module): + """HiFiGAN Periodic Discriminator + + Takes every Pth value from the input waveform and applied a stack of convoluations. + + Note: + if `period` is 2 + `waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat` + + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + get_padding = lambda k, d: int((k * d - d) / 2) + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + feat = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + + return x, feat + + +class MultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Period Discriminator (MPD) + Wrapper for the `PeriodDiscriminator` to apply it in different periods. + Periods are suggested to be prime numbers to reduce the overlap between each discriminator. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2, use_spectral_norm=use_spectral_norm), + DiscriminatorP(3, use_spectral_norm=use_spectral_norm), + DiscriminatorP(5, use_spectral_norm=use_spectral_norm), + DiscriminatorP(7, use_spectral_norm=use_spectral_norm), + DiscriminatorP(11, use_spectral_norm=use_spectral_norm), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [List[Tensor]]: list of scores from each discriminator. + [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + scores = [] + feats = [] + for _, d in enumerate(self.discriminators): + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. + It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), + norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class MultiScaleDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Scale Discriminator. + It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper. + """ + + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores = [] + feats = [] + for i, d in enumerate(self.discriminators): + if i != 0: + x = self.meanpools[i - 1](x) + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class HifiganDiscriminator(nn.Module): + """HiFiGAN discriminator wrapping MPD and MSD.""" + + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.mpd(x) + scores_, feats_ = self.msd(x) + return scores + scores_, feats + feats_ diff --git a/viXTTS/TTS/vocoder/models/hifigan_generator.py b/viXTTS/TTS/vocoder/models/hifigan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..92475322590c6d4da66b322aa418c452d8eb8b27 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/hifigan_generator.py @@ -0,0 +1,301 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + +from TTS.utils.io import load_fsspec + +LRELU_SLOPE = 0.1 + + +def get_padding(k, d): + return int((k * d - d) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_parametrizations(l, "weight") + for l in self.convs2: + remove_parametrizations(l, "weight") + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_parametrizations(l, "weight") + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_parametrizations(self.conv_pre, "weight") + + if not conv_post_weight_norm: + remove_parametrizations(self.conv_post, "weight") + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + c = c.to(self.conv_pre.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_parametrizations(l, "weight") + for l in self.resblocks: + l.remove_weight_norm() + remove_parametrizations(self.conv_pre, "weight") + remove_parametrizations(self.conv_post, "weight") + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() diff --git a/viXTTS/TTS/vocoder/models/melgan_discriminator.py b/viXTTS/TTS/vocoder/models/melgan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..e41467da3c1d2fdceb9e86cecdc848b62ee731ff --- /dev/null +++ b/viXTTS/TTS/vocoder/models/melgan_discriminator.py @@ -0,0 +1,84 @@ +import numpy as np +from torch import nn +from torch.nn.utils.parametrizations import weight_norm + + +class MelganDiscriminator(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4, 4), + groups_denominator=4, + ): + super().__init__() + self.layers = nn.ModuleList() + + layer_kernel_size = np.prod(kernel_sizes) + layer_padding = (layer_kernel_size - 1) // 2 + + # initial layer + self.layers += [ + nn.Sequential( + nn.ReflectionPad1d(layer_padding), + weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)), + nn.LeakyReLU(0.2, inplace=True), + ) + ] + + # downsampling layers + layer_in_channels = base_channels + for downsample_factor in downsample_factors: + layer_out_channels = min(layer_in_channels * downsample_factor, max_channels) + layer_kernel_size = downsample_factor * 10 + 1 + layer_padding = (layer_kernel_size - 1) // 2 + layer_groups = layer_in_channels // groups_denominator + self.layers += [ + nn.Sequential( + weight_norm( + nn.Conv1d( + layer_in_channels, + layer_out_channels, + kernel_size=layer_kernel_size, + stride=downsample_factor, + padding=layer_padding, + groups=layer_groups, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ) + ] + layer_in_channels = layer_out_channels + + # last 2 layers + layer_padding1 = (kernel_sizes[0] - 1) // 2 + layer_padding2 = (kernel_sizes[1] - 1) // 2 + self.layers += [ + nn.Sequential( + weight_norm( + nn.Conv1d( + layer_out_channels, + layer_out_channels, + kernel_size=kernel_sizes[0], + stride=1, + padding=layer_padding1, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ), + weight_norm( + nn.Conv1d( + layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2 + ) + ), + ] + + def forward(self, x): + feats = [] + for layer in self.layers: + x = layer(x) + feats.append(x) + return x, feats diff --git a/viXTTS/TTS/vocoder/models/melgan_generator.py b/viXTTS/TTS/vocoder/models/melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3fee789cd3ec0276957882e2a9e374fb9b0a6b --- /dev/null +++ b/viXTTS/TTS/vocoder/models/melgan_generator.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +from torch.nn.utils.parametrizations import weight_norm + +from TTS.utils.io import load_fsspec +from TTS.vocoder.layers.melgan import ResidualStack + + +class MelganGenerator(nn.Module): + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__() + + # assert model parameters + assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." + + # setup additional model parameters + base_padding = (proj_kernel - 1) // 2 + act_slope = 0.2 + self.inference_padding = 2 + + # initial layer + layers = [] + layers += [ + nn.ReflectionPad1d(base_padding), + weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), + ] + + # upsampling layers and residual stacks + for idx, upsample_factor in enumerate(upsample_factors): + layer_in_channels = base_channels // (2**idx) + layer_out_channels = base_channels // (2 ** (idx + 1)) + layer_filter_size = upsample_factor * 2 + layer_stride = upsample_factor + layer_output_padding = upsample_factor % 2 + layer_padding = upsample_factor // 2 + layer_output_padding + layers += [ + nn.LeakyReLU(act_slope), + weight_norm( + nn.ConvTranspose1d( + layer_in_channels, + layer_out_channels, + layer_filter_size, + stride=layer_stride, + padding=layer_padding, + output_padding=layer_output_padding, + bias=True, + ) + ), + ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), + ] + + layers += [nn.LeakyReLU(act_slope)] + + # final layer + layers += [ + nn.ReflectionPad1d(base_padding), + weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), + nn.Tanh(), + ] + self.layers = nn.Sequential(*layers) + + def forward(self, c): + return self.layers(c) + + def inference(self, c): + c = c.to(self.layers[1].weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.layers(c) + + def remove_weight_norm(self): + for _, layer in enumerate(self.layers): + if len(layer.state_dict()) != 0: + try: + nn.utils.parametrize.remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() diff --git a/viXTTS/TTS/vocoder/models/melgan_multiscale_discriminator.py b/viXTTS/TTS/vocoder/models/melgan_multiscale_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..b4909f37c0c91c6fee8bb0baab98a8662039dea1 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -0,0 +1,50 @@ +from torch import nn + +from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator + + +class MelganMultiscaleDiscriminator(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + num_scales=3, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4), + pooling_kernel_size=4, + pooling_stride=2, + pooling_padding=2, + groups_denominator=4, + ): + super().__init__() + + self.discriminators = nn.ModuleList( + [ + MelganDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + base_channels=base_channels, + max_channels=max_channels, + downsample_factors=downsample_factors, + groups_denominator=groups_denominator, + ) + for _ in range(num_scales) + ] + ) + + self.pooling = nn.AvgPool1d( + kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False + ) + + def forward(self, x): + scores = [] + feats = [] + for disc in self.discriminators: + score, feat = disc(x) + scores.append(score) + feats.append(feat) + x = self.pooling(x) + return scores, feats diff --git a/viXTTS/TTS/vocoder/models/multiband_melgan_generator.py b/viXTTS/TTS/vocoder/models/multiband_melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..25d6590659cf5863176eb6609c7609b0e1b28d12 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/multiband_melgan_generator.py @@ -0,0 +1,41 @@ +import torch + +from TTS.vocoder.layers.pqmf import PQMF +from TTS.vocoder.models.melgan_generator import MelganGenerator + + +class MultibandMelganGenerator(MelganGenerator): + def __init__( + self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) + self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) + + def pqmf_analysis(self, x): + return self.pqmf_layer.analysis(x) + + def pqmf_synthesis(self, x): + return self.pqmf_layer.synthesis(x) + + @torch.no_grad() + def inference(self, cond_features): + cond_features = cond_features.to(self.layers[1].weight.device) + cond_features = torch.nn.functional.pad( + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.pqmf_synthesis(self.layers(cond_features)) diff --git a/viXTTS/TTS/vocoder/models/parallel_wavegan_discriminator.py b/viXTTS/TTS/vocoder/models/parallel_wavegan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..d02af75f05d74041a6c8ce734dd6efbe5d3080ae --- /dev/null +++ b/viXTTS/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -0,0 +1,187 @@ +import math + +import torch +from torch import nn +from torch.nn.utils.parametrize import remove_parametrizations + +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock + + +class ParallelWaveganDiscriminator(nn.Module): + """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. + It classifies each audio window real/fake and returns a sequence + of predictions. + It is a stack of convolutional blocks with dilation. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ): + super().__init__() + assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." + assert dilation_factor > 0, " [!] dilation factor must be > 0." + self.conv_layers = nn.ModuleList() + conv_in_channels = in_channels + for i in range(num_layers - 1): + if i == 0: + dilation = 1 + else: + dilation = i if dilation_factor == 1 else dilation_factor**i + conv_in_channels = conv_channels + padding = (kernel_size - 1) // 2 * dilation + conv_layer = [ + nn.Conv1d( + conv_in_channels, + conv_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + ] + self.conv_layers += conv_layer + padding = (kernel_size - 1) // 2 + last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + self.conv_layers += [last_conv_layer] + self.apply_weight_norm() + + def forward(self, x): + """ + x : (B, 1, T). + Returns: + Tensor: (B, 1, T) + """ + for f in self.conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + +class ResidualParallelWaveganDiscriminator(nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): + super().__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.stacks = stacks + self.kernel_size = kernel_size + self.res_factor = math.sqrt(1.0 / num_layers) + + # check the number of num_layers and stacks + assert num_layers % stacks == 0 + layers_per_stack = num_layers // stacks + + # define first convolution + self.first_conv = nn.Sequential( + nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + ) + + # define residual blocks + self.conv_layers = nn.ModuleList() + for layer in range(num_layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=-1, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=False, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = nn.ModuleList( + [ + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True), + ] + ) + + # apply weight norm + self.apply_weight_norm() + + def forward(self, x): + """ + x: (B, 1, T). + """ + x = self.first_conv(x) + + skips = 0 + for f in self.conv_layers: + x, h = f(x, None) + skips += h + skips *= self.res_factor + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + print(f"Weight norm is removed from {m}.") + remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) diff --git a/viXTTS/TTS/vocoder/models/parallel_wavegan_generator.py b/viXTTS/TTS/vocoder/models/parallel_wavegan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..8338d946537ecdd2b37c5c53bb9846a4d73a9290 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/parallel_wavegan_generator.py @@ -0,0 +1,164 @@ +import math + +import numpy as np +import torch +from torch.nn.utils.parametrize import remove_parametrizations + +from TTS.utils.io import load_fsspec +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock +from TTS.vocoder.layers.upsample import ConvUpsample + + +class ParallelWaveganGenerator(torch.nn.Module): + """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. + It is similar to WaveNet with no causal convolution. + It is conditioned on an aux feature (spectrogram) to generate + an output waveform from an input noise. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=[4, 4, 4, 4], + inference_padding=2, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.num_res_blocks = num_res_blocks + self.stacks = stacks + self.kernel_size = kernel_size + self.upsample_factors = upsample_factors + self.upsample_scale = np.prod(upsample_factors) + self.inference_padding = inference_padding + self.use_weight_norm = use_weight_norm + + # check the number of layers and stacks + assert num_res_blocks % stacks == 0 + layers_per_stack = num_res_blocks // stacks + + # define first convolution + self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True) + + # define conv + upsampling network + self.upsample_net = ConvUpsample(upsample_factors=upsample_factors) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(num_res_blocks): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True), + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True), + ] + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, c): + """ + c: (B, C ,T'). + o: Output tensor (B, out_channels, T) + """ + # random noise + x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale]) + x = x.to(self.first_conv.bias.device) + + # perform upsampling + if c is not None and self.upsample_net is not None: + c = self.upsample_net(c) + assert ( + c.shape[-1] == x.shape[-1] + ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" + + # encode to hidden representation + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, c) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + + return x + + @torch.no_grad() + def inference(self, c): + c = c.to(self.first_conv.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + # print(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + if self.use_weight_norm: + self.remove_weight_norm() diff --git a/viXTTS/TTS/vocoder/models/random_window_discriminator.py b/viXTTS/TTS/vocoder/models/random_window_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..79b68e9780ff1703b0f4955e37fe93c02d97ea09 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/random_window_discriminator.py @@ -0,0 +1,203 @@ +import numpy as np +from torch import nn + + +class GBlock(nn.Module): + def __init__(self, in_channels, cond_channels, downsample_factor): + super().__init__() + + self.in_channels = in_channels + self.cond_channels = cond_channels + self.downsample_factor = downsample_factor + + self.start = nn.Sequential( + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + nn.ReLU(), + nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1), + ) + self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1) + self.end = nn.Sequential( + nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2) + ) + self.residual = nn.Sequential( + nn.Conv1d(in_channels, in_channels * 2, kernel_size=1), + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + ) + + def forward(self, inputs, conditions): + outputs = self.start(inputs) + self.lc_conv1d(conditions) + outputs = self.end(outputs) + residual_outputs = self.residual(inputs) + outputs = outputs + residual_outputs + + return outputs + + +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, downsample_factor): + super().__init__() + + self.in_channels = in_channels + self.downsample_factor = downsample_factor + self.out_channels = out_channels + + self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor) + self.layers = nn.Sequential( + nn.ReLU(), + nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2), + ) + self.residual = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=1), + ) + + def forward(self, inputs): + if self.downsample_factor > 1: + outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs)) + else: + outputs = self.layers(inputs) + self.residual(inputs) + return outputs + + +class ConditionalDiscriminator(nn.Module): + def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)): + super().__init__() + + assert len(downsample_factors) == len(out_channels) + 1 + + self.in_channels = in_channels + self.cond_channels = cond_channels + self.downsample_factors = downsample_factors + self.out_channels = out_channels + + self.pre_cond_layers = nn.ModuleList() + self.post_cond_layers = nn.ModuleList() + + # layers before condition features + self.pre_cond_layers += [DBlock(in_channels, 64, 1)] + in_channels = 64 + for i, channel in enumerate(out_channels): + self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i])) + in_channels = channel + + # condition block + self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1]) + + # layers after condition block + self.post_cond_layers += [ + DBlock(in_channels * 2, in_channels * 2, 1), + DBlock(in_channels * 2, in_channels * 2, 1), + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(in_channels * 2, 1, kernel_size=1), + ] + + def forward(self, inputs, conditions): + batch_size = inputs.size()[0] + outputs = inputs.view(batch_size, self.in_channels, -1) + for layer in self.pre_cond_layers: + outputs = layer(outputs) + outputs = self.cond_block(outputs, conditions) + for layer in self.post_cond_layers: + outputs = layer(outputs) + + return outputs + + +class UnconditionalDiscriminator(nn.Module): + def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)): + super().__init__() + + self.downsample_factors = downsample_factors + self.in_channels = in_channels + self.downsample_factors = downsample_factors + self.out_channels = out_channels + + self.layers = nn.ModuleList() + self.layers += [DBlock(self.in_channels, base_channels, 1)] + in_channels = base_channels + for i, factor in enumerate(downsample_factors): + self.layers.append(DBlock(in_channels, out_channels[i], factor)) + in_channels *= 2 + self.layers += [ + DBlock(in_channels, in_channels, 1), + DBlock(in_channels, in_channels, 1), + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(in_channels, 1, kernel_size=1), + ] + + def forward(self, inputs): + batch_size = inputs.size()[0] + outputs = inputs.view(batch_size, self.in_channels, -1) + for layer in self.layers: + outputs = layer(outputs) + return outputs + + +class RandomWindowDiscriminator(nn.Module): + """Random Window Discriminator as described in + http://arxiv.org/abs/1909.11646""" + + def __init__( + self, + cond_channels, + hop_length, + uncond_disc_donwsample_factors=(8, 4), + cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)), + cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)), + window_sizes=(512, 1024, 2048, 4096, 8192), + ): + super().__init__() + self.cond_channels = cond_channels + self.window_sizes = window_sizes + self.hop_length = hop_length + self.base_window_size = self.hop_length * 2 + self.ks = [ws // self.base_window_size for ws in window_sizes] + + # check arguments + assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes) + for ws in window_sizes: + assert ws % hop_length == 0 + + for idx, cf in enumerate(cond_disc_downsample_factors): + assert np.prod(cf) == hop_length // self.ks[idx] + + # define layers + self.unconditional_discriminators = nn.ModuleList([]) + for k in self.ks: + layer = UnconditionalDiscriminator( + in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors + ) + self.unconditional_discriminators.append(layer) + + self.conditional_discriminators = nn.ModuleList([]) + for idx, k in enumerate(self.ks): + layer = ConditionalDiscriminator( + in_channels=k, + cond_channels=cond_channels, + downsample_factors=cond_disc_downsample_factors[idx], + out_channels=cond_disc_out_channels[idx], + ) + self.conditional_discriminators.append(layer) + + def forward(self, x, c): + scores = [] + feats = [] + # unconditional pass + for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators): + index = np.random.randint(x.shape[-1] - window_size) + + score = layer(x[:, :, index : index + window_size]) + scores.append(score) + + # conditional pass + for window_size, layer in zip(self.window_sizes, self.conditional_discriminators): + frame_size = window_size // self.hop_length + lc_index = np.random.randint(c.shape[-1] - frame_size) + sample_index = lc_index * self.hop_length + x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length] + c_sub = c[:, :, lc_index : lc_index + frame_size] + + score = layer(x_sub, c_sub) + scores.append(score) + return scores, feats diff --git a/viXTTS/TTS/vocoder/models/univnet_discriminator.py b/viXTTS/TTS/vocoder/models/univnet_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..497d67ac76d0cb797db92bba3f9ad24ab5293a64 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/univnet_discriminator.py @@ -0,0 +1,95 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import weight_norm + +from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator + +LRELU_SLOPE = 0.1 + + +class SpecDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.fft_size = fft_size + self.hop_length = hop_length + self.win_length = win_length + self.stft = TorchSTFT(fft_size, hop_length, win_length) + self.discriminators = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), + ] + ) + + self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1)) + + def forward(self, y): + fmap = [] + with torch.no_grad(): + y = y.squeeze(1) + y = self.stft(y) + y = y.unsqueeze(1) + for _, d in enumerate(self.discriminators): + y = d(y) + y = F.leaky_relu(y, LRELU_SLOPE) + fmap.append(y) + + y = self.out(y) + fmap.append(y) + + return torch.flatten(y, 1, -1), fmap + + +class MultiResSpecDiscriminator(torch.nn.Module): + def __init__( # pylint: disable=dangerous-default-value + self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window" + ): + super().__init__() + self.discriminators = nn.ModuleList( + [ + SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window), + SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window), + SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window), + ] + ) + + def forward(self, x): + scores = [] + feats = [] + for d in self.discriminators: + score, feat = d(x) + scores.append(score) + feats.append(feat) + + return scores, feats + + +class UnivnetDiscriminator(nn.Module): + """Univnet discriminator wrapping MPD and MSD.""" + + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiResSpecDiscriminator() + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.mpd(x) + scores_, feats_ = self.msd(x) + return scores + scores_, feats + feats_ diff --git a/viXTTS/TTS/vocoder/models/univnet_generator.py b/viXTTS/TTS/vocoder/models/univnet_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..5e66b70df8c64cb64b300c337f782ea8b7235fc0 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/univnet_generator.py @@ -0,0 +1,157 @@ +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils import parametrize + +from TTS.vocoder.layers.lvc_block import LVCBlock + +LRELU_SLOPE = 0.1 + + +class UnivnetGenerator(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + cond_channels: int, + upsample_factors: List[int], + lvc_layers_each_block: int, + lvc_kernel_size: int, + kpnet_hidden_channels: int, + kpnet_conv_size: int, + dropout: float, + use_weight_norm=True, + ): + """Univnet Generator network. + + Paper: https://arxiv.org/pdf/2106.07889.pdf + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of channels of the output tensor. + hidden_channels (int): Number of hidden network channels. + cond_channels (int): Number of channels of the conditioning tensors. + upsample_factors (List[int]): List of uplsample factors for the upsampling layers. + lvc_layers_each_block (int): Number of LVC layers in each block. + lvc_kernel_size (int): Kernel size of the LVC layers. + kpnet_hidden_channels (int): Number of hidden channels in the key-point network. + kpnet_conv_size (int): Number of convolution channels in the key-point network. + dropout (float): Dropout rate. + use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True. + """ + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.cond_channels = cond_channels + self.upsample_scale = np.prod(upsample_factors) + self.lvc_block_nums = len(upsample_factors) + + # define first convolution + self.first_conv = torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True + ) + + # define residual blocks + self.lvc_blocks = torch.nn.ModuleList() + cond_hop_length = 1 + for n in range(self.lvc_block_nums): + cond_hop_length = cond_hop_length * upsample_factors[n] + lvcb = LVCBlock( + in_channels=hidden_channels, + cond_channels=cond_channels, + upsample_ratio=upsample_factors[n], + conv_layers=lvc_layers_each_block, + conv_kernel_size=lvc_kernel_size, + cond_hop_length=cond_hop_length, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=dropout, + ) + self.lvc_blocks += [lvcb] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True + ), + ] + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, c): + """Calculate forward propagation. + Args: + c (Tensor): Local conditioning auxiliary features (B, C ,T'). + Returns: + Tensor: Output tensor (B, out_channels, T) + """ + # random noise + x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) + x = x.to(self.first_conv.bias.device) + x = self.first_conv(x) + + for n in range(self.lvc_block_nums): + x = self.lvc_blocks[n](x, c) + + # apply final layers + for f in self.last_conv_layers: + x = F.leaky_relu(x, LRELU_SLOPE) + x = f(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + parametrize.remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + # print(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + """Return receptive field size.""" + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + @torch.no_grad() + def inference(self, c): + """Perform inference. + Args: + c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`. + Returns: + Tensor: Output tensor (T, out_channels) + """ + x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) + x = x.to(self.first_conv.bias.device) + + c = c.to(next(self.parameters())) + return self.forward(c) diff --git a/viXTTS/TTS/vocoder/models/wavegrad.py b/viXTTS/TTS/vocoder/models/wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..c1166e0914fcc5a8a595fb4d61792540a86e2e0e --- /dev/null +++ b/viXTTS/TTS/vocoder/models/wavegrad.py @@ -0,0 +1,345 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.utils.io import load_fsspec +from TTS.vocoder.datasets import WaveGradDataset +from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.generic_utils import plot_results + + +@dataclass +class WavegradArgs(Coqpit): + in_channels: int = 80 + out_channels: int = 1 + use_weight_norm: bool = False + y_conv_channels: int = 32 + x_conv_channels: int = 768 + dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512]) + ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128]) + upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2]) + upsample_dilations: List[List[int]] = field( + default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]] + ) + + +class Wavegrad(BaseVocoder): + """🐸 🌊 WaveGrad 🌊 model. + Paper - https://arxiv.org/abs/2009.00713 + + Examples: + Initializing the model. + + >>> from TTS.vocoder.configs import WavegradConfig + >>> config = WavegradConfig() + >>> model = Wavegrad(config) + + Paper Abstract: + This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the + data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts + from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned + on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting + the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in + terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations. + Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive + baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations. + Audio samples are available at this https URL. + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit): + super().__init__(config) + self.config = config + self.use_weight_norm = config.model_params.use_weight_norm + self.hop_len = np.prod(config.model_params.upsample_factors) + self.noise_level = None + self.num_steps = None + self.beta = None + self.alpha = None + self.alpha_hat = None + self.c1 = None + self.c2 = None + self.sigma = None + + # dblocks + self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2) + self.dblocks = nn.ModuleList([]) + ic = config.model_params.y_conv_channels + for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)): + self.dblocks.append(DBlock(ic, oc, df)) + ic = oc + + # film + self.film = nn.ModuleList([]) + ic = config.model_params.y_conv_channels + for oc in reversed(config.model_params.ublock_out_channels): + self.film.append(FiLM(ic, oc)) + ic = oc + + # ublocksn + self.ublocks = nn.ModuleList([]) + ic = config.model_params.x_conv_channels + for oc, uf, ud in zip( + config.model_params.ublock_out_channels, + config.model_params.upsample_factors, + config.model_params.upsample_dilations, + ): + self.ublocks.append(UBlock(ic, oc, uf, ud)) + ic = oc + + self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1) + self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1) + + if config.model_params.use_weight_norm: + self.apply_weight_norm() + + def forward(self, x, spectrogram, noise_scale): + shift_and_scale = [] + + x = self.y_conv(x) + shift_and_scale.append(self.film[0](x, noise_scale)) + + for film, layer in zip(self.film[1:], self.dblocks): + x = layer(x) + shift_and_scale.append(film(x, noise_scale)) + + x = self.x_conv(spectrogram) + for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)): + x = layer(x, film_shift, film_scale) + x = self.out_conv(x) + return x + + def load_noise_schedule(self, path): + beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg + self.compute_noise_level(beta) + + @torch.no_grad() + def inference(self, x, y_n=None): + """ + Shapes: + x: :math:`[B, C , T]` + y_n: :math:`[B, 1, T]` + """ + if y_n is None: + y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1]) + else: + y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0) + y_n = y_n.type_as(x) + sqrt_alpha_hat = self.noise_level.to(x) + for n in range(len(self.alpha) - 1, -1, -1): + y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) + if n > 0: + z = torch.randn_like(y_n) + y_n += self.sigma[n - 1] * z + y_n.clamp_(-1.0, 1.0) + return y_n + + def compute_y_n(self, y_0): + """Compute noisy audio based on noise schedule""" + self.noise_level = self.noise_level.to(y_0) + if len(y_0.shape) == 3: + y_0 = y_0.squeeze(1) + s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]]) + l_a, l_b = self.noise_level[s], self.noise_level[s + 1] + noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) + noise_scale = noise_scale.unsqueeze(1) + noise = torch.randn_like(y_0) + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise + return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] + + def compute_noise_level(self, beta): + """Compute noise schedule parameters""" + self.num_steps = len(beta) + alpha = 1 - beta + alpha_hat = np.cumprod(alpha) + noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0) + noise_level = alpha_hat**0.5 + + # pylint: disable=not-callable + self.beta = torch.tensor(beta.astype(np.float32)) + self.alpha = torch.tensor(alpha.astype(np.float32)) + self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) + self.noise_level = torch.tensor(noise_level.astype(np.float32)) + + self.c1 = 1 / self.alpha**0.5 + self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 + self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 + + def remove_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + try: + remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + try: + remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + try: + remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + remove_parametrizations(self.x_conv, "weight") + remove_parametrizations(self.out_conv, "weight") + remove_parametrizations(self.y_conv, "weight") + + def apply_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + self.x_conv = weight_norm(self.x_conv) + self.out_conv = weight_norm(self.out_conv) + self.y_conv = weight_norm(self.y_conv) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + if self.config.model_params.use_weight_norm: + self.remove_weight_norm() + betas = np.linspace( + config["test_noise_schedule"]["min_val"], + config["test_noise_schedule"]["max_val"], + config["test_noise_schedule"]["num_steps"], + ) + self.compute_noise_level(betas) + else: + betas = np.linspace( + config["train_noise_schedule"]["min_val"], + config["train_noise_schedule"]["max_val"], + config["train_noise_schedule"]["num_steps"], + ) + self.compute_noise_level(betas) + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + # format data + x = batch["input"] + y = batch["waveform"] + + # set noise scale + noise, x_noisy, noise_scale = self.compute_y_n(y) + + # forward pass + noise_hat = self.forward(x_noisy, x, noise_scale) + + # compute losses + loss = criterion(noise, noise_hat) + return {"model_output": noise_hat}, {"loss": loss} + + def train_log( # pylint: disable=no-self-use + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + pass + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + def eval_log( # pylint: disable=no-self-use + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> None: + pass + + def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument + # setup noise schedule and inference + ap = assets["audio_processor"] + noise_schedule = self.config["test_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) + samples = test_loader.dataset.load_test_samples(1) + for sample in samples: + x = sample[0] + x = x[None, :, :].to(next(self.parameters()).device) + y = sample[1] + y = y[None, :] + # compute voice + y_pred = self.inference(x) + # compute spectrograms + figures = plot_results(y_pred, y, ap, "test") + # Sample audio + sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() + return figures, {"test/audio": sample_voice} + + def get_optimizer(self): + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer): + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + + @staticmethod + def get_criterion(): + return torch.nn.L1Loss() + + @staticmethod + def format_batch(batch: Dict) -> Dict: + # return a whole audio segment + m, y = batch[0], batch[1] + y = y.unsqueeze(1) + return {"input": m, "waveform": y} + + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): + ap = assets["audio_processor"] + dataset = WaveGradDataset( + ap=ap, + items=samples, + seq_len=self.config.seq_len, + hop_len=ap.hop_length, + pad_short=self.config.pad_short, + conv_pad=self.config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + verbose=verbose, + ) + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=self.config.batch_size, + shuffle=num_gpus <= 1, + drop_last=False, + sampler=sampler, + num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers, + pin_memory=False, + ) + return loader + + def on_epoch_start(self, trainer): # pylint: disable=unused-argument + noise_schedule = self.config["train_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) + + @staticmethod + def init_from_config(config: "WavegradConfig"): + return Wavegrad(config) diff --git a/viXTTS/TTS/vocoder/models/wavernn.py b/viXTTS/TTS/vocoder/models/wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..7f74ba3ebf71f1b33cfade2a65f58ef40b6c3c48 --- /dev/null +++ b/viXTTS/TTS/vocoder/models/wavernn.py @@ -0,0 +1,646 @@ +import sys +import time +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_decode +from TTS.utils.io import load_fsspec +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset +from TTS.vocoder.layers.losses import WaveRNNLoss +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian + + +def stream(string, variables): + sys.stdout.write(f"\r{string}" % variables) + + +# pylint: disable=abstract-method +# relates https://github.com/pytorch/pytorch/issues/42305 +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Module): + def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + k_size = pad * 2 + 1 + self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for _ in range(num_res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__( + self, + feat_dims, + upsample_scales, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ): + super().__init__() + self.total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * self.total_scale + self.use_aux_net = use_aux_net + if use_aux_net: + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet_stretch = Stretch2d(self.total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1.0 / k_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + aux = aux.transpose(1, 2) + else: + aux = None + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + m = m.squeeze(1)[:, :, self.indent : -self.indent] + return m.transpose(1, 2), aux + + +class Upsample(nn.Module): + def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net): + super().__init__() + self.scale = scale + self.pad = pad + self.indent = pad * scale + self.use_aux_net = use_aux_net + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m) + aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True) + aux = aux.transpose(1, 2) + else: + aux = None + m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True) + m = m[:, :, self.indent : -self.indent] + m = m * 0.045 # empirically found + + return m.transpose(1, 2), aux + + +@dataclass +class WavernnArgs(Coqpit): + """🐸 WaveRNN model arguments. + + rnn_dims (int): + Number of hidden channels in RNN layers. Defaults to 512. + fc_dims (int): + Number of hidden channels in fully-conntected layers. Defaults to 512. + compute_dims (int): + Number of hidden channels in the feature ResNet. Defaults to 128. + res_out_dim (int): + Number of hidden channels in the feature ResNet output. Defaults to 128. + num_res_blocks (int): + Number of residual blocks in the ResNet. Defaults to 10. + use_aux_net (bool): + enable/disable the feature ResNet. Defaults to True. + use_upsample_net (bool): + enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True. + upsample_factors (list): + Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```. + mode (str): + Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single + Gaussian Distribution and `bits` for quantized bits as the model's output. + mulaw (bool): + enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults + to `True`. + pad (int): + Padding applied to the input feature frames against the convolution layers of the feature network. + Defaults to 2. + """ + + rnn_dims: int = 512 + fc_dims: int = 512 + compute_dims: int = 128 + res_out_dims: int = 128 + num_res_blocks: int = 10 + use_aux_net: bool = True + use_upsample_net: bool = True + upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8]) + mode: str = "mold" # mold [string], gauss [string], bits [int] + mulaw: bool = True # apply mulaw if mode is bits + pad: int = 2 + feat_dims: int = 80 + + +class Wavernn(BaseVocoder): + def __init__(self, config: Coqpit): + """🐸 WaveRNN model. + Original paper - https://arxiv.org/abs/1802.08435 + Official implementation - https://github.com/fatchord/WaveRNN + + Args: + config (Coqpit): [description] + + Raises: + RuntimeError: [description] + + Examples: + >>> from TTS.vocoder.configs import WavernnConfig + >>> config = WavernnConfig() + >>> model = Wavernn(config) + + Paper Abstract: + Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to + both estimating the data distribution and generating high-quality samples. Efficient sampling for this + class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we + describe a set of general techniques for reducing sampling time while maintaining high output quality. + We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that + matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it + possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight + pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of + parameters, large sparse networks perform better than small dense networks and this relationship holds for + sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample + high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on + subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple + samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an + orthogonal method for increasing sampling efficiency. + """ + super().__init__(config) + + if isinstance(self.args.mode, int): + self.n_classes = 2**self.args.mode + elif self.args.mode == "mold": + self.n_classes = 3 * 10 + elif self.args.mode == "gauss": + self.n_classes = 2 + else: + raise RuntimeError("Unknown model mode value - ", self.args.mode) + + self.ap = AudioProcessor(**config.audio.to_dict()) + self.aux_dims = self.args.res_out_dims // 4 + + if self.args.use_upsample_net: + assert ( + np.cumproduct(self.args.upsample_factors)[-1] == config.audio.hop_length + ), " [!] upsample scales needs to be equal to hop_length" + self.upsample = UpsampleNetwork( + self.args.feat_dims, + self.args.upsample_factors, + self.args.compute_dims, + self.args.num_res_blocks, + self.args.res_out_dims, + self.args.pad, + self.args.use_aux_net, + ) + else: + self.upsample = Upsample( + config.audio.hop_length, + self.args.pad, + self.args.num_res_blocks, + self.args.feat_dims, + self.args.compute_dims, + self.args.res_out_dims, + self.args.use_aux_net, + ) + if self.args.use_aux_net: + self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) + else: + self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) + + def forward(self, x, mels): + bsize = x.size(0) + h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) + h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) + mels, aux = self.upsample(mels) + + if self.args.use_aux_net: + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] + + x = ( + torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + if self.args.use_aux_net + else torch.cat([x.unsqueeze(-1), mels], dim=2) + ) + x = self.I(x) + res = x + self.rnn1.flatten_parameters() + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x + self.rnn2.flatten_parameters() + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def inference(self, mels, batched=None, target=None, overlap=None): + self.eval() + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + + with torch.no_grad(): + if isinstance(mels, np.ndarray): + mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device)) + + if mels.ndim == 2: + mels = mels.unsqueeze(0) + wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length + + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both") + mels, aux = self.upsample(mels.transpose(1, 2)) + + if batched: + mels = self.fold_with_overlap(mels, target, overlap) + if aux is not None: + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mels.size() + + h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) + h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) + x = torch.zeros(b_size, 1).type_as(mels) + + if self.args.use_aux_net: + d = self.aux_dims + aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + m_t = mels[:, i, :] + + if self.args.use_aux_net: + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + + x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1) + x = self.I(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.args.mode == "mold": + sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).type_as(mels) + elif self.args.mode == "gauss": + sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).type_as(mels) + elif isinstance(self.args.mode, int): + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0 + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError("Unknown model mode value - ", self.args.mode) + + if i % 100 == 0: + self.gen_display(i, seq_len, b_size, start) + + output = torch.stack(output).transpose(0, 1) + output = output.cpu() + if batched: + output = output.numpy() + output = output.astype(np.float64) + + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + if self.args.mulaw and isinstance(self.args.mode, int): + output = mulaw_decode(wav=output, mulaw_qc=self.args.mode) + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) + output = output[:wave_len] + + if wave_len > len(fade_out): + output[-20 * self.config.audio.hop_length :] *= fade_out + + self.train() + return output + + def gen_display(self, i, seq_len, b_size, start): + gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 + realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate + stream( + "%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ", + (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio), + ) + + def fold_with_overlap(self, x, target, overlap): + """Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + Args: + x (tensor) : Upsampled conditioning features. + shape=(1, timesteps, features) + target (int) : Target timesteps for each index of batch + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + Details: + x = [[h1, h2, ... hn]] + Where each h is a vector of conditioning features + Eg: target=2, overlap=1 with x.size(1)=10 + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + """ + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side="after") + + folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device) + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + + @staticmethod + def get_gru_cell(gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + + @staticmethod + def pad_tensor(x, pad, side="both"): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == "both" else t + pad + padded = torch.zeros(b, total, c).to(x.device) + if side in ("before", "both"): + padded[:, pad : pad + t, :] = x + elif side == "after": + padded[:, :t, :] = x + return padded + + @staticmethod + def xfade_and_unfold(y, target, overlap): + """Applies a crossfade and unfolds into a 1d array. + Args: + y (ndarry) : Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=np.float64 + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (ndarry) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + Details: + y = [[seq1], + [seq2], + [seq3]] + Apply a gain envelope at both ends of the sequences + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + Stagger and add up the groups of samples: + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + """ + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = np.zeros((silence_len), dtype=np.float64) + + # Equal power crossfade + t = np.linspace(-1, 1, fade_len, dtype=np.float64) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = np.concatenate([silence, fade_in]) + fade_out = np.concatenate([fade_out, silence]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = np.zeros((total_len), dtype=np.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + mels = batch["input"] + waveform = batch["waveform"] + waveform_coarse = batch["waveform_coarse"] + + y_hat = self.forward(waveform, mels) + if isinstance(self.args.mode, int): + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + else: + waveform_coarse = waveform_coarse.float() + waveform_coarse = waveform_coarse.unsqueeze(-1) + # compute losses + loss_dict = criterion(y_hat, waveform_coarse) + return {"model_output": y_hat}, loss_dict + + def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + @torch.no_grad() + def test( + self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, Dict]: + ap = self.ap + figures = {} + audios = {} + samples = test_loader.dataset.load_test_samples(1) + for idx, sample in enumerate(samples): + x = torch.FloatTensor(sample[0]) + x = x.to(next(self.parameters()).device) + y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) + x_hat = ap.melspectrogram(y_hat) + figures.update( + { + f"test_{idx}/ground_truth": plot_spectrogram(x.T), + f"test_{idx}/prediction": plot_spectrogram(x_hat.T), + } + ) + audios.update({f"test_{idx}/audio": y_hat}) + # audios.update({f"real_{idx}/audio": y_hat}) + return figures, audios + + def test_log( + self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + figures, audios = outputs + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def format_batch(batch: Dict) -> Dict: + waveform = batch[0] + mels = batch[1] + waveform_coarse = batch[2] + return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse} + + def get_data_loader( # pylint: disable=no-self-use + self, + config: Coqpit, + assets: Dict, + is_eval: True, + samples: List, + verbose: bool, + num_gpus: int, + ): + ap = self.ap + dataset = WaveRNNDataset( + ap=ap, + items=samples, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_args.pad, + mode=config.model_args.mode, + mulaw=config.model_args.mulaw, + is_training=not is_eval, + verbose=verbose, + ) + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + collate_fn=dataset.collate, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + return loader + + def get_criterion(self): + # define train functions + return WaveRNNLoss(self.args.mode) + + @staticmethod + def init_from_config(config: "WavernnConfig"): + return Wavernn(config) diff --git a/viXTTS/TTS/vocoder/pqmf_output.wav b/viXTTS/TTS/vocoder/pqmf_output.wav new file mode 100644 index 0000000000000000000000000000000000000000..8a77747b00198a4adfd6c398998517df5b4bdb8d Binary files /dev/null and b/viXTTS/TTS/vocoder/pqmf_output.wav differ diff --git a/viXTTS/TTS/vocoder/utils/__init__.py b/viXTTS/TTS/vocoder/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/viXTTS/TTS/vocoder/utils/distribution.py b/viXTTS/TTS/vocoder/utils/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..fe706ba9ffbc3f8aad75285bca34a910246666b3 --- /dev/null +++ b/viXTTS/TTS/vocoder/utils/distribution.py @@ -0,0 +1,154 @@ +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions.normal import Normal + + +def gaussian_loss(y_hat, y, log_std_min=-7.0): + assert y_hat.dim() == 3 + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + # TODO: replace with pytorch dist + log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))) + return log_probs.squeeze().mean() + + +def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + dist = Normal( + mean, + torch.exp(log_std), + ) + sample = dist.sample() + sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) + del dist + return sample + + +def log_sum_exp(x): + """numerically stable log_sum_exp implementation that prevents overflow""" + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + y_hat = y_hat.permute(0, 2, 1) + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + # tf equivalent + + # log_probs = tf.where(x < -0.999, log_cdf_plus, + # tf.where(x > 0.999, log_one_minus_cdf_min, + # tf.where(cdf_delta > 1e-5, + # tf.log(tf.maximum(cdf_delta, 1e-12)), + # log_pdf_mid - np.log(127.5)))) + + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - np.log((num_classes - 1) / 2) + ) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.mean(log_sum_exp(log_probs)) + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + Args: + y (Tensor): :math:`[B, C, T]` + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) + + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) + + return x + + +def to_one_hot(tensor, n, fill_with=1.0): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor) + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot diff --git a/viXTTS/TTS/vocoder/utils/generic_utils.py b/viXTTS/TTS/vocoder/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..63a0af4445b5684e928b83d2f4fdfaf7e8f5b9a2 --- /dev/null +++ b/viXTTS/TTS/vocoder/utils/generic_utils.py @@ -0,0 +1,72 @@ +from typing import Dict + +import numpy as np +import torch +from matplotlib import pyplot as plt + +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor + + +def interpolate_vocoder_input(scale_factor, spec): + """Interpolate spectrogram by the scale factor. + It is mainly used to match the sampling rates of + the tts and vocoder models. + + Args: + scale_factor (float): scale factor to interpolate the spectrogram + spec (np.array): spectrogram to be interpolated + + Returns: + torch.tensor: interpolated spectrogram. + """ + print(" > before interpolation :", spec.shape) + spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable + spec = torch.nn.functional.interpolate( + spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False + ).squeeze(0) + print(" > after interpolation :", spec.shape) + return spec + + +def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: + """Plot the predicted and the real waveform and their spectrograms. + + Args: + y_hat (torch.tensor): Predicted waveform. + y (torch.tensor): Real waveform. + ap (AudioProcessor): Audio processor used to process the waveform. + name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. + + Returns: + Dict: output figures keyed by the name of the figures. + """ """Plot vocoder model results""" + if name_prefix is None: + name_prefix = "" + + # select an instance from batch + y_hat = y_hat[0].squeeze().detach().cpu().numpy() + y = y[0].squeeze().detach().cpu().numpy() + + spec_fake = ap.melspectrogram(y_hat).T + spec_real = ap.melspectrogram(y).T + spec_diff = np.abs(spec_fake - spec_real) + + # plot figure and save it + fig_wave = plt.figure() + plt.subplot(2, 1, 1) + plt.plot(y) + plt.title("groundtruth speech") + plt.subplot(2, 1, 2) + plt.plot(y_hat) + plt.title("generated speech") + plt.tight_layout() + plt.close() + + figures = { + name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake), + name_prefix + "spectrogram/real": plot_spectrogram(spec_real), + name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff), + name_prefix + "speech_comparison": fig_wave, + } + return figures diff --git a/viXTTS/main.py b/viXTTS/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad50735b4c02c0b1d459955d7a5a5200b1a02a1 --- /dev/null +++ b/viXTTS/main.py @@ -0,0 +1,264 @@ +import hashlib +import os +import string +import subprocess +import sys +from datetime import datetime +import torch +import torchaudio +from huggingface_hub import hf_hub_download, snapshot_download +from underthesea import sent_tokenize +from unidecode import unidecode +from vinorm import TTSnorm + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + +XTTS_MODEL = None +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +MODEL_DIR = os.path.join(SCRIPT_DIR, "model") +OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output") +FILTER_SUFFIX = "_DeepFilterNet3.wav" +os.makedirs(OUTPUT_DIR, exist_ok=True) + + +def clear_gpu_cache(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def load_model(checkpoint_dir="model/", repo_id="capleaf/viXTTS", use_deepspeed=False): + global XTTS_MODEL + clear_gpu_cache() + os.makedirs(checkpoint_dir, exist_ok=True) + + required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] + files_in_dir = os.listdir(checkpoint_dir) + if not all(file in files_in_dir for file in required_files): + yield f"Missing model files! Downloading from {repo_id}..." + snapshot_download( + repo_id=repo_id, + repo_type="model", + local_dir=checkpoint_dir, + ) + hf_hub_download( + repo_id="coqui/XTTS-v2", + filename="speakers_xtts.pth", + local_dir=checkpoint_dir, + ) + yield f"Model download finished..." + + xtts_config = os.path.join(checkpoint_dir, "config.json") + config = XttsConfig() + config.load_json(xtts_config) + XTTS_MODEL = Xtts.init_from_config(config) + yield "Loading model..." + XTTS_MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed) + if torch.cuda.is_available(): + XTTS_MODEL.cuda() + + print("Model Loaded!") + yield "Model Loaded!" + + +# Define dictionaries to store cached results +cache_queue = [] +speaker_audio_cache = {} +filter_cache = {} +conditioning_latents_cache = {} + + +def invalidate_cache(cache_limit=50): + """Invalidate the cache for the oldest key""" + if len(cache_queue) > cache_limit: + key_to_remove = cache_queue.pop(0) + print("Invalidating cache", key_to_remove) + if os.path.exists(key_to_remove): + os.remove(key_to_remove) + if os.path.exists(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")): + os.remove(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")) + if key_to_remove in filter_cache: + del filter_cache[key_to_remove] + if key_to_remove in conditioning_latents_cache: + del conditioning_latents_cache[key_to_remove] + + +def generate_hash(data): + hash_object = hashlib.md5() + hash_object.update(data) + return hash_object.hexdigest() + + +def get_file_name(text, max_char=50): + filename = text[:max_char] + filename = filename.lower() + filename = filename.replace(" ", "_") + filename = filename.translate(str.maketrans("", "", string.punctuation.replace("_", ""))) + filename = unidecode(filename) + current_datetime = datetime.now().strftime("%m%d%H%M%S") + filename = f"{current_datetime}_{filename}" + return filename + +from unicodedata import normalize +def normalize_vietnamese_text(text): + text = ( + normalize("NFC", text) + .replace("..", ".") + .replace("!.", "!") + .replace("?.", "?") + .replace(" .", ".") + .replace(" ,", ",") + .replace('"', "") + .replace("'", "") + .replace("AI", "Ây Ai") + .replace("A.I", "Ây Ai") + ) + return text + + +def calculate_keep_len(text, lang): + """Simple hack for short sentences""" + if lang in ["ja", "zh-cn"]: + return -1 + + word_count = len(text.split()) + num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",") + + if word_count < 5: + return 15000 * word_count + 2000 * num_punct + elif word_count < 10: + return 13000 * word_count + 2000 * num_punct + return -1 + + +def run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text): + global filter_cache, conditioning_latents_cache, cache_queue + + if XTTS_MODEL is None: + return "You need to run the previous step to load the model !!", None, None + + if not speaker_audio_file: + return "You need to provide reference audio!!!", None, None + + # Use the file name as the key, since it's suppose to be unique 💀 + speaker_audio_key = speaker_audio_file + if not speaker_audio_key in cache_queue: + cache_queue.append(speaker_audio_key) + invalidate_cache() + + # Check if filtered reference is cached + if use_deepfilter and speaker_audio_key in filter_cache: + print("Using filter cache...") + speaker_audio_file = filter_cache[speaker_audio_key] + elif use_deepfilter: + print("Running filter...") + subprocess.run( + [ + "deepFilter", + speaker_audio_file, + "-o", + os.path.dirname(speaker_audio_file), + ] + ) + filter_cache[speaker_audio_key] = speaker_audio_file.replace(".wav", FILTER_SUFFIX) + speaker_audio_file = filter_cache[speaker_audio_key] + + # Check if conditioning latents are cached + cache_key = ( + speaker_audio_key, + XTTS_MODEL.config.gpt_cond_len, + XTTS_MODEL.config.max_ref_len, + XTTS_MODEL.config.sound_norm_refs, + ) + if cache_key in conditioning_latents_cache: + print("Using conditioning latents cache...") + gpt_cond_latent, speaker_embedding = conditioning_latents_cache[cache_key] + else: + print("Computing conditioning latents...") + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents( + audio_path=speaker_audio_file, + gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, + max_ref_length=XTTS_MODEL.config.max_ref_len, + sound_norm_refs=XTTS_MODEL.config.sound_norm_refs, + ) + conditioning_latents_cache[cache_key] = (gpt_cond_latent, speaker_embedding) + + if normalize_text and lang == "vi": + tts_text = normalize_vietnamese_text(tts_text) + + # Split text by sentence + if lang in ["ja", "zh-cn"]: + sentences = tts_text.split("。") + else: + sentences = sent_tokenize(tts_text) + + wav_chunks = [] + for sentence in sentences: + if sentence.strip() == "": + continue + wav_chunk = XTTS_MODEL.inference( + text=sentence, + language=lang, + gpt_cond_latent=gpt_cond_latent, + speaker_embedding=speaker_embedding, + # The following values are carefully chosen for viXTTS + temperature=0.3, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=30, + top_p=0.85, + enable_text_splitting=True, + ) + + keep_len = calculate_keep_len(sentence, lang) + wav_chunk["wav"] = wav_chunk["wav"][:keep_len] + + wav_chunks.append(torch.tensor(wav_chunk["wav"])) + + out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0) + out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}.wav") + print("Saving output to ", out_path) + torchaudio.save(out_path, out_wav, 24000) + + return "Speech generated !", out_path + + + + +def create_interface(): + try: + # Gọi hàm load_model để tải mô hình + model_loading_gen = load_model(checkpoint_dir=MODEL_DIR, repo_id="capleaf/viXTTS", use_deepspeed=False) + + # Chạy hàm này cho đến khi mô hình được tải xong + for message in model_loading_gen: + print(message) # In ra thông báo trạng thái tải mô hình + + # Các tham số khác + speaker_audio_files = [ + r"samples\nu-nhe-nhang.wav", + r"samples\nu-nhan-nha.wav", + r"samples\nu-luu-loat.wav", + r"samples\nu-cham.wav", + r"samples\nu-calm.wav", + r"samples\nam-truyen-cam.wav", + r"samples\nam-nhanh.wav", + r"samples\nam-cham.wav", + r"samples\nam-calm.wav", + ] + + speaker_audio_file = speaker_audio_files[0] + # Các tham số khác + lang = "vi" + normalize_text = True + use_deepfilter = False + tts_text = "Chào bạn, tôi là một trợ lý ảo." + + # Gọi hàm run_tts sau khi mô hình đã được tải + return run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text) + except Exception as e: + return f"Error loading model: {str(e)}", None, None + + +# Gọi hàm create_interface để bắt đầu quá trình +print(create_interface()) diff --git a/viXTTS/model/.cache/huggingface/.gitignore b/viXTTS/model/.cache/huggingface/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f59ec20aabf5842d237244ece8c81ab184faeac1 --- /dev/null +++ b/viXTTS/model/.cache/huggingface/.gitignore @@ -0,0 +1 @@ +* \ No newline at end of file diff --git a/viXTTS/model/.cache/huggingface/download/config.json.metadata b/viXTTS/model/.cache/huggingface/download/config.json.metadata new file mode 100644 index 0000000000000000000000000000000000000000..410ebe7d41376c23ffbd4a8a026c1e06c135f203 --- /dev/null +++ b/viXTTS/model/.cache/huggingface/download/config.json.metadata @@ -0,0 +1,3 @@ +c06f4378883110615941aab481532a9802440b05 +92644a3742afbe48ca3f9b89b69a1bfc82b2b645 +1732091204.403793 diff --git a/viXTTS/model/.cache/huggingface/download/model.pth.metadata b/viXTTS/model/.cache/huggingface/download/model.pth.metadata new file mode 100644 index 0000000000000000000000000000000000000000..5a9b9af8cb9e46ce5be6089e1230932dd2312e4b --- /dev/null +++ b/viXTTS/model/.cache/huggingface/download/model.pth.metadata @@ -0,0 +1,3 @@ +c06f4378883110615941aab481532a9802440b05 +534670e4b752002b7d7224e6ea1f467bd608c8dd3c36efaa45e1f4696e8bd1d2 +1732091203.8350449 diff --git a/viXTTS/model/.cache/huggingface/download/vocab.json.metadata b/viXTTS/model/.cache/huggingface/download/vocab.json.metadata new file mode 100644 index 0000000000000000000000000000000000000000..931f3cd628c83c0608cdfa81cceede206f00a958 --- /dev/null +++ b/viXTTS/model/.cache/huggingface/download/vocab.json.metadata @@ -0,0 +1,3 @@ +c06f4378883110615941aab481532a9802440b05 +d2519ef0c4318469467e5449e63f357cd3c165b0 +1732091205.8990037 diff --git a/viXTTS/model/config.json b/viXTTS/model/config.json new file mode 100644 index 0000000000000000000000000000000000000000..92644a3742afbe48ca3f9b89b69a1bfc82b2b645 --- /dev/null +++ b/viXTTS/model/config.json @@ -0,0 +1,176 @@ +{ + "output_path": "output", + "logger_uri": null, + "run_name": "run", + "project_name": null, + "run_description": "viXTTS training", + "print_step": null, + "plot_step": null, + "model_param_stats": false, + "wandb_entity": null, + "dashboard_logger": "tensorboard", + "save_on_interrupt": true, + "log_model_step": null, + "save_step": 24000, + "save_n_checkpoints": 2, + "save_checkpoints": true, + "save_all_best": false, + "save_best_after": 0, + "target_loss": null, + "print_eval": true, + "test_delay_epochs": 0, + "run_eval": true, + "run_eval_steps": null, + "distributed_backend": "nccl", + "distributed_url": "tcp://localhost:54321", + "mixed_precision": false, + "precision": "fp16", + "epochs": 5, + "batch_size": 2, + "eval_batch_size": 2, + "grad_clip": 0.0, + "scheduler_after_epoch": true, + "lr": 5e-06, + "optimizer": "AdamW", + "optimizer_params": { + "betas": [ + 0.9, + 0.96 + ], + "eps": 1e-08, + "weight_decay": 0.01 + }, + "lr_scheduler": "MultiStepLR", + "lr_scheduler_params": { + "milestones": [ + 900000, + 2700000, + 5400000 + ], + "gamma": 0.5, + "last_epoch": -1 + }, + "use_grad_scaler": false, + "allow_tf32": false, + "cudnn_enable": true, + "cudnn_deterministic": false, + "cudnn_benchmark": false, + "training_seed": 1, + "model": "xtts", + "num_loader_workers": 0, + "num_eval_loader_workers": 0, + "use_noise_augment": false, + "audio": { + "sample_rate": 22050, + "output_sample_rate": 24000, + "dvae_sample_rate": 22050 + }, + "use_phonemes": false, + "phonemizer": null, + "phoneme_language": null, + "compute_input_seq_cache": false, + "text_cleaner": null, + "enable_eos_bos_chars": false, + "test_sentences_file": "", + "phoneme_cache_path": null, + "characters": null, + "add_blank": false, + "batch_group_size": 48, + "loss_masking": null, + "min_audio_len": 1, + "max_audio_len": Infinity, + "min_text_len": 1, + "max_text_len": Infinity, + "compute_f0": false, + "compute_energy": false, + "compute_linear_spec": false, + "precompute_num_workers": 0, + "start_by_longest": false, + "shuffle": false, + "drop_last": false, + "datasets": [ + { + "formatter": "", + "dataset_name": "", + "path": "", + "meta_file_train": "", + "ignored_speakers": null, + "language": "", + "phonemizer": "", + "meta_file_val": "", + "meta_file_attn_mask": "" + } + ], + "test_sentences": [], + "eval_split_max_size": null, + "eval_split_size": 0.01, + "use_speaker_weighted_sampler": false, + "speaker_weighted_sampler_alpha": 1.0, + "use_language_weighted_sampler": false, + "language_weighted_sampler_alpha": 1.0, + "use_length_weighted_sampler": false, + "length_weighted_sampler_alpha": 1.0, + "model_args": { + "gpt_batch_size": 1, + "enable_redaction": false, + "kv_cache": true, + "gpt_checkpoint": null, + "clvp_checkpoint": null, + "decoder_checkpoint": null, + "num_chars": 255, + "tokenizer_file": "", + "gpt_max_audio_tokens": 605, + "gpt_max_text_tokens": 402, + "gpt_max_prompt_tokens": 70, + "gpt_layers": 30, + "gpt_n_model_channels": 1024, + "gpt_n_heads": 16, + "gpt_number_text_tokens": 7544, + "gpt_start_text_token": null, + "gpt_stop_text_token": null, + "gpt_num_audio_tokens": 1026, + "gpt_start_audio_token": 1024, + "gpt_stop_audio_token": 1025, + "gpt_code_stride_len": 1024, + "gpt_use_masking_gt_prompt_approach": true, + "gpt_use_perceiver_resampler": true, + "input_sample_rate": 22050, + "output_sample_rate": 24000, + "output_hop_length": 256, + "decoder_input_dim": 1024, + "d_vector_dim": 512, + "cond_d_vector_in_each_upsampling_layer": true, + "duration_const": 102400 + }, + "model_dir": null, + "languages": [ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh-cn", + "hu", + "ko", + "ja", + "hi", + "vi" + ], + "temperature": 0.85, + "length_penalty": 1.0, + "repetition_penalty": 2.0, + "top_k": 50, + "top_p": 0.85, + "num_gpt_outputs": 1, + "gpt_cond_len": 12, + "gpt_cond_chunk_len": 4, + "max_ref_len": 10, + "sound_norm_refs": false +} \ No newline at end of file diff --git a/viXTTS/model/model.pth b/viXTTS/model/model.pth new file mode 100644 index 0000000000000000000000000000000000000000..4b810397373b22bf19a1630bbcb92f440dd4529b --- /dev/null +++ b/viXTTS/model/model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:534670e4b752002b7d7224e6ea1f467bd608c8dd3c36efaa45e1f4696e8bd1d2 +size 1875343894 diff --git a/viXTTS/model/speakers_xtts.pth b/viXTTS/model/speakers_xtts.pth new file mode 100644 index 0000000000000000000000000000000000000000..85e8ff78084dcdf67231297758a9eefb770301c3 --- /dev/null +++ b/viXTTS/model/speakers_xtts.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0f6137c19a4eab0cbbe4c99b5babacf68b1746e50da90807708c10e645b943b +size 7754818 diff --git a/viXTTS/model/vocab.json b/viXTTS/model/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..d2519ef0c4318469467e5449e63f357cd3c165b0 --- /dev/null +++ b/viXTTS/model/vocab.json @@ -0,0 +1,17124 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "special": true, + "content": "[STOP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 1, + "special": true, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 2, + "special": true, + "content": "[SPACE]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 259, + "special": true, + "content": "[en]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 260, + "special": true, + "content": "[de]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 261, + "special": true, + "content": "[START]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 262, + "special": true, + "content": "[fr]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 284, + "special": true, + "content": "[es]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 285, + "special": true, + "content": "[it]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 286, + "special": true, + "content": "[pt]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 294, + "special": true, + "content": "[pl]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 295, + "special": true, + "content": "[tr]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 267, + "special": true, + "content": "[ru]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 293, + "special": true, + "content": "[cs]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 297, + "special": true, + "content": "[nl]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5022, + "special": true, + "content": "[ar]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5023, + "special": true, + "content": "[zh-cn]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5412, + "special": true, + "content": "[ja]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 5753, + "special": true, + "content": "[hu]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6152, + "special": true, + "content": "[ko]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 6680, + "special": true, + "content": "[hi]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 7543, + "special": true, + "content": "[vi]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Whitespace" + }, + "post_processor": null, + "decoder": null, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": "[UNK]", + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "vocab": { + "[STOP]": 0, + "[UNK]": 1, + "[SPACE]": 2, + "!": 3, + "'": 4, + "(": 5, + ")": 6, + ",": 7, + "-": 8, + ".": 9, + "/": 10, + ":": 11, + ";": 12, + "?": 13, + "a": 14, + "b": 15, + "c": 16, + "d": 17, + "e": 18, + "f": 19, + "g": 20, + "h": 21, + "i": 22, + "j": 23, + "k": 24, + "l": 25, + "m": 26, + "n": 27, + "o": 28, + "p": 29, + "q": 30, + "r": 31, + "s": 32, + "t": 33, + "u": 34, + "v": 35, + "w": 36, + "x": 37, + "y": 38, + "z": 39, + "th": 40, + "in": 41, + "the": 42, + "an": 43, + "er": 44, + "ou": 45, + "re": 46, + "on": 47, + "at": 48, + "ed": 49, + "en": 50, + "to": 51, + "ing": 52, + "and": 53, + "is": 54, + "as": 55, + "al": 56, + "or": 57, + "of": 58, + "ar": 59, + "it": 60, + "es": 61, + "he": 62, + "st": 63, + "le": 64, + "om": 65, + "se": 66, + "be": 67, + "ad": 68, + "ow": 69, + "ly": 70, + "ch": 71, + "wh": 72, + "that": 73, + "you": 74, + "li": 75, + "ve": 76, + "ac": 77, + "ti": 78, + "ld": 79, + "me": 80, + "was": 81, + "gh": 82, + "id": 83, + "ll": 84, + "wi": 85, + "ent": 86, + "for": 87, + "ay": 88, + "ro": 89, + "ver": 90, + "ic": 91, + "her": 92, + "ke": 93, + "his": 94, + "no": 95, + "ut": 96, + "un": 97, + "ir": 98, + "lo": 99, + "we": 100, + "ri": 101, + "ha": 102, + "with": 103, + "ght": 104, + "out": 105, + "im": 106, + "ion": 107, + "all": 108, + "ab": 109, + "one": 110, + "ne": 111, + "ge": 112, + "ould": 113, + "ter": 114, + "mo": 115, + "had": 116, + "ce": 117, + "she": 118, + "go": 119, + "sh": 120, + "ur": 121, + "am": 122, + "so": 123, + "pe": 124, + "my": 125, + "de": 126, + "are": 127, + "but": 128, + "ome": 129, + "fr": 130, + "ther": 131, + "fe": 132, + "su": 133, + "do": 134, + "con": 135, + "te": 136, + "ain": 137, + "ere": 138, + "po": 139, + "if": 140, + "they": 141, + "us": 142, + "ag": 143, + "tr": 144, + "now": 145, + "oun": 146, + "this": 147, + "have": 148, + "not": 149, + "sa": 150, + "il": 151, + "up": 152, + "thing": 153, + "from": 154, + "ap": 155, + "him": 156, + "ack": 157, + "ation": 158, + "ant": 159, + "our": 160, + "op": 161, + "like": 162, + "ust": 163, + "ess": 164, + "bo": 165, + "ok": 166, + "ul": 167, + "ind": 168, + "ex": 169, + "com": 170, + "some": 171, + "there": 172, + "ers": 173, + "co": 174, + "res": 175, + "man": 176, + "ard": 177, + "pl": 178, + "wor": 179, + "way": 180, + "tion": 181, + "fo": 182, + "ca": 183, + "were": 184, + "by": 185, + "ate": 186, + "pro": 187, + "ted": 188, + "ound": 189, + "own": 190, + "would": 191, + "ts": 192, + "what": 193, + "qu": 194, + "ally": 195, + "ight": 196, + "ck": 197, + "gr": 198, + "when": 199, + "ven": 200, + "can": 201, + "ough": 202, + "ine": 203, + "end": 204, + "per": 205, + "ous": 206, + "od": 207, + "ide": 208, + "know": 209, + "ty": 210, + "very": 211, + "si": 212, + "ak": 213, + "who": 214, + "about": 215, + "ill": 216, + "them": 217, + "est": 218, + "red": 219, + "ye": 220, + "could": 221, + "ong": 222, + "your": 223, + "their": 224, + "em": 225, + "just": 226, + "other": 227, + "into": 228, + "any": 229, + "whi": 230, + "um": 231, + "tw": 232, + "ast": 233, + "der": 234, + "did": 235, + "ie": 236, + "been": 237, + "ace": 238, + "ink": 239, + "ity": 240, + "back": 241, + "ting": 242, + "br": 243, + "more": 244, + "ake": 245, + "pp": 246, + "then": 247, + "sp": 248, + "el": 249, + "use": 250, + "bl": 251, + "said": 252, + "over": 253, + "get": 254, + "ß": 255, + "ä": 256, + "ö": 257, + "ü": 258, + "[en]": 259, + "[de]": 260, + "[START]": 261, + "[fr]": 262, + "œ": 263, + "ï": 264, + "ê": 265, + "â": 266, + "[ru]": 267, + "ÿ": 268, + "è": 269, + "à": 270, + "ë": 271, + "ù": 272, + "î": 273, + "ç": 274, + "æ": 275, + "ô": 276, + "û": 277, + "á": 278, + "é": 279, + "í": 280, + "ó": 281, + "ú": 282, + "ñ": 283, + "[es]": 284, + "[it]": 285, + "[pt]": 286, + "ń": 287, + "ś": 288, + "ę": 289, + "ą": 290, + "ż": 291, + "ć": 292, + "[cs]": 293, + "[pl]": 294, + "[tr]": 295, + "ã": 296, + "[nl]": 297, + "ş": 298, + "ğ": 299, + "ı": 300, + "ò": 301, + "ì": 302, + "¿": 303, + "…": 304, + "i̇": 305, + "õ": 306, + "\"": 307, + "´": 308, + "ø": 309, + "č": 310, + "ō": 311, + "š": 312, + "ž": 313, + "̇": 314, + "ei": 315, + "ich": 316, + "ein": 317, + "au": 318, + "sch": 319, + "und": 320, + "die": 321, + "da": 322, + "den": 323, + "gen": 324, + "zu": 325, + "hr": 326, + "ten": 327, + "mi": 328, + "sie": 329, + "das": 330, + "eine": 331, + "icht": 332, + "ber": 333, + "ach": 334, + "auf": 335, + "lich": 336, + "nicht": 337, + "mm": 338, + "ben": 339, + "war": 340, + "mit": 341, + "sich": 342, + "ig": 343, + "aus": 344, + "ist": 345, + "wie": 346, + "och": 347, + "ung": 348, + "ann": 349, + "ür": 350, + "hn": 351, + "ihr": 352, + "sen": 353, + "tz": 354, + "dem": 355, + "eit": 356, + "hat": 357, + "wir": 358, + "von": 359, + "wei": 360, + "ier": 361, + "ra": 362, + "einen": 363, + "vor": 364, + "als": 365, + "wo": 366, + "rei": 367, + "ste": 368, + "lie": 369, + "auch": 370, + "du": 371, + "des": 372, + "ko": 373, + "über": 374, + "bei": 375, + "hen": 376, + "hm": 377, + "lei": 378, + "aber": 379, + "wen": 380, + "hl": 381, + "ger": 382, + "nach": 383, + "ft": 384, + "imm": 385, + "je": 386, + "schen": 387, + "wer": 388, + "ser": 389, + "än": 390, + "sein": 391, + "ol": 392, + "cht": 393, + "für": 394, + "kl": 395, + "ff": 396, + "einem": 397, + "nen": 398, + "ja": 399, + "noch": 400, + "hatte": 401, + "pf": 402, + "hin": 403, + "di": 404, + "chen": 405, + "rü": 406, + "iel": 407, + "sel": 408, + "dass": 409, + "ihn": 410, + "mir": 411, + "schl": 412, + "ön": 413, + "gan": 414, + "gt": 415, + "einer": 416, + "sten": 417, + "mich": 418, + "wenn": 419, + "ell": 420, + "gte": 421, + "mal": 422, + "gel": 423, + "ken": 424, + "nur": 425, + "mmen": 426, + "fü": 427, + "ern": 428, + "ör": 429, + "unter": 430, + "ander": 431, + "dur": 432, + "uch": 433, + "ta": 434, + "men": 435, + "mach": 436, + "doch": 437, + "durch": 438, + "os": 439, + "gl": 440, + "hal": 441, + "ihre": 442, + "wä": 443, + "immer": 444, + "ihm": 445, + "kann": 446, + "ort": 447, + "dann": 448, + "lan": 449, + "tzt": 450, + "oder": 451, + "hren": 452, + "et": 453, + "kön": 454, + "ick": 455, + "fa": 456, + "wieder": 457, + "daß": 458, + "mein": 459, + "fen": 460, + "ganz": 461, + "diese": 462, + "ster": 463, + "dar": 464, + "wa": 465, + "ges": 466, + "na": 467, + "fl": 468, + "igen": 469, + "sche": 470, + "ungen": 471, + "mehr": 472, + "ßen": 473, + "ot": 474, + "kon": 475, + "gew": 476, + "haben": 477, + "geh": 478, + "ät": 479, + "sind": 480, + "dr": 481, + "wel": 482, + "uns": 483, + "vo": 484, + "ma": 485, + "ute": 486, + "schon": 487, + "bes": 488, + "gesch": 489, + "bt": 490, + "che": 491, + "son": 492, + "ob": 493, + "la": 494, + "rück": 495, + "seine": 496, + "kr": 497, + "fre": 498, + "eil": 499, + "zum": 500, + "hier": 501, + "kt": 502, + "ige": 503, + "spr": 504, + "leben": 505, + "bst": 506, + "zeit": 507, + "gro": 508, + "denn": 509, + "ho": 510, + "scha": 511, + "bar": 512, + "alle": 513, + "gegen": 514, + "wür": 515, + "mü": 516, + "ze": 517, + "werden": 518, + "jetzt": 519, + "kommen": 520, + "nie": 521, + "sei": 522, + "heit": 523, + "soll": 524, + "glei": 525, + "meine": 526, + "woll": 527, + "ner": 528, + "habe": 529, + "wur": 530, + "lichen": 531, + "assen": 532, + "nte": 533, + "sehen": 534, + "wird": 535, + "bis": 536, + "gar": 537, + "ien": 538, + "mus": 539, + "uß": 540, + "är": 541, + "stell": 542, + "keit": 543, + "zwei": 544, + "selbst": 545, + "sta": 546, + "pa": 547, + "sagte": 548, + "tet": 549, + "kam": 550, + "ssen": 551, + "viel": 552, + "ug": 553, + "zen": 554, + "hei": 555, + "mann": 556, + "will": 557, + "geb": 558, + "waren": 559, + "ück": 560, + "äch": 561, + "mer": 562, + "ru": 563, + "hau": 564, + "eigen": 565, + "ang": 566, + "weg": 567, + "blick": 568, + "fra": 569, + "alles": 570, + "ka": 571, + "augen": 572, + "fin": 573, + "liche": 574, + "unser": 575, + "dern": 576, + "herr": 577, + "nun": 578, + "vie": 579, + "chte": 580, + "wohl": 581, + "fall": 582, + "ht": 583, + "ün": 584, + "etwas": 585, + "stand": 586, + "äu": 587, + "mö": 588, + "tel": 589, + "rie": 590, + "dich": 591, + "dies": 592, + "hand": 593, + "bin": 594, + "ffen": 595, + "nichts": 596, + "dan": 597, + "hne": 598, + "ihnen": 599, + "esen": 600, + "dieser": 601, + "frau": 602, + "art": 603, + "dir": 604, + "isch": 605, + "erst": 606, + "gleich": 607, + "komm": 608, + "hör": 609, + "ße": 610, + "dig": 611, + "sehr": 612, + "zei": 613, + "sam": 614, + "aum": 615, + "hät": 616, + "ingen": 617, + "gut": 618, + "mut": 619, + "cken": 620, + "konnte": 621, + "stimm": 622, + "zur": 623, + "itz": 624, + "weil": 625, + "würde": 626, + "fä": 627, + "können": 628, + "keine": 629, + "fer": 630, + "ischen": 631, + "voll": 632, + "eines": 633, + "setz": 634, + "zie": 635, + "del": 636, + "tete": 637, + "seiner": 638, + "ieren": 639, + "gest": 640, + "zurück": 641, + "wurde": 642, + "schn": 643, + "pr": 644, + "ließ": 645, + "tra": 646, + "mä": 647, + "gend": 648, + "fol": 649, + "ik": 650, + "schla": 651, + "schaft": 652, + "ater": 653, + "weiß": 654, + "seinen": 655, + "lassen": 656, + "lu": 657, + "unden": 658, + "teil": 659, + "neu": 660, + "iert": 661, + "menschen": 662, + "hmen": 663, + "str": 664, + "gi": 665, + "sah": 666, + "ihren": 667, + "eln": 668, + "weiter": 669, + "gehen": 670, + "iger": 671, + "macht": 672, + "tag": 673, + "also": 674, + "halten": 675, + "nis": 676, + "acht": 677, + "geben": 678, + "og": 679, + "nat": 680, + "mar": 681, + "det": 682, + "ohne": 683, + "haus": 684, + "tro": 685, + "ange": 686, + "lau": 687, + "spiel": 688, + "tre": 689, + "schr": 690, + "inn": 691, + "los": 692, + "machen": 693, + "hätte": 694, + "beg": 695, + "wirk": 696, + "alt": 697, + "glich": 698, + "tes": 699, + "richt": 700, + "freund": 701, + "ihrer": 702, + "fel": 703, + "bel": 704, + "sol": 705, + "einmal": 706, + "eben": 707, + "hol": 708, + "hän": 709, + "tern": 710, + "hö": 711, + "schw": 712, + "recht": 713, + "wahr": 714, + "seinem": 715, + "stehen": 716, + "hlen": 717, + "ins": 718, + "ging": 719, + "wollte": 720, + "wissen": 721, + "ungs": 722, + "ald": 723, + "ass": 724, + "jahr": 725, + "mor": 726, + "welt": 727, + "under": 728, + "zusa": 729, + "kopf": 730, + "lang": 731, + "hinter": 732, + "atz": 733, + "stra": 734, + "angen": 735, + "ank": 736, + "ade": 737, + "glau": 738, + "fach": 739, + "hatten": 740, + "fort": 741, + "eicht": 742, + "iff": 743, + "ler": 744, + "mei": 745, + "diesem": 746, + "kein": 747, + "frei": 748, + "führ": 749, + "vom": 750, + "β": 751, + "ai": 752, + "ait": 753, + "que": 754, + "les": 755, + "av": 756, + "ais": 757, + "oi": 758, + "eu": 759, + "lle": 760, + "par": 761, + "ans": 762, + "ment": 763, + "ét": 764, + "une": 765, + "pas": 766, + "qui": 767, + "elle": 768, + "dé": 769, + "pour": 770, + "dans": 771, + "ré": 772, + "tou": 773, + "vous": 774, + "vi": 775, + "ouv": 776, + "mon": 777, + "sur": 778, + "ci": 779, + "plu": 780, + "ère": 781, + "mais": 782, + "ois": 783, + "plus": 784, + "ée": 785, + "aient": 786, + "mp": 787, + "lui": 788, + "ave": 789, + "était": 790, + "ses": 791, + "tout": 792, + "oir": 793, + "avait": 794, + "és": 795, + "mes": 796, + "nous": 797, + "eux": 798, + "bi": 799, + "ons": 800, + "pu": 801, + "ces": 802, + "tu": 803, + "leur": 804, + "don": 805, + "eur": 806, + "ette": 807, + "aire": 808, + "avec": 809, + "dit": 810, + "té": 811, + "ille": 812, + "comme": 813, + "cr": 814, + "ux": 815, + "ès": 816, + "aux": 817, + "jour": 818, + "ils": 819, + "bien": 820, + "cou": 821, + "quel": 822, + "peu": 823, + "cette": 824, + "cu": 825, + "mê": 826, + "fait": 827, + "gu": 828, + "être": 829, + "ité": 830, + "ens": 831, + "ni": 832, + "lé": 833, + "dis": 834, + "ble": 835, + "né": 836, + "puis": 837, + "même": 838, + "ques": 839, + "fi": 840, + "age": 841, + "moi": 842, + "ence": 843, + "ont": 844, + "main": 845, + "ors": 846, + "aut": 847, + "ance": 848, + "mé": 849, + "sans": 850, + "sé": 851, + "lon": 852, + "hom": 853, + "car": 854, + "able": 855, + "cher": 856, + "deux": 857, + "enf": 858, + "où": 859, + "ph": 860, + "ure": 861, + "temp": 862, + "pos": 863, + "rent": 864, + "pé": 865, + "faire": 866, + "pi": 867, + "tres": 868, + "ça": 869, + "endre": 870, + "bon": 871, + "sou": 872, + "int": 873, + "pré": 874, + "sent": 875, + "tant": 876, + "cer": 877, + "là": 878, + "lais": 879, + "près": 880, + "bre": 881, + "cour": 882, + "pet": 883, + "comp": 884, + "lait": 885, + "trouv": 886, + "entre": 887, + "sont": 888, + "dev": 889, + "nu": 890, + "temps": 891, + "dou": 892, + "rait": 893, + "bou": 894, + "quand": 895, + "jours": 896, + "avoir": 897, + "été": 898, + "ale": 899, + "pre": 900, + "fois": 901, + "orte": 902, + "vé": 903, + "non": 904, + "tous": 905, + "jus": 906, + "coup": 907, + "homme": 908, + "ête": 909, + "aussi": 910, + "urs": 911, + "seu": 912, + "ord": 913, + "min": 914, + "gé": 915, + "core": 916, + "va": 917, + "vre": 918, + "encore": 919, + "sem": 920, + "ite": 921, + "autre": 922, + "pris": 923, + "peut": 924, + "ue": 925, + "ante": 926, + "gn": 927, + "rép": 928, + "hu": 929, + "sion": 930, + "votre": 931, + "dire": 932, + "ez": 933, + "fem": 934, + "leurs": 935, + "met": 936, + "cri": 937, + "mis": 938, + "tour": 939, + "rai": 940, + "jam": 941, + "regar": 942, + "rien": 943, + "vers": 944, + "suis": 945, + "pouv": 946, + "vis": 947, + "grand": 948, + "ants": 949, + "cor": 950, + "rer": 951, + "cé": 952, + "tent": 953, + "pres": 954, + "vou": 955, + "alors": 956, + "sieur": 957, + "aine": 958, + "quoi": 959, + "fon": 960, + "endant": 961, + "arri": 962, + "eure": 963, + "après": 964, + "donc": 965, + "itu": 966, + "lè": 967, + "sait": 968, + "toi": 969, + "cha": 970, + "ail": 971, + "asse": 972, + "imp": 973, + "voy": 974, + "conn": 975, + "pla": 976, + "petit": 977, + "avant": 978, + "nom": 979, + "tin": 980, + "dont": 981, + "sous": 982, + "emp": 983, + "person": 984, + "elles": 985, + "beau": 986, + "parti": 987, + "cho": 988, + "prit": 989, + "toujours": 990, + "rais": 991, + "jamais": 992, + "trav": 993, + "tions": 994, + "très": 995, + "voi": 996, + "ren": 997, + "yeux": 998, + "voir": 999, + "premi": 1000, + "gne": 1001, + "heure": 1002, + "rou": 1003, + "eff": 1004, + "notre": 1005, + "ments": 1006, + "ton": 1007, + "fais": 1008, + "cela": 1009, + "répon": 1010, + "cons": 1011, + "air": 1012, + "ôt": 1013, + "pendant": 1014, + "ici": 1015, + "toute": 1016, + "jet": 1017, + "port": 1018, + "étaient": 1019, + "pen": 1020, + "hé": 1021, + "autres": 1022, + "père": 1023, + "oc": 1024, + "quelques": 1025, + "ique": 1026, + "lis": 1027, + "femme": 1028, + "jou": 1029, + "teur": 1030, + "monde": 1031, + "nes": 1032, + "dre": 1033, + "aff": 1034, + "rap": 1035, + "part": 1036, + "lement": 1037, + "cla": 1038, + "fut": 1039, + "quelque": 1040, + "prendre": 1041, + "rê": 1042, + "aille": 1043, + "sais": 1044, + "ches": 1045, + "let": 1046, + "char": 1047, + "ères": 1048, + "ents": 1049, + "moins": 1050, + "eau": 1051, + "aî": 1052, + "jeu": 1053, + "heur": 1054, + "ées": 1055, + "tri": 1056, + "point": 1057, + "mom": 1058, + "vent": 1059, + "nouv": 1060, + "gran": 1061, + "trois": 1062, + "sant": 1063, + "toutes": 1064, + "contre": 1065, + "èrent": 1066, + "chez": 1067, + "avez": 1068, + "ût": 1069, + "att": 1070, + "pau": 1071, + "porte": 1072, + "ouver": 1073, + "lit": 1074, + "prés": 1075, + "chose": 1076, + "vit": 1077, + "monsieur": 1078, + "hab": 1079, + "tête": 1080, + "ju": 1081, + "tement": 1082, + "ction": 1083, + "vrai": 1084, + "lar": 1085, + "cet": 1086, + "regard": 1087, + "lant": 1088, + "som": 1089, + "moment": 1090, + "illes": 1091, + "ple": 1092, + "ps": 1093, + "mère": 1094, + "cl": 1095, + "sour": 1096, + "ys": 1097, + "trop": 1098, + "enne": 1099, + "jusqu": 1100, + "avaient": 1101, + "avais": 1102, + "jeune": 1103, + "depuis": 1104, + "personne": 1105, + "fit": 1106, + "cert": 1107, + "jo": 1108, + "oui": 1109, + "rest": 1110, + "semb": 1111, + "cap": 1112, + "mat": 1113, + "mu": 1114, + "long": 1115, + "fran": 1116, + "faut": 1117, + "iti": 1118, + "bli": 1119, + "chev": 1120, + "pri": 1121, + "ente": 1122, + "ainsi": 1123, + "cham": 1124, + "lors": 1125, + "cas": 1126, + "ili": 1127, + "bé": 1128, + "nos": 1129, + "sui": 1130, + "rit": 1131, + "cro": 1132, + "gue": 1133, + "ía": 1134, + "por": 1135, + "las": 1136, + "ón": 1137, + "una": 1138, + "aba": 1139, + "dos": 1140, + "era": 1141, + "mb": 1142, + "para": 1143, + "ás": 1144, + "mos": 1145, + "ando": 1146, + "como": 1147, + "más": 1148, + "ción": 1149, + "tan": 1150, + "dad": 1151, + "ado": 1152, + "fu": 1153, + "cia": 1154, + "mente": 1155, + "sus": 1156, + "tar": 1157, + "za": 1158, + "ba": 1159, + "pero": 1160, + "sin": 1161, + "lla": 1162, + "án": 1163, + "ia": 1164, + "ran": 1165, + "ga": 1166, + "yo": 1167, + "tos": 1168, + "cos": 1169, + "ya": 1170, + "ones": 1171, + "había": 1172, + "hi": 1173, + "esta": 1174, + "mas": 1175, + "tor": 1176, + "aban": 1177, + "dor": 1178, + "ían": 1179, + "tas": 1180, + "én": 1181, + "endo": 1182, + "aque": 1183, + "ero": 1184, + "io": 1185, + "qué": 1186, + "cab": 1187, + "tal": 1188, + "señ": 1189, + "ora": 1190, + "todo": 1191, + "sal": 1192, + "cuando": 1193, + "gun": 1194, + "bu": 1195, + "ras": 1196, + "esto": 1197, + "pare": 1198, + "él": 1199, + "tras": 1200, + "jos": 1201, + "mien": 1202, + "pue": 1203, + "cre": 1204, + "pon": 1205, + "día": 1206, + "tros": 1207, + "sab": 1208, + "sobre": 1209, + "ese": 1210, + "mbre": 1211, + "eron": 1212, + "añ": 1213, + "ido": 1214, + "porque": 1215, + "ella": 1216, + "cen": 1217, + "muy": 1218, + "cal": 1219, + "este": 1220, + "has": 1221, + "có": 1222, + "gra": 1223, + "ros": 1224, + "aquel": 1225, + "dijo": 1226, + "cía": 1227, + "zo": 1228, + "ciones": 1229, + "mbi": 1230, + "elo": 1231, + "tó": 1232, + "ina": 1233, + "todos": 1234, + "tien": 1235, + "estaba": 1236, + "deci": 1237, + "cio": 1238, + "ño": 1239, + "lor": 1240, + "nues": 1241, + "medi": 1242, + "len": 1243, + "vida": 1244, + "ali": 1245, + "pues": 1246, + "ales": 1247, + "vol": 1248, + "mí": 1249, + "rar": 1250, + "cion": 1251, + "hasta": 1252, + "señor": 1253, + "cono": 1254, + "ah": 1255, + "dios": 1256, + "esa": 1257, + "ún": 1258, + "var": 1259, + "san": 1260, + "gui": 1261, + "otros": 1262, + "tado": 1263, + "buen": 1264, + "ña": 1265, + "tiemp": 1266, + "hacer": 1267, + "jer": 1268, + "vu": 1269, + "ana": 1270, + "así": 1271, + "antes": 1272, + "vez": 1273, + "miento": 1274, + "jar": 1275, + "lab": 1276, + "casa": 1277, + "eso": 1278, + "ego": 1279, + "dió": 1280, + "está": 1281, + "encia": 1282, + "eli": 1283, + "ías": 1284, + "tiempo": 1285, + "zar": 1286, + "van": 1287, + "mun": 1288, + "erta": 1289, + "tambi": 1290, + "sí": 1291, + "aun": 1292, + "mismo": 1293, + "entes": 1294, + "mano": 1295, + "ele": 1296, + "nada": 1297, + "segu": 1298, + "mej": 1299, + "erra": 1300, + "tir": 1301, + "uno": 1302, + "donde": 1303, + "toda": 1304, + "desde": 1305, + "también": 1306, + "cuer": 1307, + "hombre": 1308, + "otro": 1309, + "lib": 1310, + "trar": 1311, + "cual": 1312, + "hay": 1313, + "cada": 1314, + "taba": 1315, + "mento": 1316, + "tenía": 1317, + "quer": 1318, + "eran": 1319, + "siemp": 1320, + "siempre": 1321, + "erto": 1322, + "quí": 1323, + "gos": 1324, + "pués": 1325, + "ellos": 1326, + "después": 1327, + "nue": 1328, + "llo": 1329, + "inter": 1330, + "cómo": 1331, + "ahora": 1332, + "uste": 1333, + "traba": 1334, + "lado": 1335, + "ino": 1336, + "poco": 1337, + "erte": 1338, + "mujer": 1339, + "quier": 1340, + "algun": 1341, + "fue": 1342, + "ojos": 1343, + "enton": 1344, + "vos": 1345, + "esper": 1346, + "much": 1347, + "otra": 1348, + "az": 1349, + "eza": 1350, + "aquí": 1351, + "cias": 1352, + "gua": 1353, + "mucho": 1354, + "decir": 1355, + "esti": 1356, + "idad": 1357, + "algo": 1358, + "ocu": 1359, + "entonces": 1360, + "dido": 1361, + "entos": 1362, + "gri": 1363, + "dado": 1364, + "ios": 1365, + "dose": 1366, + "usted": 1367, + "quien": 1368, + "ami": 1369, + "unto": 1370, + "mejor": 1371, + "bas": 1372, + "solo": 1373, + "pregun": 1374, + "tur": 1375, + "alg": 1376, + "todas": 1377, + "parte": 1378, + "emb": 1379, + "cto": 1380, + "mundo": 1381, + "tiene": 1382, + "tante": 1383, + "palab": 1384, + "tran": 1385, + "aquella": 1386, + "cios": 1387, + "aunque": 1388, + "cuen": 1389, + "tener": 1390, + "fun": 1391, + "respon": 1392, + "allí": 1393, + "xi": 1394, + "han": 1395, + "pens": 1396, + "contra": 1397, + "tura": 1398, + "val": 1399, + "dio": 1400, + "tanto": 1401, + "camin": 1402, + "mó": 1403, + "esp": 1404, + "ada": 1405, + "ío": 1406, + "hacia": 1407, + "dej": 1408, + "estar": 1409, + "ión": 1410, + "gas": 1411, + "vas": 1412, + "noche": 1413, + "ér": 1414, + "años": 1415, + "padre": 1416, + "gus": 1417, + "ár": 1418, + "sino": 1419, + "manos": 1420, + "cido": 1421, + "estu": 1422, + "hubi": 1423, + "vir": 1424, + "bri": 1425, + "raz": 1426, + "chi": 1427, + "puede": 1428, + "menos": 1429, + "habi": 1430, + "homb": 1431, + "neces": 1432, + "may": 1433, + "eros": 1434, + "ría": 1435, + "hecho": 1436, + "escu": 1437, + "lti": 1438, + "ándo": 1439, + "bus": 1440, + "cosas": 1441, + "tú": 1442, + "espa": 1443, + "reci": 1444, + "ctor": 1445, + "prim": 1446, + "dia": 1447, + "dese": 1448, + "mientras": 1449, + "hor": 1450, + "fuer": 1451, + "ida": 1452, + "posi": 1453, + "lante": 1454, + "ano": 1455, + "estas": 1456, + "pli": 1457, + "luego": 1458, + "sión": 1459, + "cin": 1460, + "tierra": 1461, + "guar": 1462, + "cado": 1463, + "encon": 1464, + "pren": 1465, + "mayor": 1466, + "fal": 1467, + "ð": 1468, + "ħ": 1469, + "ň": 1470, + "ə": 1471, + "θ": 1472, + "’": 1473, + "“": 1474, + "”": 1475, + "zi": 1476, + "gli": 1477, + "tto": 1478, + "ono": 1479, + "nel": 1480, + "tti": 1481, + "della": 1482, + "zione": 1483, + "tta": 1484, + "tà": 1485, + "uo": 1486, + "come": 1487, + "alla": 1488, + "oni": 1489, + "ggi": 1490, + "ssi": 1491, + "più": 1492, + "ini": 1493, + "bb": 1494, + "sto": 1495, + "sono": 1496, + "eri": 1497, + "sse": 1498, + "sc": 1499, + "sul": 1500, + "vano": 1501, + "sti": 1502, + "suo": 1503, + "cchi": 1504, + "zza": 1505, + "anche": 1506, + "tte": 1507, + "sci": 1508, + "col": 1509, + "sso": 1510, + "ssa": 1511, + "dei": 1512, + "aveva": 1513, + "zz": 1514, + "amo": 1515, + "gno": 1516, + "sua": 1517, + "ria": 1518, + "sì": 1519, + "ché": 1520, + "dal": 1521, + "ona": 1522, + "spe": 1523, + "gni": 1524, + "tt": 1525, + "delle": 1526, + "questo": 1527, + "nella": 1528, + "dere": 1529, + "anno": 1530, + "dell": 1531, + "uni": 1532, + "bbe": 1533, + "anti": 1534, + "ene": 1535, + "gio": 1536, + "uto": 1537, + "qual": 1538, + "glia": 1539, + "quando": 1540, + "tutto": 1541, + "glio": 1542, + "zioni": 1543, + "cam": 1544, + "esso": 1545, + "ss": 1546, + "mol": 1547, + "loro": 1548, + "perché": 1549, + "cosa": 1550, + "due": 1551, + "poi": 1552, + "sco": 1553, + "cco": 1554, + "gna": 1555, + "tem": 1556, + "prima": 1557, + "così": 1558, + "essere": 1559, + "ani": 1560, + "bra": 1561, + "rio": 1562, + "anco": 1563, + "cui": 1564, + "spi": 1565, + "via": 1566, + "gior": 1567, + "bile": 1568, + "ggio": 1569, + "mai": 1570, + "tare": 1571, + "indi": 1572, + "rebbe": 1573, + "senza": 1574, + "zio": 1575, + "tutti": 1576, + "stato": 1577, + "zia": 1578, + "dalla": 1579, + "mia": 1580, + "vita": 1581, + "quella": 1582, + "qua": 1583, + "dove": 1584, + "allo": 1585, + "sempre": 1586, + "zzo": 1587, + "sia": 1588, + "dopo": 1589, + "porta": 1590, + "ccia": 1591, + "erano": 1592, + "anni": 1593, + "chia": 1594, + "enza": 1595, + "propri": 1596, + "anda": 1597, + "cca": 1598, + "occhi": 1599, + "questa": 1600, + "ffi": 1601, + "ron": 1602, + "mio": 1603, + "ris": 1604, + "ogni": 1605, + "rin": 1606, + "far": 1607, + "menti": 1608, + "ancora": 1609, + "fatto": 1610, + "mani": 1611, + "senti": 1612, + "pra": 1613, + "tempo": 1614, + "essi": 1615, + "bbi": 1616, + "lare": 1617, + "pers": 1618, + "sor": 1619, + "anza": 1620, + "pie": 1621, + "verso": 1622, + "altro": 1623, + "tato": 1624, + "cato": 1625, + "ato": 1626, + "volta": 1627, + "cc": 1628, + "fare": 1629, + "ciò": 1630, + "bili": 1631, + "nuo": 1632, + "quello": 1633, + "colo": 1634, + "ppo": 1635, + "trova": 1636, + "ore": 1637, + "rono": 1638, + "molto": 1639, + "almente": 1640, + "sca": 1641, + "vole": 1642, + "tali": 1643, + "sulla": 1644, + "sce": 1645, + "meno": 1646, + "anto": 1647, + "pun": 1648, + "stu": 1649, + "capi": 1650, + "giu": 1651, + "mini": 1652, + "pia": 1653, + "lavo": 1654, + "vero": 1655, + "rsi": 1656, + "altri": 1657, + "scia": 1658, + "suoi": 1659, + "glie": 1660, + "sotto": 1661, + "bene": 1662, + "scri": 1663, + "tale": 1664, + "degli": 1665, + "alc": 1666, + "uomo": 1667, + "pel": 1668, + "pote": 1669, + "essa": 1670, + "scu": 1671, + "signo": 1672, + "stro": 1673, + "uti": 1674, + "sione": 1675, + "gre": 1676, + "fini": 1677, + "lun": 1678, + "esi": 1679, + "passa": 1680, + "rà": 1681, + "mentre": 1682, + "hanno": 1683, + "usci": 1684, + "gia": 1685, + "già": 1686, + "mina": 1687, + "tica": 1688, + "giorno": 1689, + "esse": 1690, + "modo": 1691, + "spa": 1692, + "proprio": 1693, + "ori": 1694, + "contro": 1695, + "stru": 1696, + "diven": 1697, + "disse": 1698, + "rato": 1699, + "noi": 1700, + "vere": 1701, + "può": 1702, + "dice": 1703, + "cci": 1704, + "secon": 1705, + "ccio": 1706, + "qualche": 1707, + "tutta": 1708, + "gg": 1709, + "mondo": 1710, + "forma": 1711, + "mma": 1712, + "pensa": 1713, + "deva": 1714, + "fosse": 1715, + "sopra": 1716, + "tamente": 1717, + "ness": 1718, + "quanto": 1719, + "raga": 1720, + "unque": 1721, + "care": 1722, + "stre": 1723, + "grande": 1724, + "picco": 1725, + "guarda": 1726, + "nell": 1727, + "possi": 1728, + "presen": 1729, + "rò": 1730, + "paro": 1731, + "tua": 1732, + "vin": 1733, + "ane": 1734, + "stesso": 1735, + "dav": 1736, + "nei": 1737, + "nelle": 1738, + "ghi": 1739, + "pio": 1740, + "lato": 1741, + "sid": 1742, + "fine": 1743, + "fuo": 1744, + "quasi": 1745, + "ulti": 1746, + "ito": 1747, + "sue": 1748, + "fil": 1749, + "allora": 1750, + "veni": 1751, + "tano": 1752, + "ello": 1753, + "ão": 1754, + "não": 1755, + "uma": 1756, + "ela": 1757, + "lh": 1758, + "ção": 1759, + "cê": 1760, + "inha": 1761, + "você": 1762, + "ec": 1763, + "dade": 1764, + "ao": 1765, + "ram": 1766, + "vel": 1767, + "ém": 1768, + "pode": 1769, + "estava": 1770, + "isso": 1771, + "mui": 1772, + "faz": 1773, + "ões": 1774, + "pes": 1775, + "ix": 1776, + "sim": 1777, + "olh": 1778, + "isa": 1779, + "ên": 1780, + "tinha": 1781, + "meu": 1782, + "são": 1783, + "minha": 1784, + "muito": 1785, + "foi": 1786, + "bem": 1787, + "diz": 1788, + "parec": 1789, + "ço": 1790, + "pesso": 1791, + "pois": 1792, + "mesmo": 1793, + "ções": 1794, + "seus": 1795, + "até": 1796, + "ência": 1797, + "lhe": 1798, + "tiv": 1799, + "mã": 1800, + "só": 1801, + "tão": 1802, + "tudo": 1803, + "então": 1804, + "inda": 1805, + "bal": 1806, + "indo": 1807, + "ndo": 1808, + "já": 1809, + "vam": 1810, + "eito": 1811, + "depois": 1812, + "mel": 1813, + "lha": 1814, + "ainda": 1815, + "fazer": 1816, + "pou": 1817, + "pergun": 1818, + "deix": 1819, + "tamb": 1820, + "ala": 1821, + "pelo": 1822, + "também": 1823, + "fica": 1824, + "prec": 1825, + "eles": 1826, + "havia": 1827, + "lá": 1828, + "nas": 1829, + "gem": 1830, + "mem": 1831, + "ós": 1832, + "deu": 1833, + "eiro": 1834, + "..": 1835, + "assim": 1836, + "ior": 1837, + "har": 1838, + "aqui": 1839, + "cul": 1840, + "sar": 1841, + "outra": 1842, + "olhos": 1843, + "ima": 1844, + "mim": 1845, + "ago": 1846, + "pessoas": 1847, + "eram": 1848, + "eira": 1849, + "pela": 1850, + "coisa": 1851, + "mão": 1852, + "conh": 1853, + "agora": 1854, + "iam": 1855, + "há": 1856, + "suas": 1857, + "guém": 1858, + "cabe": 1859, + "nem": 1860, + "ível": 1861, + "consegu": 1862, + "trabal": 1863, + "lev": 1864, + "lem": 1865, + "vai": 1866, + "tei": 1867, + "pró": 1868, + "quem": 1869, + "onde": 1870, + "cabeça": 1871, + "nunca": 1872, + "mentos": 1873, + "hum": 1874, + "dele": 1875, + "verdade": 1876, + "tá": 1877, + "hos": 1878, + "algum": 1879, + "dizer": 1880, + "penas": 1881, + "nós": 1882, + "enquanto": 1883, + "outro": 1884, + "lho": 1885, + "melhor": 1886, + "primei": 1887, + "iu": 1888, + "apenas": 1889, + "estou": 1890, + "conte": 1891, + "homem": 1892, + "dois": 1893, + "ças": 1894, + "pouco": 1895, + "senhor": 1896, + "tando": 1897, + "espera": 1898, + "pai": 1899, + "rios": 1900, + "baix": 1901, + "ase": 1902, + "isas": 1903, + "hora": 1904, + "ficar": 1905, + "seja": 1906, + "ân": 1907, + "clar": 1908, + "inc": 1909, + "fos": 1910, + "ouvi": 1911, + "vem": 1912, + "tava": 1913, + "ário": 1914, + "sos": 1915, + "inho": 1916, + "rando": 1917, + "ês": 1918, + "coisas": 1919, + "aconte": 1920, + "lher": 1921, + "anos": 1922, + "talvez": 1923, + "estão": 1924, + "liv": 1925, + "outros": 1926, + "qualquer": 1927, + "gou": 1928, + "lí": 1929, + "tivesse": 1930, + "rado": 1931, + "precisa": 1932, + "mãe": 1933, + "dela": 1934, + "entra": 1935, + "maior": 1936, + "noite": 1937, + "tiva": 1938, + "pala": 1939, + "ração": 1940, + "deus": 1941, + "sas": 1942, + "inte": 1943, + "fei": 1944, + "palav": 1945, + "trás": 1946, + "cidade": 1947, + "lugar": 1948, + "vezes": 1949, + "encontra": 1950, + "tru": 1951, + "eci": 1952, + "ın": 1953, + "bir": 1954, + "yor": 1955, + "ek": 1956, + "dı": 1957, + "ey": 1958, + "tı": 1959, + "mı": 1960, + "iz": 1961, + "ır": 1962, + "gö": 1963, + "sı": 1964, + "bil": 1965, + "lı": 1966, + "üz": 1967, + "iç": 1968, + "iy": 1969, + "ım": 1970, + "uz": 1971, + "cak": 1972, + "iş": 1973, + "ını": 1974, + "iyor": 1975, + "baş": 1976, + "dü": 1977, + "değ": 1978, + "kar": 1979, + "ev": 1980, + "öy": 1981, + "bun": 1982, + "yap": 1983, + "sun": 1984, + "gör": 1985, + "yı": 1986, + "ki": 1987, + "ara": 1988, + "alı": 1989, + "onu": 1990, + "çı": 1991, + "şey": 1992, + "sın": 1993, + "kı": 1994, + "kad": 1995, + "ağ": 1996, + "değil": 1997, + "ük": 1998, + "çok": 1999, + "şı": 2000, + "ül": 2001, + "için": 2002, + "eye": 2003, + "oldu": 2004, + "mış": 2005, + "kal": 2006, + "mek": 2007, + "öyle": 2008, + "yordu": 2009, + "yüz": 2010, + "miş": 2011, + "mak": 2012, + "ola": 2013, + "yan": 2014, + "cek": 2015, + "yorum": 2016, + "bak": 2017, + "üm": 2018, + "ları": 2019, + "oğ": 2020, + "kadar": 2021, + "arı": 2022, + "ında": 2023, + "gün": 2024, + "yok": 2025, + "yer": 2026, + "dım": 2027, + "daha": 2028, + "ına": 2029, + "dim": 2030, + "bilir": 2031, + "iki": 2032, + "siz": 2033, + "diğ": 2034, + "bü": 2035, + "düş": 2036, + "üç": 2037, + "unu": 2038, + "aman": 2039, + "fak": 2040, + "ede": 2041, + "sonra": 2042, + "hiç": 2043, + "aki": 2044, + "ğı": 2045, + "bul": 2046, + "maz": 2047, + "anla": 2048, + "bura": 2049, + "geç": 2050, + "maya": 2051, + "konu": 2052, + "din": 2053, + "tek": 2054, + "zaman": 2055, + "eler": 2056, + "öz": 2057, + "dır": 2058, + "gibi": 2059, + "şa": 2060, + "leri": 2061, + "kim": 2062, + "ku": 2063, + "fakat": 2064, + "yar": 2065, + "göz": 2066, + "cı": 2067, + "yorsun": 2068, + "bek": 2069, + "inde": 2070, + "pek": 2071, + "bunu": 2072, + "lik": 2073, + "iler": 2074, + "edi": 2075, + "öl": 2076, + "sür": 2077, + "sır": 2078, + "çık": 2079, + "sıl": 2080, + "alar": 2081, + "kes": 2082, + "yak": 2083, + "çek": 2084, + "yıl": 2085, + "ecek": 2086, + "ız": 2087, + "git": 2088, + "kap": 2089, + "ama": 2090, + "ıl": 2091, + "ların": 2092, + "biz": 2093, + "tır": 2094, + "oy": 2095, + "ancak": 2096, + "doğ": 2097, + "bana": 2098, + "şim": 2099, + "başla": 2100, + "lü": 2101, + "madı": 2102, + "beni": 2103, + "yük": 2104, + "lık": 2105, + "beş": 2106, + "nasıl": 2107, + "tık": 2108, + "tür": 2109, + "daki": 2110, + "ceğ": 2111, + "zı": 2112, + "iyi": 2113, + "dok": 2114, + "benim": 2115, + "cağ": 2116, + "yen": 2117, + "şu": 2118, + "mez": 2119, + "düşün": 2120, + "kendi": 2121, + "şimdi": 2122, + "yol": 2123, + "yu": 2124, + "iste": 2125, + "sek": 2126, + "mam": 2127, + "söyle": 2128, + "dik": 2129, + "kur": 2130, + "olduğ": 2131, + "sını": 2132, + "biliyor": 2133, + "kan": 2134, + "yal": 2135, + "meye": 2136, + "muş": 2137, + "kaç": 2138, + "iye": 2139, + "tü": 2140, + "ef": 2141, + "tım": 2142, + "evet": 2143, + "yet": 2144, + "burada": 2145, + "tim": 2146, + "biraz": 2147, + "kor": 2148, + "doğru": 2149, + "inin": 2150, + "kız": 2151, + "diye": 2152, + "dör": 2153, + "etti": 2154, + "onun": 2155, + "isti": 2156, + "ği": 2157, + "sana": 2158, + "üş": 2159, + "arka": 2160, + "hayır": 2161, + "karşı": 2162, + "ile": 2163, + "hak": 2164, + "ıyor": 2165, + "neden": 2166, + "sev": 2167, + "sız": 2168, + "çocu": 2169, + "çalı": 2170, + "olur": 2171, + "bır": 2172, + "gir": 2173, + "ise": 2174, + "ih": 2175, + "kır": 2176, + "dön": 2177, + "böyle": 2178, + "seni": 2179, + "!\"": 2180, + "dört": 2181, + "söy": 2182, + "oş": 2183, + "musun": 2184, + "laş": 2185, + "ip": 2186, + "kay": 2187, + "hem": 2188, + "büyük": 2189, + "aç": 2190, + "bırak": 2191, + "misin": 2192, + "söz": 2193, + "değiş": 2194, + "ünü": 2195, + "gül": 2196, + "kö": 2197, + "karı": 2198, + "tamam": 2199, + "olu": 2200, + "yeni": 2201, + "lam": 2202, + "mıştı": 2203, + "yaş": 2204, + "iniz": 2205, + "kadın": 2206, + "bunun": 2207, + "mey": 2208, + "altı": 2209, + "yi": 2210, + "inden": 2211, + "senin": 2212, + "yat": 2213, + "top": 2214, + "isi": 2215, + "dün": 2216, + "hiçbir": 2217, + "yon": 2218, + "dın": 2219, + "tün": 2220, + "başka": 2221, + "hep": 2222, + "irmi": 2223, + "devam": 2224, + "olacak": 2225, + "artık": 2226, + "durum": 2227, + "imiz": 2228, + "üzel": 2229, + "lerini": 2230, + "sağ": 2231, + "gerek": 2232, + "yirmi": 2233, + "şek": 2234, + "bağ": 2235, + "lara": 2236, + "yür": 2237, + "ması": 2238, + "katı": 2239, + "dedi": 2240, + "gü": 2241, + "sorun": 2242, + "üne": 2243, + "mız": 2244, + "yapı": 2245, + "mil": 2246, + "ğını": 2247, + "tara": 2248, + "vardı": 2249, + "konuş": 2250, + "arak": 2251, + "larak": 2252, + "çocuk": 2253, + "bütün": 2254, + "ley": 2255, + "dür": 2256, + "güzel": 2257, + "ayı": 2258, + "yapa": 2259, + "nı": 2260, + "ayr": 2261, + "öne": 2262, + "yordum": 2263, + "ban": 2264, + "i̇ş": 2265, + "dum": 2266, + "yorlar": 2267, + "larını": 2268, + "çıkar": 2269, + "zan": 2270, + "seç": 2271, + "liyor": 2272, + "tak": 2273, + "şık": 2274, + "tekrar": 2275, + "aş": 2276, + "eş": 2277, + "mişti": 2278, + "kin": 2279, + "imi": 2280, + "eğ": 2281, + "gidi": 2282, + "leş": 2283, + "başladı": 2284, + "gide": 2285, + "otur": 2286, + "dde": 2287, + "ından": 2288, + "üzer": 2289, + "ının": 2290, + "nız": 2291, + "uy": 2292, + "yedi": 2293, + "kat": 2294, + "olarak": 2295, + "ladı": 2296, + "yalnız": 2297, + "bah": 2298, + "iyet": 2299, + "sak": 2300, + "açık": 2301, + "sında": 2302, + "...": 2303, + "insan": 2304, + "aynı": 2305, + "eder": 2306, + "istan": 2307, + "uzun": 2308, + "geri": 2309, + "erek": 2310, + "olan": 2311, + "gerçek": 2312, + "alan": 2313, + "dış": 2314, + "alık": 2315, + "fark": 2316, + "üst": 2317, + "sade": 2318, + "kiş": 2319, + "ldı": 2320, + "zor": 2321, + "etir": 2322, + "herkes": 2323, + "ömer": 2324, + "unda": 2325, + "haf": 2326, + "buna": 2327, + "ydı": 2328, + "peki": 2329, + "adam": 2330, + "haz": 2331, + "sına": 2332, + "kapı": 2333, + "görüş": 2334, + "sadece": 2335, + "aldı": 2336, + "geldi": 2337, + "rz": 2338, + "sz": 2339, + "cz": 2340, + "ię": 2341, + "dz": 2342, + "ał": 2343, + "się": 2344, + "rze": 2345, + "że": 2346, + "wy": 2347, + "rzy": 2348, + "ła": 2349, + "ło": 2350, + "ny": 2351, + "dzie": 2352, + "dzi": 2353, + "czy": 2354, + "cie": 2355, + "prze": 2356, + "dy": 2357, + "kie": 2358, + "ry": 2359, + "ją": 2360, + "ów": 2361, + "przy": 2362, + "mie": 2363, + "szy": 2364, + "cze": 2365, + "bie": 2366, + "cy": 2367, + "nia": 2368, + "ści": 2369, + "sze": 2370, + "jest": 2371, + "ży": 2372, + "ną": 2373, + "któ": 2374, + "ała": 2375, + "mnie": 2376, + "ły": 2377, + "cza": 2378, + "jak": 2379, + "roz": 2380, + "ró": 2381, + "zna": 2382, + "łu": 2383, + "ść": 2384, + "wia": 2385, + "wszy": 2386, + "spo": 2387, + "gdy": 2388, + "wał": 2389, + "wię": 2390, + "łem": 2391, + "ję": 2392, + "sk": 2393, + "rę": 2394, + "dob": 2395, + "już": 2396, + "bę": 2397, + "ałem": 2398, + "sza": 2399, + "pod": 2400, + "dla": 2401, + "pan": 2402, + "nę": 2403, + "może": 2404, + "śli": 2405, + "ało": 2406, + "lko": 2407, + "nych": 2408, + "powie": 2409, + "cię": 2410, + "tylko": 2411, + "naj": 2412, + "tego": 2413, + "ski": 2414, + "nego": 2415, + "wszyst": 2416, + "szcze": 2417, + "jed": 2418, + "jej": 2419, + "two": 2420, + "ąd": 2421, + "śmy": 2422, + "czę": 2423, + "wać": 2424, + "jego": 2425, + "ża": 2426, + "sy": 2427, + "praw": 2428, + "tym": 2429, + "który": 2430, + "ały": 2431, + "trze": 2432, + "niej": 2433, + "nym": 2434, + "gło": 2435, + "jąc": 2436, + "mówi": 2437, + "ska": 2438, + "nej": 2439, + "słu": 2440, + "wła": 2441, + "będzie": 2442, + "dę": 2443, + "pó": 2444, + "bez": 2445, + "nic": 2446, + "pła": 2447, + "ście": 2448, + "są": 2449, + "trzy": 2450, + "kiem": 2451, + "był": 2452, + "mog": 2453, + "robi": 2454, + "tam": 2455, + "mię": 2456, + "zy": 2457, + "pew": 2458, + "myś": 2459, + "przed": 2460, + "sko": 2461, + "które": 2462, + "lę": 2463, + "wsze": 2464, + "ąc": 2465, + "było": 2466, + "sobie": 2467, + "py": 2468, + "cią": 2469, + "jeszcze": 2470, + "tę": 2471, + "czas": 2472, + "szę": 2473, + "gł": 2474, + "kę": 2475, + "czu": 2476, + "przez": 2477, + "sło": 2478, + "wz": 2479, + "kto": 2480, + "ków": 2481, + "czo": 2482, + "liśmy": 2483, + "więc": 2484, + "rą": 2485, + "wó": 2486, + "rza": 2487, + "ności": 2488, + "wet": 2489, + "nął": 2490, + "śmie": 2491, + "nawet": 2492, + "musi": 2493, + "swo": 2494, + "tej": 2495, + "wą": 2496, + "wu": 2497, + "wią": 2498, + "niu": 2499, + "czą": 2500, + "dzo": 2501, + "skie": 2502, + "jeśli": 2503, + "czego": 2504, + "chy": 2505, + "dł": 2506, + "tych": 2507, + "bym": 2508, + "żo": 2509, + "eś": 2510, + "sią": 2511, + "kiedy": 2512, + "wró": 2513, + "dze": 2514, + "dro": 2515, + "rów": 2516, + "pani": 2517, + "kul": 2518, + "nad": 2519, + "chwi": 2520, + "nim": 2521, + "być": 2522, + "chodzi": 2523, + "nio": 2524, + "dobrze": 2525, + "teraz": 2526, + "wokul": 2527, + "coś": 2528, + "kł": 2529, + "pier": 2530, + "gdzie": 2531, + "dzy": 2532, + "pię": 2533, + "dź": 2534, + "ką": 2535, + "gó": 2536, + "zda": 2537, + "chce": 2538, + "stę": 2539, + "świa": 2540, + "wszystko": 2541, + "peł": 2542, + "wiem": 2543, + "wiel": 2544, + "każ": 2545, + "rzu": 2546, + "sły": 2547, + "jedna": 2548, + "myśl": 2549, + "mój": 2550, + "jestem": 2551, + "óż": 2552, + "miej": 2553, + "moż": 2554, + "kła": 2555, + "resz": 2556, + "dłu": 2557, + "stwo": 2558, + "nię": 2559, + "masz": 2560, + "żeby": 2561, + "niem": 2562, + "jakie": 2563, + "sty": 2564, + "nią": 2565, + "wej": 2566, + "oj": 2567, + "sła": 2568, + "ność": 2569, + "zło": 2570, + "szczę": 2571, + "lej": 2572, + "wego": 2573, + "cał": 2574, + "dział": 2575, + "kich": 2576, + "dza": 2577, + "dzię": 2578, + "oczy": 2579, + "zosta": 2580, + "czło": 2581, + "nam": 2582, + "kil": 2583, + "szu": 2584, + "wę": 2585, + "miał": 2586, + "strze": 2587, + "cej": 2588, + "ej": 2589, + "znaj": 2590, + "dać": 2591, + "miejs": 2592, + "kró": 2593, + "kry": 2594, + "bardzo": 2595, + "śnie": 2596, + "lą": 2597, + "gie": 2598, + "ciebie": 2599, + "dni": 2600, + "potrze": 2601, + "wokulski": 2602, + "uwa": 2603, + "umie": 2604, + "jednak": 2605, + "kra": 2606, + "wróci": 2607, + "człowie": 2608, + "czyć": 2609, + "była": 2610, + "żeli": 2611, + "mę": 2612, + "cę": 2613, + "zrobi": 2614, + "mogę": 2615, + "prowa": 2616, + "rem": 2617, + "niech": 2618, + "cznie": 2619, + "kro": 2620, + "tą": 2621, + "chci": 2622, + "bro": 2623, + "dzieć": 2624, + "szą": 2625, + "pad": 2626, + "trz": 2627, + "jem": 2628, + "tów": 2629, + "dru": 2630, + "taj": 2631, + "rzekł": 2632, + "niego": 2633, + "takie": 2634, + "wała": 2635, + "towa": 2636, + "kapła": 2637, + "widzi": 2638, + "podob": 2639, + "dzę": 2640, + "tał": 2641, + "stęp": 2642, + "bą": 2643, + "poko": 2644, + "wem": 2645, + "gę": 2646, + "aby": 2647, + "albo": 2648, + "spra": 2649, + "zno": 2650, + "smo": 2651, + "jesz": 2652, + "księ": 2653, + "jesteś": 2654, + "poz": 2655, + "nigdy": 2656, + "ksią": 2657, + "cóż": 2658, + "ws": 2659, + "pow": 2660, + "tka": 2661, + "świe": 2662, + "szka": 2663, + "samo": 2664, + "sł": 2665, + "rzę": 2666, + "nale": 2667, + "chcesz": 2668, + "nik": 2669, + "pę": 2670, + "chyba": 2671, + "ciąg": 2672, + "jący": 2673, + "woj": 2674, + "nasze": 2675, + "mniej": 2676, + "więcej": 2677, + "zwy": 2678, + "osta": 2679, + "waż": 2680, + "śmier": 2681, + "wier": 2682, + "dzą": 2683, + "zaś": 2684, + "gdyby": 2685, + "jaki": 2686, + "wol": 2687, + "win": 2688, + "dą": 2689, + "ścia": 2690, + "rozma": 2691, + "wal": 2692, + "panie": 2693, + "star": 2694, + "kaz": 2695, + "jeżeli": 2696, + "wra": 2697, + "koń": 2698, + "siebie": 2699, + "znowu": 2700, + "czem": 2701, + "stwa": 2702, + "isto": 2703, + "pół": 2704, + "dał": 2705, + "kobie": 2706, + "ałam": 2707, + "wych": 2708, + "cesa": 2709, + "nich": 2710, + "zawsze": 2711, + "dzić": 2712, + "też": 2713, + "lepie": 2714, + "proszę": 2715, + "kre": 2716, + "twa": 2717, + "łą": 2718, + "chu": 2719, + "cą": 2720, + "prz": 2721, + "łe": 2722, + "szedł": 2723, + "odpowie": 2724, + "myśli": 2725, + "świą": 2726, + "ź": 2727, + "ł": 2728, + "&": 2729, + "=": 2730, + "ă": 2731, + "đ": 2732, + "ţ": 2733, + "–": 2734, + "‘": 2735, + "ij": 2736, + "aa": 2737, + "een": 2738, + "het": 2739, + "aar": 2740, + "oor": 2741, + "ijn": 2742, + "dat": 2743, + "oe": 2744, + "ijk": 2745, + "aan": 2746, + "voor": 2747, + "iet": 2748, + "zijn": 2749, + "niet": 2750, + "oo": 2751, + "moet": 2752, + "heb": 2753, + "uit": 2754, + "wij": 2755, + "aat": 2756, + "lijk": 2757, + "sl": 2758, + "daar": 2759, + "deze": 2760, + "worden": 2761, + "moeten": 2762, + "onder": 2763, + "hebben": 2764, + "ook": 2765, + "ct": 2766, + "nog": 2767, + "aal": 2768, + "eer": 2769, + "bij": 2770, + "mijn": 2771, + "kom": 2772, + "atie": 2773, + "eft": 2774, + "kel": 2775, + "rij": 2776, + "heid": 2777, + "af": 2778, + "stel": 2779, + "maar": 2780, + "wee": 2781, + "heeft": 2782, + "waar": 2783, + "eren": 2784, + "wat": 2785, + "wil": 2786, + "aag": 2787, + "bet": 2788, + "hij": 2789, + "kun": 2790, + "uw": 2791, + "dt": 2792, + "door": 2793, + "tij": 2794, + "ond": 2795, + "geen": 2796, + "gev": 2797, + "veel": 2798, + "naar": 2799, + "aten": 2800, + "kunnen": 2801, + "echt": 2802, + "goe": 2803, + "twee": 2804, + "delijk": 2805, + "uur": 2806, + "toe": 2807, + "meer": 2808, + "onze": 2809, + "tijd": 2810, + "hoe": 2811, + "tot": 2812, + "zou": 2813, + "aak": 2814, + "amen": 2815, + "woor": 2816, + "wordt": 2817, + "gelijk": 2818, + "gaan": 2819, + "ker": 2820, + "eld": 2821, + "hou": 2822, + "zel": 2823, + "tegen": 2824, + "komen": 2825, + "werk": 2826, + "goed": 2827, + "zal": 2828, + "zij": 2829, + "slag": 2830, + "zien": 2831, + "echter": 2832, + "itie": 2833, + "tie": 2834, + "elijk": 2835, + "ische": 2836, + "belan": 2837, + "haar": 2838, + "vr": 2839, + "grijk": 2840, + "doen": 2841, + "land": 2842, + "belangrijk": 2843, + "open": 2844, + "ctie": 2845, + "zelf": 2846, + "mij": 2847, + "iteit": 2848, + "stem": 2849, + "mee": 2850, + "aren": 2851, + "dien": 2852, + "gaat": 2853, + "prob": 2854, + "moe": 2855, + "ullen": 2856, + "zich": 2857, + "daarom": 2858, + "orm": 2859, + "staat": 2860, + "zit": 2861, + "dui": 2862, + "dus": 2863, + "ds": 2864, + "verslag": 2865, + "kelijk": 2866, + "proble": 2867, + "schap": 2868, + "gd": 2869, + "hun": 2870, + "erd": 2871, + "zet": 2872, + "staan": 2873, + "maal": 2874, + "inder": 2875, + "eid": 2876, + "kken": 2877, + "ged": 2878, + "zullen": 2879, + "mensen": 2880, + "jaar": 2881, + "regel": 2882, + "ieder": 2883, + "volgen": 2884, + "geven": 2885, + "even": 2886, + "blij": 2887, + "ië": 2888, + "uwe": 2889, + "maken": 2890, + "oek": 2891, + "nieuwe": 2892, + "baar": 2893, + "andere": 2894, + "ruik": 2895, + "agen": 2896, + "ouw": 2897, + "willen": 2898, + "aakt": 2899, + "hoo": 2900, + "anden": 2901, + "lig": 2902, + "samen": 2903, + "zeer": 2904, + "duidelijk": 2905, + "antwoor": 2906, + "heel": 2907, + "punt": 2908, + "houden": 2909, + "vraag": 2910, + "gele": 2911, + "eens": 2912, + "besch": 2913, + "omen": 2914, + "erg": 2915, + "doel": 2916, + "dag": 2917, + "uren": 2918, + "ings": 2919, + "oren": 2920, + "delen": 2921, + "steun": 2922, + "innen": 2923, + "pol": 2924, + "oon": 2925, + "sn": 2926, + "zonder": 2927, + "nodig": 2928, + "alleen": 2929, + "mid": 2930, + "ragen": 2931, + "iets": 2932, + "versch": 2933, + "gebruik": 2934, + "rouw": 2935, + "stellen": 2936, + "menten": 2937, + "eerste": 2938, + "laat": 2939, + "groot": 2940, + "ood": 2941, + "toch": 2942, + "laten": 2943, + "aard": 2944, + "sle": 2945, + "deel": 2946, + "plaat": 2947, + "ree": 2948, + "betre": 2949, + "lid": 2950, + "uiten": 2951, + "racht": 2952, + "beleid": 2953, + "stie": 2954, + "staten": 2955, + "ggen": 2956, + "reken": 2957, + "alen": 2958, + "ming": 2959, + "mogelijk": 2960, + "grote": 2961, + "altijd": 2962, + "enkel": 2963, + "wik": 2964, + "politie": 2965, + "elk": 2966, + "handel": 2967, + "kwe": 2968, + "maat": 2969, + "elen": 2970, + "vrij": 2971, + "jes": 2972, + "aam": 2973, + "huis": 2974, + "weer": 2975, + "lidstaten": 2976, + "king": 2977, + "kle": 2978, + "bed": 2979, + "geval": 2980, + "wikkel": 2981, + "kwestie": 2982, + "stee": 2983, + "hel": 2984, + "komst": 2985, + "iden": 2986, + "eerd": 2987, + "tweede": 2988, + "probleem": 2989, + "ussen": 2990, + "snel": 2991, + "tig": 2992, + "ult": 2993, + "nemen": 2994, + "commis": 2995, + "verschil": 2996, + "zoek": 2997, + "krij": 2998, + "graag": 2999, + "denk": 3000, + "landen": 3001, + "reden": 3002, + "besl": 3003, + "oeg": 3004, + "beter": 3005, + "heden": 3006, + "mag": 3007, + "boven": 3008, + "cont": 3009, + "fd": 3010, + "hele": 3011, + "vier": 3012, + "gez": 3013, + "kw": 3014, + "aas": 3015, + "ontwikkel": 3016, + "drie": 3017, + "vaak": 3018, + "plaats": 3019, + "gang": 3020, + "ijf": 3021, + "natuur": 3022, + "tussen": 3023, + "bat": 3024, + "komt": 3025, + "wacht": 3026, + "aad": 3027, + "achter": 3028, + "gebie": 3029, + "verk": 3030, + "ligt": 3031, + "nieuw": 3032, + "vand": 3033, + "ý": 3034, + "ď": 3035, + "ě": 3036, + "ř": 3037, + "ť": 3038, + "ů": 3039, + "„": 3040, + "ní": 3041, + "ně": 3042, + "ře": 3043, + "ná": 3044, + "vě": 3045, + "vá": 3046, + "rá": 3047, + "vy": 3048, + "mě": 3049, + "ři": 3050, + "ří": 3051, + "že": 3052, + "jí": 3053, + "vý": 3054, + "ji": 3055, + "dě": 3056, + "če": 3057, + "tě": 3058, + "ky": 3059, + "še": 3060, + "ké": 3061, + "ší": 3062, + "pře": 3063, + "ví": 3064, + "ný": 3065, + "ži": 3066, + "má": 3067, + "cí": 3068, + "zá": 3069, + "ské": 3070, + "dá": 3071, + "byl": 3072, + "tí": 3073, + "pří": 3074, + "při": 3075, + "či": 3076, + "vní": 3077, + "ča": 3078, + "dí": 3079, + "dní": 3080, + "ká": 3081, + "nou": 3082, + "vět": 3083, + "pě": 3084, + "kou": 3085, + "ých": 3086, + "bě": 3087, + "prá": 3088, + "jako": 3089, + "ží": 3090, + "zí": 3091, + "jsou": 3092, + "jsem": 3093, + "lní": 3094, + "cké": 3095, + "vat": 3096, + "před": 3097, + "hla": 3098, + "stá": 3099, + "čí": 3100, + "ši": 3101, + "kla": 3102, + "ště": 3103, + "lou": 3104, + "mů": 3105, + "chá": 3106, + "pů": 3107, + "také": 3108, + "dů": 3109, + "nost": 3110, + "tře": 3111, + "sku": 3112, + "vše": 3113, + "tní": 3114, + "byla": 3115, + "ční": 3116, + "jeho": 3117, + "bý": 3118, + "vání": 3119, + "ných": 3120, + "tři": 3121, + "vz": 3122, + "stře": 3123, + "dva": 3124, + "hle": 3125, + "čá": 3126, + "nosti": 3127, + "vš": 3128, + "hra": 3129, + "jen": 3130, + "slo": 3131, + "však": 3132, + "kdy": 3133, + "bylo": 3134, + "bude": 3135, + "jší": 3136, + "vých": 3137, + "ním": 3138, + "sm": 3139, + "koli": 3140, + "rů": 3141, + "může": 3142, + "není": 3143, + "hod": 3144, + "bí": 3145, + "tý": 3146, + "stě": 3147, + "uje": 3148, + "sá": 3149, + "pět": 3150, + "krá": 3151, + "tom": 3152, + "ství": 3153, + "vně": 3154, + "sed": 3155, + "své": 3156, + "pí": 3157, + "musí": 3158, + "už": 3159, + "tím": 3160, + "jící": 3161, + "jedno": 3162, + "čas": 3163, + "čty": 3164, + "ský": 3165, + "evro": 3166, + "toho": 3167, + "hy": 3168, + "kter": 3169, + "rní": 3170, + "stí": 3171, + "svě": 3172, + "pak": 3173, + "všech": 3174, + "ků": 3175, + "ng": 3176, + "ád": 3177, + "chází": 3178, + "být": 3179, + "první": 3180, + "mno": 3181, + "ského": 3182, + "pá": 3183, + "nebo": 3184, + "kem": 3185, + "sla": 3186, + "ného": 3187, + "zde": 3188, + "další": 3189, + "řa": 3190, + "čtyři": 3191, + "hrá": 3192, + "druh": 3193, + "lně": 3194, + "vla": 3195, + "ských": 3196, + "ško": 3197, + "půso": 3198, + "proto": 3199, + "vů": 3200, + "ská": 3201, + "šest": 3202, + "dně": 3203, + "ještě": 3204, + "mezi": 3205, + "několi": 3206, + "již": 3207, + "čně": 3208, + "slu": 3209, + "zná": 3210, + "sedm": 3211, + "vlá": 3212, + "osm": 3213, + "byly": 3214, + "vám": 3215, + "cký": 3216, + "tech": 3217, + "ději": 3218, + "velmi": 3219, + "leži": 3220, + "vala": 3221, + "lý": 3222, + "tvo": 3223, + "spole": 3224, + "stup": 3225, + "mož": 3226, + "evrop": 3227, + "stal": 3228, + "jde": 3229, + "rodi": 3230, + "její": 3231, + "poli": 3232, + "devět": 3233, + "sme": 3234, + "až": 3235, + "této": 3236, + "tento": 3237, + "kaž": 3238, + "nula": 3239, + "bych": 3240, + "moc": 3241, + "stou": 3242, + "kdo": 3243, + "zd": 3244, + "praco": 3245, + "tomu": 3246, + "ným": 3247, + "živo": 3248, + "zem": 3249, + "násle": 3250, + "sky": 3251, + "jich": 3252, + "měl": 3253, + "děla": 3254, + "jsme": 3255, + "nice": 3256, + "stej": 3257, + "stní": 3258, + "náro": 3259, + "nit": 3260, + "později": 3261, + "tako": 3262, + "nce": 3263, + "čer": 3264, + "ším": 3265, + "něco": 3266, + "vál": 3267, + "řej": 3268, + "krát": 3269, + "ální": 3270, + "asi": 3271, + "které": 3272, + "stav": 3273, + "mají": 3274, + "mys": 3275, + "době": 3276, + "sně": 3277, + "zku": 3278, + "tů": 3279, + "chod": 3280, + "spě": 3281, + "jejich": 3282, + "součas": 3283, + "vali": 3284, + "kte": 3285, + "prů": 3286, + "zení": 3287, + "pat": 3288, + "potře": 3289, + "dnes": 3290, + "zemí": 3291, + "znam": 3292, + "mám": 3293, + "tedy": 3294, + "hlavní": 3295, + "použí": 3296, + "bní": 3297, + "vede": 3298, + "lep": 3299, + "jek": 3300, + "prav": 3301, + "politi": 3302, + "dne": 3303, + "čení": 3304, + "než": 3305, + "děl": 3306, + "čo": 3307, + "cích": 3308, + "sté": 3309, + "dlou": 3310, + "několik": 3311, + "vyu": 3312, + "ckých": 3313, + "nové": 3314, + "čin": 3315, + "dělá": 3316, + "ký": 3317, + "obla": 3318, + "podle": 3319, + "důleži": 3320, + "poku": 3321, + "kone": 3322, + "dý": 3323, + "dvě": 3324, + "žád": 3325, + "nout": 3326, + "tku": 3327, + "tvr": 3328, + "ckého": 3329, + "rov": 3330, + "tele": 3331, + "psa": 3332, + "svět": 3333, + "tivní": 3334, + "dosta": 3335, + "šel": 3336, + "druhé": 3337, + "skou": 3338, + "žo": 3339, + "jedná": 3340, + "význam": 3341, + "problé": 3342, + "publi": 3343, + "ván": 3344, + "odpo": 3345, + "podpo": 3346, + "dle": 3347, + "jaké": 3348, + "šení": 3349, + "vím": 3350, + "během": 3351, + "nachází": 3352, + "slou": 3353, + "pouze": 3354, + "otá": 3355, + "plo": 3356, + "tové": 3357, + "větši": 3358, + "komi": 3359, + "vají": 3360, + "tyto": 3361, + "zápa": 3362, + "změ": 3363, + "moh": 3364, + "více": 3365, + "společ": 3366, + "auto": 3367, + "proti": 3368, + "dět": 3369, + "cháze": 3370, + "žel": 3371, + "«": 3372, + "»": 3373, + "а": 3374, + "б": 3375, + "в": 3376, + "г": 3377, + "д": 3378, + "е": 3379, + "ж": 3380, + "з": 3381, + "и": 3382, + "й": 3383, + "к": 3384, + "л": 3385, + "м": 3386, + "н": 3387, + "о": 3388, + "п": 3389, + "р": 3390, + "с": 3391, + "т": 3392, + "у": 3393, + "ф": 3394, + "х": 3395, + "ц": 3396, + "ч": 3397, + "ш": 3398, + "щ": 3399, + "ъ": 3400, + "ы": 3401, + "ь": 3402, + "э": 3403, + "ю": 3404, + "я": 3405, + "ё": 3406, + "‑": 3407, + "−": 3408, + "ст": 3409, + "ен": 3410, + "но": 3411, + "на": 3412, + "пр": 3413, + "то": 3414, + "по": 3415, + "ра": 3416, + "го": 3417, + "ко": 3418, + "не": 3419, + "во": 3420, + "ва": 3421, + "ет": 3422, + "ер": 3423, + "ни": 3424, + "ел": 3425, + "ит": 3426, + "ны": 3427, + "за": 3428, + "ро": 3429, + "ени": 3430, + "ка": 3431, + "ли": 3432, + "ем": 3433, + "да": 3434, + "об": 3435, + "ла": 3436, + "до": 3437, + "ся": 3438, + "ть": 3439, + "от": 3440, + "ло": 3441, + "ль": 3442, + "ед": 3443, + "со": 3444, + "ми": 3445, + "ре": 3446, + "мо": 3447, + "ци": 3448, + "про": 3449, + "та": 3450, + "это": 3451, + "ки": 3452, + "ру": 3453, + "при": 3454, + "ти": 3455, + "се": 3456, + "ста": 3457, + "вы": 3458, + "мы": 3459, + "ви": 3460, + "бы": 3461, + "ма": 3462, + "ес": 3463, + "ля": 3464, + "сти": 3465, + "ле": 3466, + "что": 3467, + "ме": 3468, + "ри": 3469, + "ча": 3470, + "од": 3471, + "ей": 3472, + "ель": 3473, + "ения": 3474, + "га": 3475, + "ну": 3476, + "си": 3477, + "па": 3478, + "раз": 3479, + "бо": 3480, + "сто": 3481, + "су": 3482, + "са": 3483, + "ду": 3484, + "его": 3485, + "ест": 3486, + "ин": 3487, + "ить": 3488, + "из": 3489, + "же": 3490, + "му": 3491, + "пер": 3492, + "под": 3493, + "ение": 3494, + "сь": 3495, + "ку": 3496, + "пред": 3497, + "ного": 3498, + "ных": 3499, + "вер": 3500, + "те": 3501, + "ной": 3502, + "ции": 3503, + "де": 3504, + "ры": 3505, + "дел": 3506, + "лю": 3507, + "ве": 3508, + "он": 3509, + "мен": 3510, + "ги": 3511, + "ня": 3512, + "бу": 3513, + "пра": 3514, + "все": 3515, + "ется": 3516, + "сть": 3517, + "жа": 3518, + "дол": 3519, + "жи": 3520, + "бе": 3521, + "кон": 3522, + "сл": 3523, + "ши": 3524, + "ди": 3525, + "ств": 3526, + "ско": 3527, + "ные": 3528, + "чи": 3529, + "ют": 3530, + "дер": 3531, + "стра": 3532, + "ты": 3533, + "ход": 3534, + "щи": 3535, + "зо": 3536, + "зна": 3537, + "ности": 3538, + "чес": 3539, + "вля": 3540, + "вать": 3541, + "ор": 3542, + "пол": 3543, + "вет": 3544, + "так": 3545, + "ша": 3546, + "ту": 3547, + "сво": 3548, + "пре": 3549, + "она": 3550, + "итель": 3551, + "ный": 3552, + "сло": 3553, + "как": 3554, + "вл": 3555, + "ность": 3556, + "хо": 3557, + "мож": 3558, + "пе": 3559, + "для": 3560, + "ния": 3561, + "ное": 3562, + "рас": 3563, + "долж": 3564, + "дар": 3565, + "тель": 3566, + "ска": 3567, + "пу": 3568, + "ство": 3569, + "кото": 3570, + "раб": 3571, + "ее": 3572, + "род": 3573, + "эти": 3574, + "соб": 3575, + "ору": 3576, + "жен": 3577, + "ным": 3578, + "ити": 3579, + "ние": 3580, + "ком": 3581, + "дет": 3582, + "сту": 3583, + "гу": 3584, + "пи": 3585, + "меж": 3586, + "ению": 3587, + "тер": 3588, + "работ": 3589, + "воз": 3590, + "ция": 3591, + "кой": 3592, + "щест": 3593, + "гра": 3594, + "зи": 3595, + "ря": 3596, + "между": 3597, + "ства": 3598, + "вс": 3599, + "ело": 3600, + "ше": 3601, + "мер": 3602, + "ба": 3603, + "зы": 3604, + "лу": 3605, + "аль": 3606, + "дей": 3607, + "гла": 3608, + "народ": 3609, + "кти": 3610, + "предста": 3611, + "лся": 3612, + "явля": 3613, + "ски": 3614, + "нов": 3615, + "един": 3616, + "ров": 3617, + "ис": 3618, + "нима": 3619, + "рем": 3620, + "ходи": 3621, + "также": 3622, + "дру": 3623, + "ать": 3624, + "след": 3625, + "гово": 3626, + "ная": 3627, + "ющи": 3628, + "ень": 3629, + "которы": 3630, + "хот": 3631, + "ву": 3632, + "их": 3633, + "ему": 3634, + "чит": 3635, + "важ": 3636, + "орга": 3637, + "чески": 3638, + "ще": 3639, + "ке": 3640, + "ха": 3641, + "пос": 3642, + "том": 3643, + "боль": 3644, + "мне": 3645, + "пас": 3646, + "объ": 3647, + "прав": 3648, + "конф": 3649, + "слу": 3650, + "поддер": 3651, + "стви": 3652, + "наш": 3653, + "лько": 3654, + "стоя": 3655, + "ную": 3656, + "лем": 3657, + "енных": 3658, + "кра": 3659, + "ды": 3660, + "международ": 3661, + "гда": 3662, + "необ": 3663, + "госу": 3664, + "ству": 3665, + "ении": 3666, + "государ": 3667, + "кто": 3668, + "им": 3669, + "чест": 3670, + "рет": 3671, + "вопро": 3672, + "лен": 3673, + "ели": 3674, + "рова": 3675, + "ций": 3676, + "нам": 3677, + "этой": 3678, + "жения": 3679, + "необходи": 3680, + "меня": 3681, + "было": 3682, + "сили": 3683, + "фи": 3684, + "вя": 3685, + "шь": 3686, + "этого": 3687, + "они": 3688, + "органи": 3689, + "безо": 3690, + "проб": 3691, + "име": 3692, + "реш": 3693, + "би": 3694, + "безопас": 3695, + "ются": 3696, + "оста": 3697, + "енно": 3698, + "год": 3699, + "ела": 3700, + "представ": 3701, + "ться": 3702, + "слово": 3703, + "организа": 3704, + "должны": 3705, + "этом": 3706, + "бла": 3707, + "че": 3708, + "чу": 3709, + "благо": 3710, + "этому": 3711, + "врем": 3712, + "спе": 3713, + "ном": 3714, + "ений": 3715, + "спо": 3716, + "нас": 3717, + "нет": 3718, + "зу": 3719, + "вед": 3720, + "еще": 3721, + "сказа": 3722, + "сей": 3723, + "ерен": 3724, + "дан": 3725, + "сам": 3726, + "еля": 3727, + "ран": 3728, + "зыва": 3729, + "является": 3730, + "будет": 3731, + "ктив": 3732, + "тре": 3733, + "деле": 3734, + "мот": 3735, + "конферен": 3736, + "лась": 3737, + "час": 3738, + "сторо": 3739, + "кого": 3740, + "ез": 3741, + "ней": 3742, + "ос": 3743, + "лись": 3744, + "разору": 3745, + "пере": 3746, + "сси": 3747, + "ными": 3748, + "проц": 3749, + "голо": 3750, + "чело": 3751, + "боле": 3752, + "челове": 3753, + "сер": 3754, + "пл": 3755, + "чет": 3756, + "стран": 3757, + "пя": 3758, + "был": 3759, + "кла": 3760, + "тов": 3761, + "жд": 3762, + "дела": 3763, + "ера": 3764, + "уже": 3765, + "совет": 3766, + "ген": 3767, + "безопасности": 3768, + "ца": 3769, + "седа": 3770, + "поз": 3771, + "ответ": 3772, + "проблем": 3773, + "нако": 3774, + "тем": 3775, + "доста": 3776, + "пы": 3777, + "ща": 3778, + "вой": 3779, + "сущест": 3780, + "необходимо": 3781, + "быть": 3782, + "может": 3783, + "дем": 3784, + "чтобы": 3785, + "ек": 3786, + "чер": 3787, + "усили": 3788, + "рес": 3789, + "руд": 3790, + "единенных": 3791, + "доб": 3792, + "дости": 3793, + "ствен": 3794, + "ядер": 3795, + "годня": 3796, + "каза": 3797, + "сегодня": 3798, + "сейчас": 3799, + "только": 3800, + "вод": 3801, + "есь": 3802, + "много": 3803, + "буду": 3804, + "ев": 3805, + "есть": 3806, + "три": 3807, + "общест": 3808, + "явл": 3809, + "высту": 3810, + "ред": 3811, + "счит": 3812, + "сит": 3813, + "делега": 3814, + "лож": 3815, + "этот": 3816, + "фор": 3817, + "клю": 3818, + "возмож": 3819, + "вания": 3820, + "бли": 3821, + "или": 3822, + "вз": 3823, + "наций": 3824, + "ского": 3825, + "приня": 3826, + "пла": 3827, + "оч": 3828, + "иться": 3829, + "сте": 3830, + "наши": 3831, + "которые": 3832, + "ар": 3833, + "имеет": 3834, + "сот": 3835, + "знач": 3836, + "перь": 3837, + "следу": 3838, + "ены": 3839, + "таки": 3840, + "объединенных": 3841, + "стро": 3842, + "теперь": 3843, + "бле": 3844, + "благодар": 3845, + "разв": 3846, + "ан": 3847, + "жива": 3848, + "очень": 3849, + "ят": 3850, + "без": 3851, + "обес": 3852, + "гро": 3853, + "лось": 3854, + "сы": 3855, + "организации": 3856, + "член": 3857, + "того": 3858, + "ональ": 3859, + "жда": 3860, + "всех": 3861, + "свя": 3862, + "более": 3863, + "сов": 3864, + "когда": 3865, + "вот": 3866, + "кре": 3867, + "кры": 3868, + "поэтому": 3869, + "воль": 3870, + "ой": 3871, + "генера": 3872, + "чем": 3873, + "лы": 3874, + "полити": 3875, + "вен": 3876, + "конференции": 3877, + "процес": 3878, + "бя": 3879, + "ите": 3880, + "отно": 3881, + "развити": 3882, + "аф": 3883, + "ющ": 3884, + "вно": 3885, + "мир": 3886, + "нии": 3887, + "кая": 3888, + "ас": 3889, + "ительно": 3890, + "вто": 3891, + "ением": 3892, + "генераль": 3893, + "прот": 3894, + "всем": 3895, + "самбле": 3896, + "ассамбле": 3897, + "ом": 3898, + "зд": 3899, + "смот": 3900, + "реги": 3901, + "чего": 3902, + "однако": 3903, + "усилия": 3904, + "действи": 3905, + "чно": 3906, + "уча": 3907, + "образ": 3908, + "вос": 3909, + "эта": 3910, + "перего": 3911, + "говор": 3912, + "вам": 3913, + "моло": 3914, + "время": 3915, + "дь": 3916, + "хотел": 3917, + "гру": 3918, + "заявл": 3919, + "предоста": 3920, + "поль": 3921, + "нее": 3922, + "резо": 3923, + "перегово": 3924, + "резолю": 3925, + "крет": 3926, + "поддерж": 3927, + "обеспе": 3928, + "него": 3929, + "представит": 3930, + "наде": 3931, + "кри": 3932, + "чь": 3933, + "проек": 3934, + "лет": 3935, + "други": 3936, + "_": 3937, + "،": 3938, + "؛": 3939, + "؟": 3940, + "ء": 3941, + "آ": 3942, + "أ": 3943, + "ؤ": 3944, + "إ": 3945, + "ئ": 3946, + "ا": 3947, + "ب": 3948, + "ة": 3949, + "ت": 3950, + "ث": 3951, + "ج": 3952, + "ح": 3953, + "خ": 3954, + "د": 3955, + "ذ": 3956, + "ر": 3957, + "ز": 3958, + "س": 3959, + "ش": 3960, + "ص": 3961, + "ض": 3962, + "ط": 3963, + "ظ": 3964, + "ع": 3965, + "غ": 3966, + "ـ": 3967, + "ف": 3968, + "ق": 3969, + "ك": 3970, + "ل": 3971, + "م": 3972, + "ن": 3973, + "ه": 3974, + "و": 3975, + "ى": 3976, + "ي": 3977, + "ً": 3978, + "ٌ": 3979, + "ٍ": 3980, + "َ": 3981, + "ُ": 3982, + "ِ": 3983, + "ّ": 3984, + "ْ": 3985, + "ٰ": 3986, + "چ": 3987, + "ڨ": 3988, + "ک": 3989, + "ھ": 3990, + "ی": 3991, + "ۖ": 3992, + "ۗ": 3993, + "ۘ": 3994, + "ۚ": 3995, + "ۛ": 3996, + "—": 3997, + "☭": 3998, + "ﺃ": 3999, + "ﻻ": 4000, + "ال": 4001, + "َا": 4002, + "وَ": 4003, + "َّ": 4004, + "ِي": 4005, + "أَ": 4006, + "لَ": 4007, + "نَ": 4008, + "الْ": 4009, + "هُ": 4010, + "ُو": 4011, + "ما": 4012, + "نْ": 4013, + "من": 4014, + "عَ": 4015, + "نا": 4016, + "لا": 4017, + "مَ": 4018, + "تَ": 4019, + "فَ": 4020, + "أن": 4021, + "لي": 4022, + "مِ": 4023, + "ان": 4024, + "في": 4025, + "رَ": 4026, + "يَ": 4027, + "هِ": 4028, + "مْ": 4029, + "قَ": 4030, + "بِ": 4031, + "لى": 4032, + "ين": 4033, + "إِ": 4034, + "لِ": 4035, + "وا": 4036, + "كَ": 4037, + "ها": 4038, + "ًا": 4039, + "مُ": 4040, + "ون": 4041, + "الم": 4042, + "بَ": 4043, + "يا": 4044, + "ذا": 4045, + "سا": 4046, + "الل": 4047, + "مي": 4048, + "يْ": 4049, + "را": 4050, + "ري": 4051, + "لك": 4052, + "مَا": 4053, + "نَّ": 4054, + "لم": 4055, + "إن": 4056, + "ست": 4057, + "وم": 4058, + "َّا": 4059, + "لَا": 4060, + "هم": 4061, + "ِّ": 4062, + "كُ": 4063, + "كان": 4064, + "سَ": 4065, + "با": 4066, + "دي": 4067, + "حَ": 4068, + "عْ": 4069, + "بي": 4070, + "الأ": 4071, + "ول": 4072, + "فِي": 4073, + "رِ": 4074, + "دا": 4075, + "مِنْ": 4076, + "ُونَ": 4077, + "وْ": 4078, + "هَا": 4079, + "ُّ": 4080, + "الس": 4081, + "الَ": 4082, + "ني": 4083, + "لْ": 4084, + "تُ": 4085, + "هل": 4086, + "رة": 4087, + "دَ": 4088, + "سْ": 4089, + "تِ": 4090, + "نَا": 4091, + "رْ": 4092, + "اللَّ": 4093, + "سامي": 4094, + "كن": 4095, + "كل": 4096, + "هَ": 4097, + "عَلَ": 4098, + "على": 4099, + "مع": 4100, + "إلى": 4101, + "قد": 4102, + "الر": 4103, + "ُوا": 4104, + "ير": 4105, + "عن": 4106, + "يُ": 4107, + "نِ": 4108, + "بْ": 4109, + "الح": 4110, + "هُمْ": 4111, + "قا": 4112, + "ذه": 4113, + "الت": 4114, + "ِينَ": 4115, + "جَ": 4116, + "هذا": 4117, + "عد": 4118, + "الع": 4119, + "دْ": 4120, + "قَالَ": 4121, + "رُ": 4122, + "يم": 4123, + "ية": 4124, + "نُ": 4125, + "خَ": 4126, + "رب": 4127, + "الك": 4128, + "وَا": 4129, + "أنا": 4130, + "ةِ": 4131, + "الن": 4132, + "حد": 4133, + "عِ": 4134, + "تا": 4135, + "هو": 4136, + "فا": 4137, + "عا": 4138, + "الش": 4139, + "لُ": 4140, + "يت": 4141, + "ذَا": 4142, + "يع": 4143, + "الذ": 4144, + "حْ": 4145, + "الص": 4146, + "إِنَّ": 4147, + "جا": 4148, + "علي": 4149, + "كَا": 4150, + "بُ": 4151, + "تع": 4152, + "وق": 4153, + "مل": 4154, + "لَّ": 4155, + "يد": 4156, + "أخ": 4157, + "رف": 4158, + "تي": 4159, + "الِ": 4160, + "ّا": 4161, + "ذلك": 4162, + "أَنْ": 4163, + "سِ": 4164, + "توم": 4165, + "مر": 4166, + "مَنْ": 4167, + "بل": 4168, + "الق": 4169, + "الله": 4170, + "ِيَ": 4171, + "كم": 4172, + "ذَ": 4173, + "عل": 4174, + "حب": 4175, + "سي": 4176, + "عُ": 4177, + "الج": 4178, + "الد": 4179, + "شَ": 4180, + "تك": 4181, + "فْ": 4182, + "صَ": 4183, + "لل": 4184, + "دِ": 4185, + "بر": 4186, + "فِ": 4187, + "ته": 4188, + "أع": 4189, + "تْ": 4190, + "قْ": 4191, + "الْأَ": 4192, + "ئِ": 4193, + "عَنْ": 4194, + "ور": 4195, + "حا": 4196, + "الَّ": 4197, + "مت": 4198, + "فر": 4199, + "دُ": 4200, + "هنا": 4201, + "وَأَ": 4202, + "تب": 4203, + "ةُ": 4204, + "أي": 4205, + "سب": 4206, + "ريد": 4207, + "وج": 4208, + "كُمْ": 4209, + "حِ": 4210, + "كْ": 4211, + "در": 4212, + "َاء": 4213, + "هذه": 4214, + "الط": 4215, + "الْمُ": 4216, + "دة": 4217, + "قل": 4218, + "غَ": 4219, + "يوم": 4220, + "الَّذ": 4221, + "كر": 4222, + "تر": 4223, + "كِ": 4224, + "كي": 4225, + "عَلَى": 4226, + "رَب": 4227, + "عة": 4228, + "قُ": 4229, + "جْ": 4230, + "فض": 4231, + "لة": 4232, + "هْ": 4233, + "رَا": 4234, + "وَلَ": 4235, + "الْمَ": 4236, + "أَنَّ": 4237, + "يَا": 4238, + "أُ": 4239, + "شي": 4240, + "اللَّهُ": 4241, + "لَى": 4242, + "قِ": 4243, + "أت": 4244, + "عَلَيْ": 4245, + "اللَّهِ": 4246, + "الب": 4247, + "ضَ": 4248, + "ةً": 4249, + "قي": 4250, + "ار": 4251, + "بد": 4252, + "خْ": 4253, + "سْتَ": 4254, + "طَ": 4255, + "قَدْ": 4256, + "ذهب": 4257, + "أم": 4258, + "ماذا": 4259, + "وَإِ": 4260, + "ةٌ": 4261, + "ونَ": 4262, + "ليلى": 4263, + "ولا": 4264, + "حُ": 4265, + "هي": 4266, + "صل": 4267, + "الخ": 4268, + "ود": 4269, + "ليس": 4270, + "لدي": 4271, + "قال": 4272, + "كَانَ": 4273, + "مَّ": 4274, + "حي": 4275, + "تم": 4276, + "لن": 4277, + "وَلَا": 4278, + "بع": 4279, + "يمكن": 4280, + "سُ": 4281, + "ةَ": 4282, + "حت": 4283, + "رًا": 4284, + "كا": 4285, + "شا": 4286, + "هِمْ": 4287, + "لَهُ": 4288, + "زَ": 4289, + "داً": 4290, + "مس": 4291, + "كث": 4292, + "الْعَ": 4293, + "جِ": 4294, + "صْ": 4295, + "فَا": 4296, + "له": 4297, + "وي": 4298, + "عَا": 4299, + "هُوَ": 4300, + "بِي": 4301, + "بَا": 4302, + "أس": 4303, + "ثَ": 4304, + "لِي": 4305, + "رض": 4306, + "الرَّ": 4307, + "لِكَ": 4308, + "تَّ": 4309, + "فُ": 4310, + "قة": 4311, + "فعل": 4312, + "مِن": 4313, + "الآ": 4314, + "ثُ": 4315, + "سم": 4316, + "مَّا": 4317, + "بِهِ": 4318, + "تق": 4319, + "خر": 4320, + "لقد": 4321, + "خل": 4322, + "شر": 4323, + "أنت": 4324, + "لَّا": 4325, + "سن": 4326, + "السَّ": 4327, + "الذي": 4328, + "سَا": 4329, + "وما": 4330, + "زل": 4331, + "وب": 4332, + "أْ": 4333, + "إذا": 4334, + "رِي": 4335, + "حة": 4336, + "نِي": 4337, + "الْحَ": 4338, + "وَقَالَ": 4339, + "به": 4340, + "ةٍ": 4341, + "سأ": 4342, + "رٌ": 4343, + "بال": 4344, + "مة": 4345, + "شْ": 4346, + "وت": 4347, + "عند": 4348, + "فس": 4349, + "بَعْ": 4350, + "هر": 4351, + "قط": 4352, + "أح": 4353, + "إنه": 4354, + "وع": 4355, + "فت": 4356, + "غا": 4357, + "هناك": 4358, + "بت": 4359, + "مِنَ": 4360, + "سر": 4361, + "ذَلِكَ": 4362, + "رس": 4363, + "حدث": 4364, + "غْ": 4365, + "ِّي": 4366, + "الإ": 4367, + "وَيَ": 4368, + "جل": 4369, + "است": 4370, + "قِي": 4371, + "عب": 4372, + "وس": 4373, + "يش": 4374, + "الَّذِينَ": 4375, + "تاب": 4376, + "دِي": 4377, + "جب": 4378, + "كون": 4379, + "بن": 4380, + "الث": 4381, + "لَيْ": 4382, + "بعد": 4383, + "وَالْ": 4384, + "فَأَ": 4385, + "عم": 4386, + "هُم": 4387, + "تن": 4388, + "ذْ": 4389, + "أص": 4390, + "أين": 4391, + "رَبِّ": 4392, + "الذين": 4393, + "إِن": 4394, + "بين": 4395, + "جُ": 4396, + "عَلَيْهِ": 4397, + "حَا": 4398, + "لو": 4399, + "ستط": 4400, + "ظر": 4401, + "لَمْ": 4402, + "ءِ": 4403, + "كُل": 4404, + "طل": 4405, + "تَا": 4406, + "ضُ": 4407, + "كنت": 4408, + "لًا": 4409, + "مٌ": 4410, + "قبل": 4411, + "ــ": 4412, + "ذِ": 4413, + "قَوْ": 4414, + "صِ": 4415, + "مًا": 4416, + "كانت": 4417, + "صا": 4418, + "يق": 4419, + "الف": 4420, + "النا": 4421, + "مٍ": 4422, + "إِنْ": 4423, + "النَّ": 4424, + "جد": 4425, + "وَمَا": 4426, + "تت": 4427, + "بح": 4428, + "مكان": 4429, + "كيف": 4430, + "ّة": 4431, + "الا": 4432, + "جَا": 4433, + "أو": 4434, + "ساعد": 4435, + "ضِ": 4436, + "إلا": 4437, + "راً": 4438, + "قَا": 4439, + "رأ": 4440, + "عت": 4441, + "أحد": 4442, + "هد": 4443, + "ضا": 4444, + "طر": 4445, + "أق": 4446, + "ماء": 4447, + "دَّ": 4448, + "البا": 4449, + "مُو": 4450, + "أَوْ": 4451, + "طا": 4452, + "قُو": 4453, + "خِ": 4454, + "تل": 4455, + "ستطيع": 4456, + "دَا": 4457, + "النَّا": 4458, + "إلَى": 4459, + "وَتَ": 4460, + "هَذَا": 4461, + "بة": 4462, + "عليك": 4463, + "جر": 4464, + "المن": 4465, + "زا": 4466, + "رٍ": 4467, + "دع": 4468, + "ًّا": 4469, + "سة": 4470, + "ثُمَّ": 4471, + "شيء": 4472, + "الغ": 4473, + "تح": 4474, + "رُونَ": 4475, + "اليوم": 4476, + "مِي": 4477, + "نُوا": 4478, + "أر": 4479, + "تُمْ": 4480, + "عر": 4481, + "يف": 4482, + "أب": 4483, + "دًا": 4484, + "صَا": 4485, + "التَّ": 4486, + "أريد": 4487, + "الز": 4488, + "يَوْ": 4489, + "إلي": 4490, + "جي": 4491, + "يَعْ": 4492, + "فضل": 4493, + "الإن": 4494, + "أنه": 4495, + "1": 4496, + "2": 4497, + "3": 4498, + "4": 4499, + "5": 4500, + "·": 4501, + "×": 4502, + "̃": 4503, + "̌": 4504, + "ε": 4505, + "λ": 4506, + "μ": 4507, + "•": 4508, + "‧": 4509, + "─": 4510, + "□": 4511, + "、": 4512, + "。": 4513, + "〈": 4514, + "〉": 4515, + "《": 4516, + "》": 4517, + "「": 4518, + "」": 4519, + "『": 4520, + "』": 4521, + "ア": 4522, + "オ": 4523, + "カ": 4524, + "チ": 4525, + "ド": 4526, + "ベ": 4527, + "ャ": 4528, + "ヤ": 4529, + "ン": 4530, + "・": 4531, + "ー": 4532, + "ㄟ": 4533, + "!": 4534, + "(": 4535, + ")": 4536, + ",": 4537, + "-": 4538, + "/": 4539, + ":": 4540, + ";": 4541, + "?": 4542, + "p": 4543, + "i4": 4544, + "zh": 4545, + "i2": 4546, + "ng1": 4547, + "u4": 4548, + "i1": 4549, + "ng2": 4550, + "u3": 4551, + "de5": 4552, + "e4": 4553, + "i3": 4554, + "ng4": 4555, + "an4": 4556, + "shi4": 4557, + "an2": 4558, + "u2": 4559, + "u1": 4560, + "ng3": 4561, + "a1": 4562, + "an1": 4563, + "e2": 4564, + "a4": 4565, + "ei4": 4566, + "ong1": 4567, + "ai4": 4568, + "ao4": 4569, + "ang1": 4570, + "an3": 4571, + "wei4": 4572, + "uo2": 4573, + "n1": 4574, + "en2": 4575, + "ao3": 4576, + "e1": 4577, + "qi": 4578, + "eng2": 4579, + "zho": 4580, + "ang3": 4581, + "ang4": 4582, + "ang2": 4583, + "uo4": 4584, + "ge4": 4585, + "yi1": 4586, + "guo2": 4587, + "a3": 4588, + "he2": 4589, + "e3": 4590, + "yi2": 4591, + "di4": 4592, + "zhong1": 4593, + "bu4": 4594, + "ai2": 4595, + "n2": 4596, + "zai4": 4597, + "shi2": 4598, + "eng1": 4599, + "ren2": 4600, + "ong2": 4601, + "xian4": 4602, + "xu": 4603, + "n4": 4604, + "li4": 4605, + "en4": 4606, + "yu2": 4607, + "ei2": 4608, + "yi2ge4": 4609, + "ou4": 4610, + "ei3": 4611, + "ui4": 4612, + "a2": 4613, + "you3": 4614, + "ao1": 4615, + "da4": 4616, + "cheng2": 4617, + "en1": 4618, + "eng4": 4619, + "yi4": 4620, + "si1": 4621, + "zhi4": 4622, + "jia1": 4623, + "yuan2": 4624, + "ta1": 4625, + "de5yi2ge4": 4626, + "ke1": 4627, + "shu3": 4628, + "xi1": 4629, + "ji2": 4630, + "ao2": 4631, + "ou3": 4632, + "ong4": 4633, + "xia4": 4634, + "ai1": 4635, + "gong1": 4636, + "zhi1": 4637, + "en3": 4638, + "wei2": 4639, + "xue2": 4640, + "qu1": 4641, + "zhou1": 4642, + "er3": 4643, + "ming2": 4644, + "zhong3": 4645, + "li3": 4646, + "wu4": 4647, + "yi3": 4648, + "uo1": 4649, + "e5": 4650, + "ji4": 4651, + "xing2": 4652, + "jian4": 4653, + "hua4": 4654, + "yu3": 4655, + "uo3": 4656, + "ji1": 4657, + "ai3": 4658, + "zuo4": 4659, + "hou4": 4660, + "hui4": 4661, + "ei1": 4662, + "nian2": 4663, + "qi2": 4664, + "dao4": 4665, + "sheng1": 4666, + "de2": 4667, + "dai4": 4668, + "uan2": 4669, + "zhe4": 4670, + "zheng4": 4671, + "ben3": 4672, + "shang4": 4673, + "zhu3": 4674, + "bei4": 4675, + "ye4": 4676, + "chu1": 4677, + "zhan4": 4678, + "le5": 4679, + "lai2": 4680, + "shi3": 4681, + "nan2": 4682, + "ren4": 4683, + "you2": 4684, + "ke4": 4685, + "ba1": 4686, + "fu4": 4687, + "dui4": 4688, + "ya4": 4689, + "mei3": 4690, + "zi4": 4691, + "xin1": 4692, + "jing1": 4693, + "zhu": 4694, + "n3": 4695, + "yong4": 4696, + "mu4": 4697, + "jiao4": 4698, + "ye3": 4699, + "jin4": 4700, + "bian4": 4701, + "lu4": 4702, + "qi1": 4703, + "she4": 4704, + "xiang1": 4705, + "ong3": 4706, + "shu4": 4707, + "dong4": 4708, + "suo3": 4709, + "guan1": 4710, + "san1": 4711, + "te4": 4712, + "duo1": 4713, + "fu2": 4714, + "min2": 4715, + "la1": 4716, + "zhi2": 4717, + "zhen4": 4718, + "ou1": 4719, + "wu3": 4720, + "ma3": 4721, + "i5": 4722, + "zi5": 4723, + "ju4": 4724, + "er4": 4725, + "yao4": 4726, + "xia4de5yi2ge4": 4727, + "si4": 4728, + "tu2": 4729, + "shan1": 4730, + "zui4": 4731, + "yin1": 4732, + "er2": 4733, + "tong2": 4734, + "dong1": 4735, + "yu4": 4736, + "yan2": 4737, + "qian2": 4738, + "shu3xia4de5yi2ge4": 4739, + "jun1": 4740, + "ke3": 4741, + "wen2": 4742, + "fa3": 4743, + "luo2": 4744, + "zhu4": 4745, + "xi4": 4746, + "kou3": 4747, + "bei3": 4748, + "jian1": 4749, + "fa1": 4750, + "dian4": 4751, + "jiang1": 4752, + "wei4yu2": 4753, + "xiang4": 4754, + "zhi3": 4755, + "eng3": 4756, + "fang1": 4757, + "lan2": 4758, + "shu": 4759, + "ri4": 4760, + "lian2": 4761, + "shou3": 4762, + "qiu2": 4763, + "jin1": 4764, + "huo4": 4765, + "shu3xia4de5yi2ge4zhong3": 4766, + "fen1": 4767, + "nei4": 4768, + "gai1": 4769, + "mei3guo2": 4770, + "un2": 4771, + "ge2": 4772, + "bao3": 4773, + "qing1": 4774, + "gao1": 4775, + "tai2": 4776, + "xiao3": 4777, + "jie2": 4778, + "tian1": 4779, + "chang2": 4780, + "quan2": 4781, + "lie4": 4782, + "hai3": 4783, + "fei1": 4784, + "ti3": 4785, + "jue2": 4786, + "ou2": 4787, + "ci3": 4788, + "zu2": 4789, + "ni2": 4790, + "biao3": 4791, + "zhong1guo2": 4792, + "du4": 4793, + "yue4": 4794, + "xing4": 4795, + "sheng4": 4796, + "che1": 4797, + "dan1": 4798, + "jie1": 4799, + "lin2": 4800, + "ping2": 4801, + "fu3": 4802, + "gu3": 4803, + "jie4": 4804, + "v3": 4805, + "sheng3": 4806, + "na4": 4807, + "yuan4": 4808, + "zhang3": 4809, + "guan3": 4810, + "dao3": 4811, + "zu3": 4812, + "ding4": 4813, + "dian3": 4814, + "ceng2": 4815, + "ren2kou3": 4816, + "tai4": 4817, + "tong1": 4818, + "guo4": 4819, + "neng2": 4820, + "chang3": 4821, + "hua2": 4822, + "liu2": 4823, + "ying1": 4824, + "xiao4": 4825, + "ci4": 4826, + "bian4hua4": 4827, + "liang3": 4828, + "gong4": 4829, + "zhong4": 4830, + "de5yi1": 4831, + "se4": 4832, + "kai1": 4833, + "wang2": 4834, + "jiu4": 4835, + "shi1": 4836, + "shou4": 4837, + "mei2": 4838, + "feng1": 4839, + "ze2": 4840, + "tu2shi4": 4841, + "ti2": 4842, + "qi4": 4843, + "jiu3": 4844, + "shen1": 4845, + "zhe3": 4846, + "ren2kou3bian4hua4": 4847, + "ren2kou3bian4hua4tu2shi4": 4848, + "di4qu1": 4849, + "yang2": 4850, + "men5": 4851, + "long2": 4852, + "bing4": 4853, + "chan3": 4854, + "zhu1": 4855, + "wei3": 4856, + "wai4": 4857, + "xing1": 4858, + "bo1": 4859, + "bi3": 4860, + "tang2": 4861, + "hua1": 4862, + "bo2": 4863, + "shui3": 4864, + "shu1": 4865, + "dou1": 4866, + "sai4": 4867, + "chao2": 4868, + "bi4": 4869, + "ling2": 4870, + "lei4": 4871, + "da4xue2": 4872, + "fen4": 4873, + "shu3de5": 4874, + "mu3": 4875, + "jiao1": 4876, + "dang1": 4877, + "cheng1": 4878, + "tong3": 4879, + "nv3": 4880, + "qi3": 4881, + "yan3": 4882, + "mian4": 4883, + "luo4": 4884, + "jing4": 4885, + "ge1": 4886, + "ru4": 4887, + "dan4": 4888, + "ri4ben3": 4889, + "pu3": 4890, + "yun4": 4891, + "huang2": 4892, + "wo3": 4893, + "lv": 4894, + "hai2": 4895, + "shi4yi1": 4896, + "xie1": 4897, + "ying3": 4898, + "wu2": 4899, + "shen2": 4900, + "wang3": 4901, + "guang3": 4902, + "liu4": 4903, + "su4": 4904, + "shi4zhen4": 4905, + "can1": 4906, + "cao3": 4907, + "xia2": 4908, + "ka3": 4909, + "da2": 4910, + "hu4": 4911, + "ban4": 4912, + "dang3": 4913, + "hu2": 4914, + "zong3": 4915, + "deng3": 4916, + "de5yi2ge4shi4zhen4": 4917, + "chuan2": 4918, + "mo4": 4919, + "zhang1": 4920, + "ban1": 4921, + "mo2": 4922, + "cha2": 4923, + "ce4": 4924, + "zhu3yao4": 4925, + "tou2": 4926, + "ju2": 4927, + "shi4wei4yu2": 4928, + "sa4": 4929, + "un1": 4930, + "ke3yi3": 4931, + "du1": 4932, + "han4": 4933, + "liang4": 4934, + "sha1": 4935, + "jia3": 4936, + "zi1": 4937, + "lv4": 4938, + "fu1": 4939, + "xian1": 4940, + "xu4": 4941, + "guang1": 4942, + "meng2": 4943, + "bao4": 4944, + "you4": 4945, + "rong2": 4946, + "zhi1yi1": 4947, + "wei1": 4948, + "mao2": 4949, + "guo2jia1": 4950, + "cong2": 4951, + "gou4": 4952, + "tie3": 4953, + "zhen1": 4954, + "du2": 4955, + "bian1": 4956, + "ci2": 4957, + "qu3": 4958, + "fan4": 4959, + "xiang3": 4960, + "men2": 4961, + "ju1": 4962, + "hong2": 4963, + "zi3": 4964, + "ta1men5": 4965, + "ji3": 4966, + "zong1": 4967, + "zhou1de5yi2ge4shi4zhen4": 4968, + "tuan2": 4969, + "jing3": 4970, + "gong1si1": 4971, + "xie4": 4972, + "li2": 4973, + "li4shi3": 4974, + "bao1": 4975, + "gang3": 4976, + "gui1": 4977, + "zheng1": 4978, + "zhi2wu4": 4979, + "ta1de5": 4980, + "pin3": 4981, + "zhuan1": 4982, + "chong2": 4983, + "shi3yong4": 4984, + "wa3": 4985, + "shuo1": 4986, + "chuan1": 4987, + "lei2": 4988, + "wan1": 4989, + "huo2": 4990, + "su1": 4991, + "zao3": 4992, + "gai3": 4993, + "qu4": 4994, + "gu4": 4995, + "xi2": 4996, + "hang2": 4997, + "ying4": 4998, + "cun1": 4999, + "gen1": 5000, + "ying2": 5001, + "ting2": 5002, + "cheng2shi4": 5003, + "jiang3": 5004, + "ling3": 5005, + "lun2": 5006, + "bu4fen4": 5007, + "deng1": 5008, + "xuan3": 5009, + "dong4wu4": 5010, + "de2guo2": 5011, + "xian3": 5012, + "fan3": 5013, + "zhe5": 5014, + "han2": 5015, + "hao4": 5016, + "mi4": 5017, + "ran2": 5018, + "qin1": 5019, + "tiao2": 5020, + "zhan3": 5021, + "[ar]": 5022, + "[zh-cn]": 5023, + "¡": 5024, + "é": 5025, + "shi": 5026, + "tsu": 5027, + "teki": 5028, + "nai": 5029, + "aru": 5030, + "uu": 5031, + "kai": 5032, + "shite": 5033, + "mono": 5034, + "koto": 5035, + "kara": 5036, + "shita": 5037, + "suru": 5038, + "masu": 5039, + "tai": 5040, + "ware": 5041, + "shin": 5042, + "oku": 5043, + "yuu": 5044, + "iru": 5045, + "jiko": 5046, + "desu": 5047, + "rare": 5048, + "shou": 5049, + "sha": 5050, + "sekai": 5051, + "kyou": 5052, + "mashita": 5053, + "nara": 5054, + "kei": 5055, + "ita": 5056, + "ari": 5057, + "itsu": 5058, + "kono": 5059, + "naka": 5060, + "chou": 5061, + "sore": 5062, + "naru": 5063, + "gaku": 5064, + "reba": 5065, + "hito": 5066, + "sai": 5067, + "nan": 5068, + "dai": 5069, + "tsuku": 5070, + "shiki": 5071, + "sare": 5072, + "naku": 5073, + "jun": 5074, + "kaku": 5075, + "zai": 5076, + "wata": 5077, + "shuu": 5078, + "ii": 5079, + "kare": 5080, + "shii": 5081, + "made": 5082, + "sho": 5083, + "kereba": 5084, + "shika": 5085, + "ichi": 5086, + "deki": 5087, + "nin": 5088, + "wareware": 5089, + "nakereba": 5090, + "oite": 5091, + "yaku": 5092, + "mujun": 5093, + "yoku": 5094, + "butsu": 5095, + "omo": 5096, + "gae": 5097, + "naranai": 5098, + "tachi": 5099, + "chuu": 5100, + "kangae": 5101, + "toki": 5102, + "koro": 5103, + "mujunteki": 5104, + "naga": 5105, + "jin": 5106, + "shima": 5107, + "iku": 5108, + "imasu": 5109, + "hon": 5110, + "kae": 5111, + "kore": 5112, + "kita": 5113, + "datta": 5114, + "jitsu": 5115, + "mae": 5116, + "toku": 5117, + "douitsu": 5118, + "ritsu": 5119, + "kyuu": 5120, + "hyou": 5121, + "rareta": 5122, + "keisei": 5123, + "kkan": 5124, + "rareru": 5125, + "mou": 5126, + "doko": 5127, + "ryou": 5128, + "dake": 5129, + "nakatta": 5130, + "soko": 5131, + "tabe": 5132, + "hana": 5133, + "fuku": 5134, + "yasu": 5135, + "wataku": 5136, + "yama": 5137, + "kyo": 5138, + "genzai": 5139, + "boku": 5140, + "ata": 5141, + "kawa": 5142, + "masen": 5143, + "juu": 5144, + "natte": 5145, + "watakushi": 5146, + "yotte": 5147, + "hai": 5148, + "jishin": 5149, + "rete": 5150, + "oka": 5151, + "kagaku": 5152, + "natta": 5153, + "karu": 5154, + "nari": 5155, + "mata": 5156, + "kuru": 5157, + "gai": 5158, + "kari": 5159, + "shakai": 5160, + "koui": 5161, + "yori": 5162, + "setsu": 5163, + "reru": 5164, + "tokoro": 5165, + "jutsu": 5166, + "saku": 5167, + "ttai": 5168, + "ningen": 5169, + "tame": 5170, + "kankyou": 5171, + "ooku": 5172, + "watashi": 5173, + "tsukuru": 5174, + "sugi": 5175, + "jibun": 5176, + "shitsu": 5177, + "keru": 5178, + "kishi": 5179, + "shikashi": 5180, + "moto": 5181, + "mari": 5182, + "itte": 5183, + "deshita": 5184, + "nde": 5185, + "arimasu": 5186, + "koe": 5187, + "zettai": 5188, + "kkanteki": 5189, + "rekishi": 5190, + "dekiru": 5191, + "tsuka": 5192, + "itta": 5193, + "kobutsu": 5194, + "miru": 5195, + "shoku": 5196, + "shimasu": 5197, + "gijutsu": 5198, + "gyou": 5199, + "joushiki": 5200, + "atta": 5201, + "hodo": 5202, + "koko": 5203, + "tsukurareta": 5204, + "zoku": 5205, + "hitei": 5206, + "koku": 5207, + "rekishiteki": 5208, + "kete": 5209, + "kako": 5210, + "nagara": 5211, + "kakaru": 5212, + "shutai": 5213, + "haji": 5214, + "taku": 5215, + "douitsuteki": 5216, + "mete": 5217, + "tsuu": 5218, + "sarete": 5219, + "genjitsu": 5220, + "bai": 5221, + "nawa": 5222, + "jikan": 5223, + "waru": 5224, + "rt": 5225, + "atsu": 5226, + "soku": 5227, + "kouiteki": 5228, + "kata": 5229, + "tetsu": 5230, + "gawa": 5231, + "kedo": 5232, + "reta": 5233, + "sayou": 5234, + "tteru": 5235, + "tori": 5236, + "kimi": 5237, + "mura": 5238, + "sareru": 5239, + "machi": 5240, + "kya": 5241, + "osa": 5242, + "konna": 5243, + "aku": 5244, + "sareta": 5245, + "ipp": 5246, + "shiku": 5247, + "uchi": 5248, + "hitotsu": 5249, + "hatara": 5250, + "tachiba": 5251, + "shiro": 5252, + "katachi": 5253, + "tomo": 5254, + "ete": 5255, + "meru": 5256, + "nichi": 5257, + "dare": 5258, + "katta": 5259, + "eru": 5260, + "suki": 5261, + "ooki": 5262, + "maru": 5263, + "moku": 5264, + "oko": 5265, + "kangaerareru": 5266, + "oto": 5267, + "tanni": 5268, + "tada": 5269, + "taiteki": 5270, + "motte": 5271, + "kinou": 5272, + "shinai": 5273, + "kki": 5274, + "tari": 5275, + "ranai": 5276, + "kkou": 5277, + "mirai": 5278, + "ppon": 5279, + "goto": 5280, + "hitsu": 5281, + "teru": 5282, + "mochi": 5283, + "katsu": 5284, + "nyuu": 5285, + "zuka": 5286, + "tsuite": 5287, + "nomi": 5288, + "sugu": 5289, + "kuda": 5290, + "tetsugaku": 5291, + "ika": 5292, + "ronri": 5293, + "oki": 5294, + "nippon": 5295, + "shimashita": 5296, + "chishiki": 5297, + "chokkanteki": 5298, + "suko": 5299, + "kuu": 5300, + "arou": 5301, + "katte": 5302, + "kuri": 5303, + "inai": 5304, + "hyougen": 5305, + "ishiki": 5306, + "doku": 5307, + "atte": 5308, + "atara": 5309, + "wari": 5310, + "kao": 5311, + "seisan": 5312, + "hanashi": 5313, + "kake": 5314, + "naji": 5315, + "sunawa": 5316, + "sunawachi": 5317, + "ugo": 5318, + "suu": 5319, + "bara": 5320, + "hiro": 5321, + "iwa": 5322, + "betsu": 5323, + "yoi": 5324, + "seru": 5325, + "shiteru": 5326, + "rarete": 5327, + "toshi": 5328, + "seki": 5329, + "tairitsu": 5330, + "wakara": 5331, + "tokyo": 5332, + "kka": 5333, + "kyoku": 5334, + "iro": 5335, + "mite": 5336, + "saki": 5337, + "kanji": 5338, + "mita": 5339, + "sube": 5340, + "ryoku": 5341, + "matta": 5342, + "kudasai": 5343, + "omoi": 5344, + "wareru": 5345, + "hitsuyou": 5346, + "kashi": 5347, + "renai": 5348, + "kankei": 5349, + "gatte": 5350, + "ochi": 5351, + "motsu": 5352, + "sonzai": 5353, + "taishite": 5354, + "ame": 5355, + "seimei": 5356, + "kano": 5357, + "giri": 5358, + "kangaeru": 5359, + "yue": 5360, + "asa": 5361, + "onaji": 5362, + "yoru": 5363, + "niku": 5364, + "osaka": 5365, + "sukoshi": 5366, + "tama": 5367, + "kanojo": 5368, + "kite": 5369, + "mondai": 5370, + "amari": 5371, + "eki": 5372, + "kojin": 5373, + "haya": 5374, + "dete": 5375, + "atarashii": 5376, + "awa": 5377, + "gakkou": 5378, + "tsuzu": 5379, + "shukan": 5380, + "imashita": 5381, + "atae": 5382, + "darou": 5383, + "hataraku": 5384, + "gata": 5385, + "dachi": 5386, + "matsu": 5387, + "arimasen": 5388, + "seibutsu": 5389, + "mitsu": 5390, + "heya": 5391, + "yasui": 5392, + "deni": 5393, + "noko": 5394, + "haha": 5395, + "domo": 5396, + "kami": 5397, + "sudeni": 5398, + "nao": 5399, + "raku": 5400, + "ike": 5401, + "meta": 5402, + "kodomo": 5403, + "soshite": 5404, + "game": 5405, + "bakari": 5406, + "tote": 5407, + "hatsu": 5408, + "mise": 5409, + "mokuteki": 5410, + "dakara": 5411, + "[ja]": 5412, + "ő": 5413, + "ű": 5414, + "そ": 5415, + "な": 5416, + "ん": 5417, + "포": 5418, + "�": 5419, + "gy": 5420, + "eg": 5421, + "cs": 5422, + "ál": 5423, + "egy": 5424, + "át": 5425, + "ott": 5426, + "ett": 5427, + "meg": 5428, + "hogy": 5429, + "ég": 5430, + "ól": 5431, + "nek": 5432, + "volt": 5433, + "ág": 5434, + "nk": 5435, + "ék": 5436, + "ít": 5437, + "ák": 5438, + "ud": 5439, + "szer": 5440, + "mind": 5441, + "oz": 5442, + "ép": 5443, + "ért": 5444, + "mond": 5445, + "szt": 5446, + "nak": 5447, + "ől": 5448, + "csak": 5449, + "oly": 5450, + "áll": 5451, + "ány": 5452, + "mint": 5453, + "már": 5454, + "ött": 5455, + "nagy": 5456, + "ész": 5457, + "azt": 5458, + "elő": 5459, + "tud": 5460, + "ény": 5461, + "áz": 5462, + "még": 5463, + "köz": 5464, + "ely": 5465, + "ség": 5466, + "hoz": 5467, + "uk": 5468, + "kez": 5469, + "ám": 5470, + "aj": 5471, + "unk": 5472, + "vagy": 5473, + "szem": 5474, + "ember": 5475, + "fog": 5476, + "mert": 5477, + "ös": 5478, + "ság": 5479, + "leg": 5480, + "ünk": 5481, + "hát": 5482, + "ony": 5483, + "ezt": 5484, + "minden": 5485, + "ült": 5486, + "jó": 5487, + "kis": 5488, + "áj": 5489, + "úgy": 5490, + "most": 5491, + "ír": 5492, + "itt": 5493, + "elt": 5494, + "mondta": 5495, + "kell": 5496, + "ált": 5497, + "érd": 5498, + "tö": 5499, + "vár": 5500, + "lát": 5501, + "ők": 5502, + "vet": 5503, + "után": 5504, + "két": 5505, + "nap": 5506, + "ív": 5507, + "ály": 5508, + "vég": 5509, + "ök": 5510, + "dul": 5511, + "néz": 5512, + "ában": 5513, + "kül": 5514, + "akkor": 5515, + "szél": 5516, + "új": 5517, + "olyan": 5518, + "ked": 5519, + "hely": 5520, + "tör": 5521, + "ból": 5522, + "elm": 5523, + "ára": 5524, + "ló": 5525, + "volna": 5526, + "lehet": 5527, + "ebb": 5528, + "sok": 5529, + "olt": 5530, + "eket": 5531, + "bor": 5532, + "fej": 5533, + "gond": 5534, + "akar": 5535, + "fél": 5536, + "úl": 5537, + "otta": 5538, + "valami": 5539, + "jel": 5540, + "éd": 5541, + "arc": 5542, + "hall": 5543, + "föl": 5544, + "ába": 5545, + "olg": 5546, + "kir": 5547, + "old": 5548, + "kérd": 5549, + "jár": 5550, + "úr": 5551, + "zs": 5552, + "élet": 5553, + "ját": 5554, + "ov": 5555, + "éz": 5556, + "vil": 5557, + "őr": 5558, + "ög": 5559, + "lesz": 5560, + "koz": 5561, + "ább": 5562, + "király": 5563, + "eng": 5564, + "igaz": 5565, + "haj": 5566, + "kod": 5567, + "ról": 5568, + "több": 5569, + "szó": 5570, + "ében": 5571, + "öt": 5572, + "nyi": 5573, + "szól": 5574, + "gondol": 5575, + "egész": 5576, + "így": 5577, + "ős": 5578, + "obb": 5579, + "osan": 5580, + "ből": 5581, + "abb": 5582, + "őt": 5583, + "nál": 5584, + "kép": 5585, + "aztán": 5586, + "tart": 5587, + "beszél": 5588, + "előtt": 5589, + "aszt": 5590, + "maj": 5591, + "kör": 5592, + "hang": 5593, + "íz": 5594, + "incs": 5595, + "év": 5596, + "ód": 5597, + "ók": 5598, + "hozz": 5599, + "okat": 5600, + "nagyon": 5601, + "ház": 5602, + "ped": 5603, + "ezte": 5604, + "etlen": 5605, + "neki": 5606, + "majd": 5607, + "szony": 5608, + "ának": 5609, + "felé": 5610, + "egyszer": 5611, + "adt": 5612, + "gyer": 5613, + "amikor": 5614, + "foly": 5615, + "szak": 5616, + "őd": 5617, + "hú": 5618, + "ász": 5619, + "amely": 5620, + "ére": 5621, + "ilyen": 5622, + "oda": 5623, + "ják": 5624, + "tár": 5625, + "ával": 5626, + "lak": 5627, + "gyan": 5628, + "ély": 5629, + "út": 5630, + "kezd": 5631, + "mell": 5632, + "mikor": 5633, + "hez": 5634, + "való": 5635, + "szeret": 5636, + "rend": 5637, + "vissza": 5638, + "fő": 5639, + "asszony": 5640, + "ről": 5641, + "pedig": 5642, + "szép": 5643, + "ták": 5644, + "öv": 5645, + "világ": 5646, + "maga": 5647, + "szik": 5648, + "éj": 5649, + "ént": 5650, + "jött": 5651, + "szí": 5652, + "gat": 5653, + "ettem": 5654, + "hány": 5655, + "ást": 5656, + "ahol": 5657, + "őket": 5658, + "hár": 5659, + "nő": 5660, + "csi": 5661, + "talál": 5662, + "elte": 5663, + "látt": 5664, + "tört": 5665, + "hagy": 5666, + "esz": 5667, + "nél": 5668, + "kut": 5669, + "lány": 5670, + "amit": 5671, + "ső": 5672, + "ellen": 5673, + "magát": 5674, + "ugyan": 5675, + "külön": 5676, + "asz": 5677, + "mindig": 5678, + "lép": 5679, + "talán": 5680, + "szor": 5681, + "illan": 5682, + "nincs": 5683, + "vagyok": 5684, + "telen": 5685, + "ismer": 5686, + "isten": 5687, + "ított": 5688, + "jobb": 5689, + "ves": 5690, + "dult": 5691, + "juk": 5692, + "szen": 5693, + "öm": 5694, + "lett": 5695, + "egyik": 5696, + "bár": 5697, + "szi": 5698, + "szív": 5699, + "azon": 5700, + "eszt": 5701, + "föld": 5702, + "kuty": 5703, + "pillan": 5704, + "fér": 5705, + "től": 5706, + "tű": 5707, + "ébe": 5708, + "tött": 5709, + "barát": 5710, + "íg": 5711, + "ahogy": 5712, + "eh": 5713, + "ep": 5714, + "jelent": 5715, + "tat": 5716, + "szeg": 5717, + "mintha": 5718, + "egyen": 5719, + "szab": 5720, + "bizony": 5721, + "jon": 5722, + "öreg": 5723, + "dolg": 5724, + "csap": 5725, + "tiszt": 5726, + "állt": 5727, + "ancs": 5728, + "idő": 5729, + "ügy": 5730, + "miért": 5731, + "ót": 5732, + "csin": 5733, + "ének": 5734, + "vér": 5735, + "jól": 5736, + "alatt": 5737, + "mely": 5738, + "semmi": 5739, + "nyug": 5740, + "vág": 5741, + "követ": 5742, + "össze": 5743, + "mad": 5744, + "acs": 5745, + "fiú": 5746, + "másik": 5747, + "jön": 5748, + "szám": 5749, + "rész": 5750, + "kér": 5751, + "ével": 5752, + "[hu]": 5753, + "%": 5754, + "0": 5755, + "6": 5756, + "7": 5757, + "8": 5758, + "9": 5759, + "A": 5760, + "B": 5761, + "C": 5762, + "D": 5763, + "E": 5764, + "F": 5765, + "G": 5766, + "H": 5767, + "I": 5768, + "J": 5769, + "K": 5770, + "L": 5771, + "M": 5772, + "N": 5773, + "O": 5774, + "P": 5775, + "Q": 5776, + "R": 5777, + "S": 5778, + "T": 5779, + "U": 5780, + "V": 5781, + "W": 5782, + "X": 5783, + "Y": 5784, + "Z": 5785, + "Ł": 5786, + "α": 5787, + "ς": 5788, + "♥": 5789, + "か": 5790, + "ズ": 5791, + "因": 5792, + "国": 5793, + "怎": 5794, + "抱": 5795, + "推": 5796, + "有": 5797, + "樣": 5798, + "為": 5799, + "群": 5800, + "麼": 5801, + "eo": 5802, + "eul": 5803, + "eun": 5804, + "eon": 5805, + "ae": 5806, + "yeon": 5807, + "yeo": 5808, + "ui": 5809, + "hae": 5810, + "geo": 5811, + "neun": 5812, + "ssda": 5813, + "seo": 5814, + "eong": 5815, + "kk": 5816, + "jeo": 5817, + "deul": 5818, + "eum": 5819, + "yeong": 5820, + "geos": 5821, + "hag": 5822, + "aneun": 5823, + "iss": 5824, + "dae": 5825, + "eob": 5826, + "eol": 5827, + "geu": 5828, + "jeong": 5829, + "sae": 5830, + "doe": 5831, + "geul": 5832, + "eulo": 5833, + "bn": 5834, + "sang": 5835, + "bnida": 5836, + "haneun": 5837, + "jeog": 5838, + "saeng": 5839, + "ineun": 5840, + "anh": 5841, + "salam": 5842, + "eom": 5843, + "nae": 5844, + "gwa": 5845, + "yeol": 5846, + "eseo": 5847, + "myeon": 5848, + "ttae": 5849, + "hw": 5850, + "eobs": 5851, + "jang": 5852, + "gw": 5853, + "ileul": 5854, + "yeog": 5855, + "jeon": 5856, + "sig": 5857, + "jag": 5858, + "hago": 5859, + "deun": 5860, + "seong": 5861, + "gag": 5862, + "ham": 5863, + "dang": 5864, + "leul": 5865, + "sil": 5866, + "dong": 5867, + "handa": 5868, + "eossda": 5869, + "aeg": 5870, + "seon": 5871, + "haessda": 5872, + "issda": 5873, + "ege": 5874, + "mul": 5875, + "jung": 5876, + "jig": 5877, + "issneun": 5878, + "geun": 5879, + "seubnida": 5880, + "won": 5881, + "daneun": 5882, + "eoh": 5883, + "deo": 5884, + "gam": 5885, + "jal": 5886, + "haeng": 5887, + "yang": 5888, + "bang": 5889, + "jae": 5890, + "saenggag": 5891, + "hage": 5892, + "sog": 5893, + "eoss": 5894, + "jasin": 5895, + "jil": 5896, + "eog": 5897, + "gyeong": 5898, + "gong": 5899, + "deon": 5900, + "haess": 5901, + "eung": 5902, + "joh": 5903, + "nal": 5904, + "myeong": 5905, + "eona": 5906, + "igo": 5907, + "gyeol": 5908, + "yag": 5909, + "gwan": 5910, + "uli": 5911, + "yong": 5912, + "lyeo": 5913, + "jog": 5914, + "eohge": 5915, + "bog": 5916, + "tong": 5917, + "manh": 5918, + "jeol": 5919, + "geol": 5920, + "aga": 5921, + "naneun": 5922, + "uneun": 5923, + "cheol": 5924, + "dol": 5925, + "bad": 5926, + "hamyeon": 5927, + "yeossda": 5928, + "ibnida": 5929, + "gye": 5930, + "eos": 5931, + "hwal": 5932, + "salamdeul": 5933, + "jiman": 5934, + "dangsin": 5935, + "jib": 5936, + "ttaemun": 5937, + "ib": 5938, + "eneun": 5939, + "eug": 5940, + "jeom": 5941, + "geuleon": 5942, + "hwa": 5943, + "assda": 5944, + "beob": 5945, + "bae": 5946, + "yeoss": 5947, + "chin": 5948, + "chaeg": 5949, + "geon": 5950, + "naega": 5951, + "iga": 5952, + "sigan": 5953, + "gil": 5954, + "hyeon": 5955, + "lyeog": 5956, + "gug": 5957, + "pyeon": 5958, + "wae": 5959, + "jul": 5960, + "seul": 5961, + "deung": 5962, + "hajiman": 5963, + "eumyeon": 5964, + "pil": 5965, + "nyeon": 5966, + "tae": 5967, + "pyo": 5968, + "jineun": 5969, + "beon": 5970, + "hada": 5971, + "seol": 5972, + "sip": 5973, + "daleun": 5974, + "salm": 5975, + "gyo": 5976, + "cheon": 5977, + "hagi": 5978, + "cheoleom": 5979, + "gal": 5980, + "ila": 5981, + "kkaji": 5982, + "anhneun": 5983, + "habnida": 5984, + "tteon": 5985, + "haeseo": 5986, + "doenda": 5987, + "ttal": 5988, + "ilo": 5989, + "seub": 5990, + "byeon": 5991, + "myeo": 5992, + "beol": 5993, + "jeung": 5994, + "chim": 5995, + "hwang": 5996, + "euneun": 5997, + "jong": 5998, + "boda": 5999, + "nol": 6000, + "neom": 6001, + "buteo": 6002, + "jigeum": 6003, + "eobsda": 6004, + "daelo": 6005, + "yul": 6006, + "pyeong": 6007, + "seoneun": 6008, + "salang": 6009, + "seut": 6010, + "heom": 6011, + "hyang": 6012, + "gwang": 6013, + "eobsneun": 6014, + "hwag": 6015, + "gess": 6016, + "jagi": 6017, + "ileon": 6018, + "wihae": 6019, + "daehan": 6020, + "gaji": 6021, + "meog": 6022, + "jyeo": 6023, + "chaj": 6024, + "byeong": 6025, + "eod": 6026, + "gyeo": 6027, + "eoji": 6028, + "gul": 6029, + "modeun": 6030, + "insaeng": 6031, + "geulae": 6032, + "sasil": 6033, + "sib": 6034, + "chal": 6035, + "ilago": 6036, + "geum": 6037, + "doeneun": 6038, + "bol": 6039, + "gajang": 6040, + "geuligo": 6041, + "hyeong": 6042, + "haengbog": 6043, + "chul": 6044, + "chae": 6045, + "mang": 6046, + "dam": 6047, + "choe": 6048, + "sijag": 6049, + "cheong": 6050, + "ilaneun": 6051, + "ulineun": 6052, + "aen": 6053, + "kke": 6054, + "munje": 6055, + "teu": 6056, + "geuneun": 6057, + "bge": 6058, + "cheo": 6059, + "baeg": 6060, + "jug": 6061, + "sangdae": 6062, + "geugeos": 6063, + "dog": 6064, + "eus": 6065, + "jab": 6066, + "hyeo": 6067, + "tteohge": 6068, + "chil": 6069, + "swi": 6070, + "jileul": 6071, + "chang": 6072, + "ganeun": 6073, + "iji": 6074, + "dago": 6075, + "yohan": 6076, + "teug": 6077, + "ppun": 6078, + "aleul": 6079, + "haengdong": 6080, + "sesang": 6081, + "edo": 6082, + "mandeul": 6083, + "amyeon": 6084, + "kkae": 6085, + "bag": 6086, + "ideul": 6087, + "pum": 6088, + "meol": 6089, + "neul": 6090, + "hamkke": 6091, + "chung": 6092, + "dab": 6093, + "yug": 6094, + "sag": 6095, + "gwangye": 6096, + "ileohge": 6097, + "balo": 6098, + "neunde": 6099, + "hamyeo": 6100, + "geuleoh": 6101, + "anila": 6102, + "bangbeob": 6103, + "dasi": 6104, + "byeol": 6105, + "gyeon": 6106, + "gamjeong": 6107, + "oneul": 6108, + "janeun": 6109, + "yeom": 6110, + "lago": 6111, + "igi": 6112, + "hwan": 6113, + "teul": 6114, + "eoseo": 6115, + "sik": 6116, + "jaga": 6117, + "geuleom": 6118, + "geuleona": 6119, + "jeongdo": 6120, + "gyeog": 6121, + "geuleohge": 6122, + "geudeul": 6123, + "eut": 6124, + "imyeon": 6125, + "jjae": 6126, + "keun": 6127, + "isang": 6128, + "malhaessda": 6129, + "euge": 6130, + "nop": 6131, + "ingan": 6132, + "bomyeon": 6133, + "taeg": 6134, + "dwi": 6135, + "saneun": 6136, + "wan": 6137, + "anhgo": 6138, + "nugu": 6139, + "sung": 6140, + "damyeon": 6141, + "adeul": 6142, + "peul": 6143, + "ttala": 6144, + "geosdo": 6145, + "aji": 6146, + "meon": 6147, + "eumyeo": 6148, + "dolog": 6149, + "neung": 6150, + "modu": 6151, + "[ko]": 6152, + "\u0014": 6153, + "\u0016": 6154, + "$": 6155, + "*": 6156, + "|": 6157, + "°": 6158, + "º": 6159, + "ँ": 6160, + "ं": 6161, + "ः": 6162, + "अ": 6163, + "आ": 6164, + "इ": 6165, + "ई": 6166, + "उ": 6167, + "ऊ": 6168, + "ऋ": 6169, + "ऎ": 6170, + "ए": 6171, + "ऐ": 6172, + "ऑ": 6173, + "ऒ": 6174, + "ओ": 6175, + "औ": 6176, + "क": 6177, + "ख": 6178, + "ग": 6179, + "घ": 6180, + "ङ": 6181, + "च": 6182, + "छ": 6183, + "ज": 6184, + "झ": 6185, + "ञ": 6186, + "ट": 6187, + "ठ": 6188, + "ड": 6189, + "ढ": 6190, + "ण": 6191, + "त": 6192, + "थ": 6193, + "द": 6194, + "ध": 6195, + "न": 6196, + "ऩ": 6197, + "प": 6198, + "फ": 6199, + "ब": 6200, + "भ": 6201, + "म": 6202, + "य": 6203, + "र": 6204, + "ऱ": 6205, + "ल": 6206, + "ळ": 6207, + "व": 6208, + "श": 6209, + "ष": 6210, + "स": 6211, + "ह": 6212, + "़": 6213, + "ा": 6214, + "ि": 6215, + "ी": 6216, + "ु": 6217, + "ू": 6218, + "ृ": 6219, + "ॄ": 6220, + "ॅ": 6221, + "ॆ": 6222, + "े": 6223, + "ै": 6224, + "ॉ": 6225, + "ॊ": 6226, + "ो": 6227, + "ौ": 6228, + "्": 6229, + "ॐ": 6230, + "ॖ": 6231, + "क़": 6232, + "ख़": 6233, + "ग़": 6234, + "ज़": 6235, + "ड़": 6236, + "ढ़": 6237, + "फ़": 6238, + "य़": 6239, + "ॠ": 6240, + "।": 6241, + "॥": 6242, + "०": 6243, + "१": 6244, + "२": 6245, + "३": 6246, + "४": 6247, + "५": 6248, + "६": 6249, + "७": 6250, + "८": 6251, + "९": 6252, + "॰": 6253, + "ॲ": 6254, + "​": 6255, + "‌": 6256, + "‍": 6257, + "‎": 6258, + "₹": 6259, + "के": 6260, + "है": 6261, + "ें": 6262, + "्र": 6263, + "ार": 6264, + "ने": 6265, + "या": 6266, + "में": 6267, + "से": 6268, + "की": 6269, + "का": 6270, + "ों": 6271, + "ता": 6272, + "कर": 6273, + "स्": 6274, + "कि": 6275, + "को": 6276, + "र्": 6277, + "ना": 6278, + "क्": 6279, + "ही": 6280, + "और": 6281, + "पर": 6282, + "ते": 6283, + "हो": 6284, + "प्र": 6285, + "ान": 6286, + "्य": 6287, + "ला": 6288, + "वा": 6289, + "ले": 6290, + "सा": 6291, + "हैं": 6292, + "लि": 6293, + "जा": 6294, + "हा": 6295, + "भी": 6296, + "वि": 6297, + "इस": 6298, + "ती": 6299, + "न्": 6300, + "रा": 6301, + "मा": 6302, + "दे": 6303, + "दि": 6304, + "बा": 6305, + "ति": 6306, + "था": 6307, + "नि": 6308, + "कार": 6309, + "एक": 6310, + "हीं": 6311, + "हु": 6312, + "ंग": 6313, + "ैं": 6314, + "नी": 6315, + "सी": 6316, + "अप": 6317, + "त्": 6318, + "नहीं": 6319, + "री": 6320, + "मे": 6321, + "मु": 6322, + "ित": 6323, + "तो": 6324, + "पा": 6325, + "ली": 6326, + "लिए": 6327, + "गा": 6328, + "ल्": 6329, + "रह": 6330, + "रे": 6331, + "क्ष": 6332, + "मैं": 6333, + "सम": 6334, + "उस": 6335, + "जि": 6336, + "त्र": 6337, + "मि": 6338, + "चा": 6339, + "ोग": 6340, + "सं": 6341, + "द्": 6342, + "सि": 6343, + "आप": 6344, + "तु": 6345, + "दा": 6346, + "कु": 6347, + "यों": 6348, + "वे": 6349, + "जी": 6350, + "्या": 6351, + "उन": 6352, + "िक": 6353, + "ये": 6354, + "भा": 6355, + "्ट": 6356, + "हम": 6357, + "स्ट": 6358, + "शा": 6359, + "ड़": 6360, + "ंद": 6361, + "खा": 6362, + "म्": 6363, + "श्": 6364, + "यह": 6365, + "सक": 6366, + "पू": 6367, + "किया": 6368, + "अपने": 6369, + "रू": 6370, + "सु": 6371, + "मी": 6372, + "हि": 6373, + "जो": 6374, + "थे": 6375, + "रि": 6376, + "दी": 6377, + "थी": 6378, + "गी": 6379, + "लोग": 6380, + "गया": 6381, + "तर": 6382, + "न्ह": 6383, + "च्": 6384, + "वार": 6385, + "बी": 6386, + "प्": 6387, + "दो": 6388, + "टी": 6389, + "शि": 6390, + "करने": 6391, + "गे": 6392, + "ैसे": 6393, + "इन": 6394, + "ंड": 6395, + "साथ": 6396, + "पु": 6397, + "बे": 6398, + "बार": 6399, + "वी": 6400, + "अन": 6401, + "हर": 6402, + "उन्ह": 6403, + "होता": 6404, + "जब": 6405, + "कुछ": 6406, + "मान": 6407, + "क्र": 6408, + "बि": 6409, + "पह": 6410, + "फि": 6411, + "सर": 6412, + "ारी": 6413, + "रो": 6414, + "दू": 6415, + "कहा": 6416, + "तक": 6417, + "शन": 6418, + "ब्": 6419, + "स्थ": 6420, + "वह": 6421, + "बाद": 6422, + "ओं": 6423, + "गु": 6424, + "ज्": 6425, + "्रे": 6426, + "गर": 6427, + "रहे": 6428, + "वर्": 6429, + "हू": 6430, + "ार्": 6431, + "पी": 6432, + "बहु": 6433, + "मुझ": 6434, + "्रा": 6435, + "दिया": 6436, + "सब": 6437, + "करते": 6438, + "अपनी": 6439, + "बहुत": 6440, + "कह": 6441, + "टे": 6442, + "हुए": 6443, + "किसी": 6444, + "रहा": 6445, + "ष्ट": 6446, + "ज़": 6447, + "बना": 6448, + "सो": 6449, + "डि": 6450, + "कोई": 6451, + "व्य": 6452, + "बात": 6453, + "रु": 6454, + "वो": 6455, + "मुझे": 6456, + "द्ध": 6457, + "चार": 6458, + "मेरे": 6459, + "वर": 6460, + "्री": 6461, + "जाता": 6462, + "नों": 6463, + "प्रा": 6464, + "देख": 6465, + "टा": 6466, + "क्या": 6467, + "अध": 6468, + "लग": 6469, + "लो": 6470, + "पि": 6471, + "यु": 6472, + "चे": 6473, + "जिस": 6474, + "ंत": 6475, + "ानी": 6476, + "पै": 6477, + "जन": 6478, + "ारे": 6479, + "ची": 6480, + "मिल": 6481, + "दु": 6482, + "देश": 6483, + "च्छ": 6484, + "ष्": 6485, + "सू": 6486, + "खे": 6487, + "चु": 6488, + "िया": 6489, + "लगा": 6490, + "बु": 6491, + "उनके": 6492, + "ज्ञ": 6493, + "क्षा": 6494, + "तरह": 6495, + "्यादा": 6496, + "वाले": 6497, + "पूर्": 6498, + "मैंने": 6499, + "काम": 6500, + "रूप": 6501, + "होती": 6502, + "उप": 6503, + "जान": 6504, + "प्रकार": 6505, + "भार": 6506, + "मन": 6507, + "हुआ": 6508, + "टर": 6509, + "हूँ": 6510, + "परि": 6511, + "पास": 6512, + "अनु": 6513, + "राज": 6514, + "लोगों": 6515, + "अब": 6516, + "समझ": 6517, + "डी": 6518, + "मौ": 6519, + "शु": 6520, + "चि": 6521, + "पे": 6522, + "कृ": 6523, + "सकते": 6524, + "मह": 6525, + "योग": 6526, + "दर्": 6527, + "उसे": 6528, + "ंध": 6529, + "डा": 6530, + "जाए": 6531, + "बो": 6532, + "ूल": 6533, + "मो": 6534, + "ोंने": 6535, + "ंस": 6536, + "तुम": 6537, + "पहले": 6538, + "बता": 6539, + "तथा": 6540, + "यो": 6541, + "गई": 6542, + "उत्": 6543, + "सकता": 6544, + "कम": 6545, + "ज्यादा": 6546, + "रख": 6547, + "समय": 6548, + "ारा": 6549, + "अगर": 6550, + "स्त": 6551, + "चल": 6552, + "फिर": 6553, + "वारा": 6554, + "करना": 6555, + "शी": 6556, + "गए": 6557, + "बन": 6558, + "ौर": 6559, + "होने": 6560, + "चाह": 6561, + "खु": 6562, + "हाँ": 6563, + "उन्हें": 6564, + "उन्होंने": 6565, + "छो": 6566, + "म्ह": 6567, + "प्रति": 6568, + "निक": 6569, + "वन": 6570, + "्यू": 6571, + "रही": 6572, + "तुम्ह": 6573, + "जैसे": 6574, + "ियों": 6575, + "क्यों": 6576, + "लों": 6577, + "फ़": 6578, + "ंत्र": 6579, + "होते": 6580, + "क्ति": 6581, + "त्य": 6582, + "कर्": 6583, + "कई": 6584, + "वं": 6585, + "किन": 6586, + "पो": 6587, + "कारण": 6588, + "ड़ी": 6589, + "भि": 6590, + "इसके": 6591, + "बर": 6592, + "उसके": 6593, + "द्वारा": 6594, + "शे": 6595, + "कॉ": 6596, + "दिन": 6597, + "न्न": 6598, + "ड़ा": 6599, + "स्व": 6600, + "निर्": 6601, + "मुख": 6602, + "लिया": 6603, + "टि": 6604, + "ज्ञान": 6605, + "क्त": 6606, + "द्र": 6607, + "ग्": 6608, + "क्स": 6609, + "मै": 6610, + "गो": 6611, + "जे": 6612, + "ट्र": 6613, + "मार": 6614, + "त्व": 6615, + "धार": 6616, + "भाव": 6617, + "करता": 6618, + "खि": 6619, + "कं": 6620, + "चाहि": 6621, + "यर": 6622, + "प्त": 6623, + "कों": 6624, + "ंच": 6625, + "जु": 6626, + "मत": 6627, + "अच्छ": 6628, + "हुई": 6629, + "कभी": 6630, + "लेकिन": 6631, + "भू": 6632, + "अपना": 6633, + "दूस": 6634, + "चाहिए": 6635, + "यू": 6636, + "घर": 6637, + "सबसे": 6638, + "मेरी": 6639, + "नाम": 6640, + "ढ़": 6641, + "ंट": 6642, + "ेंगे": 6643, + "बै": 6644, + "फा": 6645, + "एवं": 6646, + "यी": 6647, + "ग्र": 6648, + "क्षे": 6649, + "आज": 6650, + "आपको": 6651, + "भाग": 6652, + "ठा": 6653, + "कै": 6654, + "भारत": 6655, + "उनकी": 6656, + "पहु": 6657, + "सभी": 6658, + "धा": 6659, + "णा": 6660, + "सान": 6661, + "होगा": 6662, + "तब": 6663, + "संग": 6664, + "पर्": 6665, + "अव": 6666, + "तना": 6667, + "गि": 6668, + "यन": 6669, + "स्था": 6670, + "चित": 6671, + "ट्": 6672, + "छा": 6673, + "जाने": 6674, + "क्षेत्र": 6675, + "वाली": 6676, + "पूर्ण": 6677, + "समा": 6678, + "कारी": 6679, + "[hi]": 6680, + "ĩ": 6681, + "ũ": 6682, + "ơ": 6683, + "ư": 6684, + "ạ": 6685, + "ả": 6686, + "ấ": 6687, + "ầ": 6688, + "ẩ": 6689, + "ẫ": 6690, + "ậ": 6691, + "ắ": 6692, + "ằ": 6693, + "ẳ": 6694, + "ẵ": 6695, + "ặ": 6696, + "ẹ": 6697, + "ẻ": 6698, + "ẽ": 6699, + "ế": 6700, + "ề": 6701, + "ể": 6702, + "ễ": 6703, + "ệ": 6704, + "ỉ": 6705, + "ị": 6706, + "ọ": 6707, + "ỏ": 6708, + "ố": 6709, + "ồ": 6710, + "ổ": 6711, + "ỗ": 6712, + "ộ": 6713, + "ớ": 6714, + "ờ": 6715, + "ở": 6716, + "ỡ": 6717, + "ợ": 6718, + "ụ": 6719, + "ủ": 6720, + "ứ": 6721, + "ừ": 6722, + "ử": 6723, + "ữ": 6724, + "ự": 6725, + "ỳ": 6726, + "ỵ": 6727, + "ỷ": 6728, + "ỹ": 6729, + "nh": 6730, + "kh": 6731, + "ông": 6732, + "ời": 6733, + "ấy": 6734, + "ôi": 6735, + "ột": 6736, + "một": 6737, + "cá": 6738, + "không": 6739, + "ười": 6740, + "tôi": 6741, + "mà": 6742, + "ại": 6743, + "cái": 6744, + "thì": 6745, + "người": 6746, + "iế": 6747, + "nó": 6748, + "như": 6749, + "đã": 6750, + "ào": 6751, + "thế": 6752, + "lại": 6753, + "đư": 6754, + "ồi": 6755, + "ới": 6756, + "cũ": 6757, + "đi": 6758, + "ình": 6759, + "cũng": 6760, + "ải": 6761, + "ủa": 6762, + "ngh": 6763, + "của": 6764, + "iệ": 6765, + "ến": 6766, + "ằng": 6767, + "ợc": 6768, + "được": 6769, + "đến": 6770, + "ất": 6771, + "phải": 6772, + "ững": 6773, + "những": 6774, + "rồi": 6775, + "àng": 6776, + "nói": 6777, + "úc": 6778, + "uố": 6779, + "ữa": 6780, + "ướ": 6781, + "ày": 6782, + "cả": 6783, + "bà": 6784, + "ây": 6785, + "gì": 6786, + "và": 6787, + "ần": 6788, + "iết": 6789, + "làm": 6790, + "để": 6791, + "òn": 6792, + "với": 6793, + "ật": 6794, + "mình": 6795, + "vào": 6796, + "âu": 6797, + "đá": 6798, + "ước": 6799, + "nhà": 6800, + "nữa": 6801, + "còn": 6802, + "ác": 6803, + "nào": 6804, + "chỉ": 6805, + "thấy": 6806, + "ắt": 6807, + "ết": 6808, + "biết": 6809, + "vì": 6810, + "ưa": 6811, + "iề": 6812, + "ắm": 6813, + "ồng": 6814, + "ài": 6815, + "êu": 6816, + "này": 6817, + "mặ": 6818, + "sự": 6819, + "trong": 6820, + "lên": 6821, + "ăn": 6822, + "việ": 6823, + "khi": 6824, + "ái": 6825, + "ầu": 6826, + "ơi": 6827, + "về": 6828, + "ạn": 6829, + "ơng": 6830, + "òng": 6831, + "lúc": 6832, + "đứ": 6833, + "iên": 6834, + "cứ": 6835, + "ỏi": 6836, + "thôi": 6837, + "nghĩ": 6838, + "rằng": 6839, + "lắm": 6840, + "ăm": 6841, + "ận": 6842, + "nhưng": 6843, + "úng": 6844, + "mới": 6845, + "mặt": 6846, + "giờ": 6847, + "êm": 6848, + "inh": 6849, + "việc": 6850, + "ối": 6851, + "áo": 6852, + "ính": 6853, + "sao": 6854, + "ậy": 6855, + "ội": 6856, + "àn": 6857, + "tay": 6858, + "ẫn": 6859, + "ờng": 6860, + "tư": 6861, + "họ": 6862, + "uốn": 6863, + "cách": 6864, + "ẳng": 6865, + "thư": 6866, + "ùng": 6867, + "lấy": 6868, + "đó": 6869, + "chứ": 6870, + "ách": 6871, + "hỏi": 6872, + "cụ": 6873, + "muốn": 6874, + "chàng": 6875, + "nên": 6876, + "thật": 6877, + "đấy": 6878, + "ện": 6879, + "ừng": 6880, + "nhau": 6881, + "tiế": 6882, + "ơn": 6883, + "vợ": 6884, + "đây": 6885, + "ếu": 6886, + "sau": 6887, + "iền": 6888, + "số": 6889, + "bả": 6890, + "đầu": 6891, + "âm": 6892, + "bị": 6893, + "ởng": 6894, + "cô": 6895, + "ổi": 6896, + "ục": 6897, + "nếu": 6898, + "vậy": 6899, + "quan": 6900, + "ãi": 6901, + "ngồi": 6902, + "chưa": 6903, + "vẫn": 6904, + "đời": 6905, + "phúc": 6906, + "uyện": 6907, + "ửa": 6908, + "trước": 6909, + "ìn": 6910, + "ắc": 6911, + "ngay": 6912, + "uống": 6913, + "ức": 6914, + "iể": 6915, + "chẳng": 6916, + "chồng": 6917, + "chúng": 6918, + "ầm": 6919, + "ạc": 6920, + "ộc": 6921, + "chị": 6922, + "phú": 6923, + "lòng": 6924, + "ậu": 6925, + "ành": 6926, + "đứng": 6927, + "ôn": 6928, + "ỗi": 6929, + "thể": 6930, + "sẽ": 6931, + "ừa": 6932, + "từ": 6933, + "rất": 6934, + "óc": 6935, + "tình": 6936, + "yêu": 6937, + "kia": 6938, + "bố": 6939, + "ực": 6940, + "lư": 6941, + "iêm": 6942, + "tuy": 6943, + "đâu": 6944, + "xuống": 6945, + "ầy": 6946, + "ịch": 6947, + "iều": 6948, + "nay": 6949, + "mẹ": 6950, + "bảo": 6951, + "iểu": 6952, + "bao": 6953, + "tự": 6954, + "tiền": 6955, + "nhìn": 6956, + "ôm": 6957, + "mày": 6958, + "bác": 6959, + "ích": 6960, + "xin": 6961, + "uộc": 6962, + "ọi": 6963, + "trên": 6964, + "áng": 6965, + "ọn": 6966, + "đáng": 6967, + "ều": 6968, + "nước": 6969, + "bằng": 6970, + "bạn": 6971, + "hạ": 6972, + "xe": 6973, + "các": 6974, + "ền": 6975, + "nhiên": 6976, + "cùng": 6977, + "ương": 6978, + "lẽ": 6979, + "chính": 6980, + "mấy": 6981, + "ăng": 6982, + "năm": 6983, + "vừa": 6984, + "cậu": 6985, + "mắt": 6986, + "ọng": 6987, + "tưởng": 6988, + "quá": 6989, + "iện": 6990, + "khác": 6991, + "sợ": 6992, + "nhân": 6993, + "ọc": 6994, + "độ": 6995, + "nàng": 6996, + "ngo": 6997, + "bắt": 6998, + "ngày": 6999, + "ạy": 7000, + "chết": 7001, + "hôm": 7002, + "hơn": 7003, + "thằng": 7004, + "tiếng": 7005, + "xuân": 7006, + "liêm": 7007, + "hiểu": 7008, + "cười": 7009, + "ẩm": 7010, + "bây": 7011, + "đàn": 7012, + "rõ": 7013, + "lời": 7014, + "xem": 7015, + "vô": 7016, + "chuyện": 7017, + "hằng": 7018, + "lần": 7019, + "ắng": 7020, + "mất": 7021, + "nghĩa": 7022, + "nhất": 7023, + "câu": 7024, + "chỗ": 7025, + "bên": 7026, + "báo": 7027, + "ánh": 7028, + "khổ": 7029, + "đủ": 7030, + "ướng": 7031, + "ốc": 7032, + "công": 7033, + "iếu": 7034, + "ấm": 7035, + "hàng": 7036, + "đáp": 7037, + "ngài": 7038, + "ốt": 7039, + "iến": 7040, + "trông": 7041, + "điều": 7042, + "mãi": 7043, + "ặc": 7044, + "ạo": 7045, + "thu": 7046, + "đồng": 7047, + "chủ": 7048, + "nhiều": 7049, + "gái": 7050, + "xong": 7051, + "ồn": 7052, + "cuộc": 7053, + "áu": 7054, + "kho": 7055, + "học": 7056, + "ỳnh": 7057, + "đánh": 7058, + "bạc": 7059, + "thứ": 7060, + "nghe": 7061, + "bọn": 7062, + "mịch": 7063, + "ớn": 7064, + "kẻ": 7065, + "quy": 7066, + "ập": 7067, + "ốn": 7068, + "thành": 7069, + "cơ": 7070, + "ấn": 7071, + "quay": 7072, + "cửa": 7073, + "quỳnh": 7074, + "minh": 7075, + "đị": 7076, + "thầy": 7077, + "tâm": 7078, + "vẻ": 7079, + "èn": 7080, + "văn": 7081, + "theo": 7082, + "hết": 7083, + "iêu": 7084, + "đường": 7085, + "đo": 7086, + "ấu": 7087, + "bẩm": 7088, + "đương": 7089, + "lạ": 7090, + "iệu": 7091, + "gọi": 7092, + "tân": 7093, + "ắn": 7094, + "chơi": 7095, + "uyên": 7096, + "cổ": 7097, + "dám": 7098, + "đồ": 7099, + "ẩn": 7100, + "phòng": 7101, + "chạy": 7102, + "dân": 7103, + "sướng": 7104, + "quả": 7105, + "úi": 7106, + "nghị": 7107, + "cần": 7108, + "bất": 7109, + "bàn": 7110, + "ợi": 7111, + "hư": 7112, + "ẳn": 7113, + "khó": 7114, + "mặc": 7115, + "thân": 7116, + "uốc": 7117, + "ống": 7118, + "óng": 7119, + "ặp": 7120, + "hẳn": 7121, + "tao": 7122, + "ìm": 7123, + "ngoài": 7124, + "hoặc": 7125, + "hội": 7126, + "ắp": 7127, + "chân": 7128, + "bỏ": 7129, + "mười": 7130, + "xưa": 7131, + "lã": 7132, + "lâu": 7133, + "đứa": 7134, + "ường": 7135, + "ợng": 7136, + "khách": 7137, + "tây": 7138, + "ấp": 7139, + "chịu": 7140, + "quần": 7141, + "ợt": 7142, + "thưa": 7143, + "thương": 7144, + "tóc": 7145, + "tìm": 7146, + "trăm": 7147, + "tức": 7148, + "ặng": 7149, + "nhớ": 7150, + "âng": 7151, + "éo": 7152, + "ạt": 7153, + "nỗi": 7154, + "ằm": 7155, + "nhận": 7156, + "ạnh": 7157, + "cảnh": 7158, + "sinh": 7159, + "tử": 7160, + "dài": 7161, + "hơi": 7162, + "ản": 7163, + "ưới": 7164, + "ảng": 7165, + "việt": 7166, + "vị": 7167, + "huyện": 7168, + "quân": 7169, + "thường": 7170, + "lớn": 7171, + "đau": 7172, + "ặt": 7173, + "thuốc": 7174, + "cảm": 7175, + "thời": 7176, + "êng": 7177, + "đưa": 7178, + "tội": 7179, + "đức": 7180, + "mư": 7181, + "động": 7182, + "làng": 7183, + "kêu": 7184, + "uyền": 7185, + "mọi": 7186, + "đối": 7187, + "tài": 7188, + "ỉnh": 7189, + "ín": 7190, + "cử": 7191, + "bốn": 7192, + "vân": 7193, + "tuyết": 7194, + "cơm": 7195, + "bộ": 7196, + "hình": 7197, + "quý": 7198, + "iêng": 7199, + "vài": 7200, + "vội": 7201, + "trẻ": 7202, + "lão": 7203, + "nằm": 7204, + "đừng": 7205, + "tại": 7206, + "vui": 7207, + "định": 7208, + "đêm": 7209, + "chắc": 7210, + "ua": 7211, + "ẹp": 7212, + "phán": 7213, + "giữa": 7214, + "đem": 7215, + "thêm": 7216, + "bình": 7217, + "tho": 7218, + "sống": 7219, + "cao": 7220, + "bước": 7221, + "ụng": 7222, + "xa": 7223, + "hại": 7224, + "ệt": 7225, + "hách": 7226, + "ớm": 7227, + "ẩy": 7228, + "phá": 7229, + "cố": 7230, + "ãy": 7231, + "ảy": 7232, + "ẹn": 7233, + "hồi": 7234, + "iếc": 7235, + "ngủ": 7236, + "thiên": 7237, + "sáng": 7238, + "nhiêu": 7239, + "dung": 7240, + "bài": 7241, + "tiếp": 7242, + "ệnh": 7243, + "nổi": 7244, + "đình": 7245, + "mời": 7246, + "ạm": 7247, + "tấn": 7248, + "iếm": 7249, + "óa": 7250, + "thiếu": 7251, + "kể": 7252, + "ảnh": 7253, + "dễ": 7254, + "tất": 7255, + "khỏi": 7256, + "quyền": 7257, + "vâng": 7258, + "thực": 7259, + "giá": 7260, + "hộ": 7261, + "phu": 7262, + "mỗi": 7263, + "cầm": 7264, + "ịp": 7265, + "nhỏ": 7266, + "khiến": 7267, + "trọng": 7268, + "giữ": 7269, + "giấy": 7270, + "gặp": 7271, + "sở": 7272, + "ẫu": 7273, + "hồ": 7274, + "đỏ": 7275, + "ặn": 7276, + "càng": 7277, + "đại": 7278, + "thị": 7279, + "trở": 7280, + "iệt": 7281, + "vạn": 7282, + "nữ": 7283, + "trí": 7284, + "áy": 7285, + "tháng": 7286, + "chú": 7287, + "hà": 7288, + "mươi": 7289, + "dưới": 7290, + "ngờ": 7291, + "hoa": 7292, + "đọc": 7293, + "đẹp": 7294, + "mở": 7295, + "tối": 7296, + "đẻ": 7297, + "chờ": 7298, + "kỳ": 7299, + "giận": 7300, + "phần": 7301, + "hạnh": 7302, + "đất": 7303, + "phụ": 7304, + "òa": 7305, + "nạn": 7306, + "buồn": 7307, + "tuổi": 7308, + "sức": 7309, + "thần": 7310, + "quên": 7311, + "ghế": 7312, + "thức": 7313, + "trời": 7314, + "đôi": 7315, + "phút": 7316, + "tính": 7317, + "nhé": 7318, + "ởi": 7319, + "danh": 7320, + "tổ": 7321, + "thơ": 7322, + "tên": 7323, + "ộng": 7324, + "chả": 7325, + "bữa": 7326, + "xã": 7327, + "viên": 7328, + "tha": 7329, + "hóa": 7330, + "ỗng": 7331, + "hưởng": 7332, + "mũ": 7333, + "ỏng": 7334, + "cây": 7335, + "thay": 7336, + "khẽ": 7337, + "lỗi": 7338, + "èo": 7339, + "chục": 7340, + "đúng": 7341, + "khốn": 7342, + "phố": 7343, + "tờ": 7344, + "cháu": 7345, + "khinh": 7346, + "kính": 7347, + "riêng": 7348, + "ạch": 7349, + "ngu": 7350, + "tiết": 7351, + "yên": 7352, + "dùng": 7353, + "vọng": 7354, + "ửi": 7355, + "nửa": 7356, + "iệm": 7357, + "lưng": 7358, + "bỗng": 7359, + "hiệu": 7360, + "chiếc": 7361, + "sắp": 7362, + "tù": 7363, + "gian": 7364, + "đám": 7365, + "giàu": 7366, + "dâm": 7367, + "lặng": 7368, + "gần": 7369, + "giọng": 7370, + "lộ": 7371, + "lính": 7372, + "sáu": 7373, + "ùa": 7374, + "trường": 7375, + "cúi": 7376, + "tinh": 7377, + "lưu": 7378, + "dạ": 7379, + "phó": 7380, + "coi": 7381, + "thở": 7382, + "nghìn": 7383, + "buổi": 7384, + "huyền": 7385, + "xấu": 7386, + "chợt": 7387, + "trò": 7388, + "đều": 7389, + "thanh": 7390, + "bán": 7391, + "đầy": 7392, + "toàn": 7393, + "ờn": 7394, + "phục": 7395, + "nghi": 7396, + "mua": 7397, + "hãy": 7398, + "giả": 7399, + "cãi": 7400, + "hôn": 7401, + "chữ": 7402, + "trả": 7403, + "trị": 7404, + "hành": 7405, + "sư": 7406, + "ồm": 7407, + "ngạc": 7408, + "gớm": 7409, + "thử": 7410, + "dậy": 7411, + "tốt": 7412, + "úp": 7413, + "phép": 7414, + "quê": 7415, + "sách": 7416, + "chừng": 7417, + "chiều": 7418, + "chút": 7419, + "iệng": 7420, + "ứa": 7421, + "pháp": 7422, + "kéo": 7423, + "viết": 7424, + "lôi": 7425, + "đu": 7426, + "ếp": 7427, + "mỹ": 7428, + "chương": 7429, + "chánh": 7430, + "cứu": 7431, + "iếp": 7432, + "òi": 7433, + "ói": 7434, + "ùi": 7435, + "giai": 7436, + "nghề": 7437, + "hiện": 7438, + "sân": 7439, + "ểm": 7440, + "thất": 7441, + "vật": 7442, + "ảo": 7443, + "chữa": 7444, + "bèn": 7445, + "vả": 7446, + "ưng": 7447, + "nguy": 7448, + "quanh": 7449, + "ngôn": 7450, + "đen": 7451, + "đèn": 7452, + "chào": 7453, + "luôn": 7454, + "cạnh": 7455, + "kinh": 7456, + "đông": 7457, + "phận": 7458, + "vàng": 7459, + "nguyên": 7460, + "thủ": 7461, + "chí": 7462, + "uối": 7463, + "mạnh": 7464, + "miệng": 7465, + "phương": 7466, + "dẫu": 7467, + "gật": 7468, + "xo": 7469, + "đạo": 7470, + "ảm": 7471, + "iệp": 7472, + "địa": 7473, + "đoạn": 7474, + "khóc": 7475, + "vư": 7476, + "đào": 7477, + "ợn": 7478, + "hạng": 7479, + "đê": 7480, + "dại": 7481, + "ngây": 7482, + "trắng": 7483, + "bích": 7484, + "cánh": 7485, + "quang": 7486, + "ruột": 7487, + "hổ": 7488, + "đổi": 7489, + "giời": 7490, + "hở": 7491, + "mật": 7492, + "tế": 7493, + "tỉnh": 7494, + "ụt": 7495, + "quyết": 7496, + "cắt": 7497, + "iễ": 7498, + "đốc": 7499, + "ĩnh": 7500, + "chán": 7501, + "trái": 7502, + "cầu": 7503, + "hút": 7504, + "mừng": 7505, + "ợp": 7506, + "hối": 7507, + "đề": 7508, + "thẳng": 7509, + "gác": 7510, + "ứng": 7511, + "nhỉ": 7512, + "thích": 7513, + "kiếm": 7514, + "phát": 7515, + "giường": 7516, + "mồm": 7517, + "ghen": 7518, + "bát": 7519, + "lễ": 7520, + "tiến": 7521, + "hận": 7522, + "bụng": 7523, + "ãn": 7524, + "cờ": 7525, + "luận": 7526, + "ẩu": 7527, + "hợp": 7528, + "ọt": 7529, + "giơ": 7530, + "lương": 7531, + "ngư": 7532, + "ghê": 7533, + "xanh": 7534, + "dù": 7535, + "giới": 7536, + "mũi": 7537, + "gắt": 7538, + "kết": 7539, + "lối": 7540, + "thù": 7541, + "giúp": 7542, + "[vi]": 7543 + }, + "merges": [ + "t h", + "i n", + "th e", + "a n", + "e r", + "o u", + "r e", + "o n", + "a t", + "e d", + "e n", + "t o", + "in g", + "an d", + "i s", + "a s", + "a l", + "o r", + "o f", + "a r", + "i t", + "e s", + "h e", + "s t", + "l e", + "o m", + "s e", + "b e", + "a d", + "o w", + "l y", + "c h", + "w h", + "th at", + "y ou", + "l i", + "v e", + "a c", + "t i", + "l d", + "m e", + "w as", + "g h", + "i d", + "l l", + "w i", + "en t", + "f or", + "a y", + "r o", + "v er", + "i c", + "h er", + "k e", + "h is", + "n o", + "u t", + "u n", + "i r", + "l o", + "w e", + "r i", + "h a", + "wi th", + "gh t", + "ou t", + "i m", + "i on", + "al l", + "a b", + "on e", + "n e", + "g e", + "ou ld", + "t er", + "m o", + "h ad", + "c e", + "s he", + "g o", + "s h", + "u r", + "a m", + "s o", + "p e", + "m y", + "d e", + "a re", + "b ut", + "om e", + "f r", + "the r", + "f e", + "s u", + "d o", + "c on", + "t e", + "a in", + "er e", + "p o", + "i f", + "the y", + "u s", + "a g", + "t r", + "n ow", + "ou n", + "th is", + "ha ve", + "no t", + "s a", + "i l", + "u p", + "th ing", + "fr om", + "a p", + "h im", + "ac k", + "at ion", + "an t", + "ou r", + "o p", + "li ke", + "u st", + "es s", + "b o", + "o k", + "u l", + "in d", + "e x", + "c om", + "s ome", + "the re", + "er s", + "c o", + "re s", + "m an", + "ar d", + "p l", + "w or", + "w ay", + "ti on", + "f o", + "c a", + "w ere", + "b y", + "at e", + "p ro", + "t ed", + "oun d", + "ow n", + "w ould", + "t s", + "wh at", + "q u", + "al ly", + "i ght", + "c k", + "g r", + "wh en", + "v en", + "c an", + "ou gh", + "in e", + "en d", + "p er", + "ou s", + "o d", + "id e", + "k now", + "t y", + "ver y", + "s i", + "a k", + "wh o", + "ab out", + "i ll", + "the m", + "es t", + "re d", + "y e", + "c ould", + "on g", + "you r", + "the ir", + "e m", + "j ust", + "o ther", + "in to", + "an y", + "wh i", + "u m", + "t w", + "as t", + "d er", + "d id", + "i e", + "be en", + "ac e", + "in k", + "it y", + "b ack", + "t ing", + "b r", + "mo re", + "a ke", + "p p", + "the n", + "s p", + "e l", + "u se", + "b l", + "sa id", + "o ver", + "ge t", + "e n", + "e r", + "c h", + "e i", + "i e", + "u n", + "i ch", + "ei n", + "s t", + "a n", + "t e", + "g e", + "a u", + "i n", + "s ch", + "d er", + "un d", + "d ie", + "d a", + "e s", + "a l", + "d en", + "a r", + "g en", + "z u", + "d e", + "h r", + "o n", + "t en", + "e l", + "o r", + "m i", + "s ie", + "da s", + "a t", + "b e", + "ein e", + "ich t", + "b er", + "l e", + "a ch", + "v er", + "s e", + "au f", + "w i", + "s o", + "t er", + "l ich", + "c k", + "u r", + "n icht", + "m m", + "b en", + "a s", + "w ar", + "r e", + "mi t", + "s ich", + "i g", + "l l", + "au s", + "i st", + "w ie", + "o ch", + "un g", + "an n", + "ü r", + "h n", + "i hr", + "s a", + "s en", + "t z", + "de m", + "ei t", + "u m", + "h at", + "wi r", + "v on", + "h a", + "s p", + "w ei", + "i er", + "r o", + "h er", + "r a", + "ein en", + "n e", + "v or", + "al s", + "an d", + "al l", + "w as", + "w o", + "r ei", + "st e", + "l ie", + "au ch", + "d u", + "d es", + "k o", + "ü ber", + "a m", + "b ei", + "h en", + "h m", + "l ei", + "a ber", + "w en", + "h l", + "g er", + "i m", + "u t", + "n ach", + "h e", + "i s", + "b r", + "f t", + "en t", + "i mm", + "j e", + "sch en", + "w er", + "s er", + "a b", + "ä n", + "m e", + "s ein", + "i t", + "o l", + "ch t", + "f ür", + "k l", + "f f", + "eine m", + "n en", + "w e", + "j a", + "u s", + "n och", + "hat te", + "t r", + "p f", + "h in", + "d i", + "ch en", + "b l", + "m an", + "r ü", + "ie l", + "s el", + "das s", + "i hn", + "mi r", + "sch l", + "ö n", + "g an", + "g t", + "ein er", + "st en", + "m ich", + "wen n", + "el l", + "g te", + "in d", + "m al", + "ge l", + "k en", + "n ur", + "mm en", + "f ü", + "er n", + "ö r", + "un ter", + "f r", + "an der", + "g r", + "i l", + "d ur", + "u ch", + "f e", + "t a", + "m en", + "m ach", + "d och", + "t i", + "dur ch", + "o s", + "g l", + "h al", + "ihr e", + "w ä", + "imm er", + "i hm", + "k ann", + "or t", + "d ann", + "l an", + "tz t", + "o der", + "hr en", + "e t", + "k ön", + "i ck", + "f a", + "in g", + "i r", + "wie der", + "da ß", + "m ein", + "f en", + "gan z", + "die se", + "st er", + "da r", + "w a", + "ge s", + "n a", + "f l", + "i gen", + "sch e", + "un gen", + "me hr", + "ß en", + "o t", + "k on", + "ge w", + "ha ben", + "ge h", + "ä t", + "s ind", + "d r", + "w el", + "un s", + "v o", + "m a", + "u te", + "sch on", + "b es", + "ge sch", + "b t", + "ch e", + "s on", + "o b", + "l a", + "p p", + "rü ck", + "s eine", + "k r", + "f re", + "ei l", + "zu m", + "u l", + "h ier", + "k t", + "i ge", + "sp r", + "k e", + "le ben", + "b st", + "z eit", + "i on", + "g ro", + "den n", + "h o", + "sch a", + "b ar", + "al le", + "ge gen", + "w ür", + "m ü", + "z e", + "wer den", + "je tzt", + "ko mmen", + "n ie", + "s ei", + "h eit", + "so ll", + "g lei", + "m eine", + "wo ll", + "n er", + "ha be", + "w ur", + "lich en", + "p er", + "as sen", + "n te", + "se hen", + "wir d", + "b is", + "g ar", + "i en", + "m us", + "u ß", + "ä r", + "st ell", + "k eit", + "z wei", + "sel bst", + "st a", + "p a", + "sa gte", + "te t", + "k am", + "s sen", + "v iel", + "u g", + "z en", + "h ei", + "m ann", + "wi ll", + "ge b", + "war en", + "ü ck", + "ä ch", + "m er", + "r u", + "w or", + "h au", + "ei gen", + "an g", + "we g", + "bl ick", + "f ra", + "all es", + "k a", + "au gen", + "f in", + "lich e", + "t o", + "un ser", + "der n", + "her r", + "n un", + "v ie", + "ch te", + "wo hl", + "f all", + "h t", + "ü n", + "et was", + "st and", + "en d", + "ä u", + "e m", + "m ö", + "te l", + "r ie", + "d ich", + "die s", + "h and", + "b in", + "ff en", + "nicht s", + "d an", + "p l", + "hn e", + "ihn en", + "es en", + "die ser", + "fr au", + "an t", + "ar t", + "di r", + "i sch", + "er st", + "glei ch", + "ko mm", + "h ör", + "ß e", + "d ig", + "se hr", + "z ei", + "sa m", + "au m", + "h ät", + "in gen", + "g ut", + "b o", + "m ut", + "ck en", + "kon nte", + "st imm", + "p ro", + "zu r", + "i tz", + "wei l", + "wür de", + "f ä", + "kön nen", + "k eine", + "f er", + "i schen", + "vo ll", + "ein es", + "se tz", + "z ie", + "de l", + "te te", + "sein er", + "ier en", + "ge st", + "zu rück", + "wur de", + "sch n", + "p r", + "lie ß", + "t ra", + "m ä", + "gen d", + "f ol", + "i k", + "schl a", + "scha ft", + "at er", + "wei ß", + "s einen", + "l assen", + "l u", + "und en", + "t eil", + "ne u", + "ier t", + "men schen", + "hm en", + "st r", + "g i", + "sa h", + "ihr en", + "el n", + "wei ter", + "ge hen", + "ig er", + "mach t", + "ta g", + "al so", + "hal ten", + "n is", + "ach t", + "ge ben", + "f or", + "o g", + "n at", + "m ar", + "de t", + "o hne", + "h aus", + "t ro", + "an ge", + "l au", + "sp iel", + "t re", + "sch r", + "in n", + "s u", + "l os", + "mach en", + "hät te", + "be g", + "wir k", + "al t", + "g lich", + "te s", + "r icht", + "fre und", + "m o", + "ihr er", + "f el", + "b el", + "so l", + "ein mal", + "e ben", + "h ol", + "h än", + "q u", + "ter n", + "h ö", + "sch w", + "re cht", + "wa hr", + "s einem", + "ste hen", + "hl en", + "in s", + "g ing", + "woll te", + "wi ssen", + "ung s", + "al d", + "as s", + "ja hr", + "m or", + "wel t", + "un der", + "zu sa", + "at ion", + "ko pf", + "lan g", + "hin ter", + "at z", + "st ra", + "an gen", + "an k", + "a de", + "gl au", + "f ach", + "hat ten", + "l o", + "f ort", + "ei cht", + "i ff", + "l er", + "m ei", + "diese m", + "k ein", + "f rei", + "fü hr", + "vo m", + "e s", + "e n", + "a i", + "o u", + "o n", + "l e", + "d e", + "r e", + "q u", + "a n", + "e r", + "en t", + "e t", + "l a", + "n e", + "i l", + "a r", + "i s", + "ai t", + "t e", + "a u", + "i n", + "qu e", + "i t", + "u r", + "s e", + "l es", + "c h", + "c e", + "m e", + "o r", + "ou r", + "a s", + "p r", + "a v", + "o m", + "ai s", + "u n", + "an t", + "ou s", + "t r", + "t i", + "l u", + "o i", + "e u", + "l le", + "s i", + "p ar", + "d es", + "an s", + "m ent", + "é t", + "es t", + "j e", + "u ne", + "a l", + "p as", + "t re", + "qu i", + "d u", + "r i", + "c on", + "s on", + "c om", + "e lle", + "d é", + "p our", + "d ans", + "l i", + "s a", + "r é", + "t ou", + "v ous", + "d i", + "v i", + "a g", + "a m", + "a t", + "ou v", + "a p", + "ti on", + "m on", + "s ur", + "c i", + "o s", + "p lu", + "s u", + "en d", + "a b", + "è re", + "ai n", + "m ais", + "o is", + "r es", + "plu s", + "é e", + "ai ent", + "m p", + "ch e", + "lu i", + "av e", + "ét ait", + "m a", + "s es", + "tou t", + "i r", + "v o", + "a c", + "s er", + "an d", + "f f", + "oi r", + "g r", + "av ait", + "é s", + "m es", + "n ous", + "eu x", + "b i", + "t er", + "c o", + "on s", + "p u", + "c es", + "g e", + "t u", + "le ur", + "pr o", + "d on", + "e ur", + "et te", + "ai re", + "ave c", + "d it", + "t é", + "i e", + "u s", + "il le", + "p er", + "com me", + "c r", + "or t", + "m i", + "e x", + "u x", + "v er", + "m o", + "è s", + "v e", + "au x", + "r a", + "j our", + "il s", + "bi en", + "c ou", + "p e", + "que l", + "p eu", + "c ette", + "t es", + "p o", + "in s", + "c u", + "m ê", + "s o", + "f ait", + "g u", + "m ar", + "ê tre", + "l o", + "it é", + "f r", + "a tion", + "en s", + "b r", + "n i", + "l é", + "d is", + "b le", + "m an", + "n é", + "pu is", + "mê me", + "qu es", + "f i", + "e l", + "ag e", + "g ar", + "m oi", + "en ce", + "on t", + "m ain", + "or s", + "au t", + "an ce", + "v en", + "m é", + "s ans", + "e m", + "s é", + "l on", + "h om", + "r o", + "u t", + "c ar", + "ab le", + "i m", + "de r", + "ch er", + "n o", + "vi e", + "au s", + "b e", + "de ux", + "en f", + "o ù", + "t en", + "p h", + "u re", + "te mp", + "p os", + "r ent", + "p é", + "f aire", + "p i", + "tr es", + "ç a", + "an g", + "end re", + "f or", + "p a", + "b on", + "s ou", + "in t", + "pr é", + "s ent", + "t ant", + "n er", + "c er", + "l à", + "l ais", + "pr ès", + "b re", + "c our", + "p et", + "i on", + "i ne", + "com p", + "l ait", + "tr ouv", + "t a", + "ent re", + "son t", + "de v", + "n u", + "temp s", + "d ou", + "r ait", + "b ou", + "qu and", + "jour s", + "l an", + "er s", + "av oir", + "ét é", + "a le", + "p re", + "f ois", + "or te", + "v é", + "m er", + "n on", + "t ous", + "j us", + "cou p", + "t s", + "hom me", + "ê te", + "a d", + "aus si", + "ur s", + "se u", + "or d", + "o b", + "m in", + "g é", + "co re", + "v a", + "v re", + "en core", + "se m", + "i te", + "au tre", + "pr is", + "peu t", + "u e", + "an te", + "m al", + "g n", + "ré p", + "h u", + "si on", + "vo tre", + "di re", + "e z", + "f em", + "leur s", + "m et", + "f in", + "c ri", + "m is", + "t our", + "r ai", + "j am", + "re gar", + "ri en", + "ver s", + "su is", + "p ouv", + "o p", + "v is", + "gr and", + "ant s", + "c or", + "re r", + "ar d", + "c é", + "t ent", + "pr es", + "v ou", + "f a", + "al ors", + "si eur", + "ai ne", + "le r", + "qu oi", + "f on", + "end ant", + "ar ri", + "eu re", + "a près", + "don c", + "it u", + "l è", + "s ait", + "t oi", + "ch a", + "ai l", + "as se", + "i mp", + "vo y", + "con n", + "p la", + "pet it", + "av ant", + "n om", + "t in", + "don t", + "d a", + "s ous", + "e mp", + "per son", + "el les", + "be au", + "par ti", + "ch o", + "pr it", + "tou jours", + "m en", + "r ais", + "jam ais", + "tr av", + "tion s", + "tr ès", + "v oi", + "r en", + "y eux", + "f er", + "v oir", + "pre mi", + "c a", + "g ne", + "h eure", + "r ou", + "e ff", + "no tre", + "ment s", + "t on", + "f ais", + "ce la", + "i er", + "rép on", + "con s", + "ai r", + "ô t", + "p endant", + "i ci", + "tou te", + "j et", + "p ort", + "ét aient", + "p en", + "h é", + "au tres", + "p ère", + "o c", + "quel ques", + "i que", + "l is", + "fem me", + "j ou", + "te ur", + "mon de", + "u se", + "n es", + "d re", + "a ff", + "r ap", + "par t", + "le ment", + "c la", + "f ut", + "quel que", + "pr endre", + "r ê", + "ai lle", + "s ais", + "ch es", + "le t", + "ch ar", + "è res", + "ent s", + "b er", + "g er", + "mo ins", + "e au", + "a î", + "j eu", + "h eur", + "é es", + "tr i", + "po int", + "m om", + "v ent", + "n ouv", + "gr an", + "tr ois", + "s ant", + "tout es", + "con tre", + "è rent", + "che z", + "ave z", + "û t", + "a lle", + "at t", + "p au", + "p orte", + "ouv er", + "b ar", + "l it", + "f ort", + "o t", + "as s", + "pr és", + "cho se", + "v it", + "mon sieur", + "h ab", + "t ête", + "j u", + "te ment", + "c tion", + "v rai", + "la r", + "c et", + "regar d", + "l ant", + "de m", + "s om", + "mom ent", + "il les", + "p le", + "p s", + "b es", + "m ère", + "c l", + "s our", + "y s", + "tr op", + "en ne", + "jus qu", + "av aient", + "av ais", + "jeu ne", + "de puis", + "person ne", + "f it", + "cer t", + "j o", + "g es", + "ou i", + "r est", + "sem b", + "c ap", + "m at", + "m u", + "lon g", + "fr an", + "f aut", + "it i", + "b li", + "che v", + "pr i", + "ent e", + "ain si", + "ch am", + "l ors", + "c as", + "d o", + "il i", + "b é", + "n os", + "an ge", + "su i", + "r it", + "cr o", + "gu e", + "d e", + "e n", + "e s", + "o s", + "l a", + "e r", + "q u", + "a r", + "a n", + "o n", + "qu e", + "a s", + "o r", + "e l", + "d o", + "a l", + "c i", + "u n", + "r e", + "a b", + "i n", + "t e", + "t o", + "s e", + "d i", + "t r", + "d a", + "c on", + "t a", + "s u", + "m i", + "c o", + "t i", + "l e", + "l os", + "n o", + "l o", + "í a", + "c u", + "c a", + "s i", + "v i", + "m e", + "p or", + "m o", + "p ar", + "r a", + "r i", + "la s", + "c h", + "r o", + "m a", + "p er", + "ó n", + "m en", + "de s", + "un a", + "m p", + "s o", + "ab a", + "p u", + "d os", + "t u", + "g u", + "er a", + "de l", + "h a", + "m u", + "l i", + "en t", + "m b", + "h ab", + "es t", + "g o", + "p a", + "r es", + "par a", + "p o", + "á s", + "m os", + "tr a", + "t en", + "an do", + "p i", + "qu i", + "b i", + "m an", + "co mo", + "v e", + "m ás", + "j o", + "ci ón", + "i s", + "t an", + "v o", + "da d", + "c e", + "a do", + "v er", + "f u", + "ci a", + "c er", + "p e", + "c as", + "c ar", + "men te", + "n i", + "su s", + "t ar", + "n a", + "f i", + "t er", + "z a", + "p ro", + "tr o", + "s a", + "l u", + "b a", + "per o", + "s er", + "c es", + "d as", + "d u", + "s in", + "e mp", + "m ar", + "l la", + "e x", + "á n", + "c or", + "i a", + "v a", + "r an", + "ch o", + "g a", + "y o", + "t os", + "c os", + "mi s", + "l es", + "t es", + "v en", + "h o", + "y a", + "en te", + "on es", + "hab ía", + "n u", + "u s", + "p as", + "h i", + "n os", + "es ta", + "la n", + "m as", + "t or", + "l le", + "h e", + "s on", + "b re", + "p re", + "ab an", + "d or", + "í an", + "i r", + "t as", + "é n", + "r u", + "en do", + "a que", + "er o", + "i o", + "qu é", + "m in", + "c ab", + "j a", + "de r", + "t al", + "é s", + "se ñ", + "or a", + "to do", + "la r", + "d on", + "g ar", + "s al", + "p r", + "cu ando", + "j e", + "h u", + "g un", + "b u", + "g i", + "d ar", + "n e", + "r as", + "de n", + "es to", + "par e", + "p en", + "é l", + "tr as", + "c an", + "b o", + "j os", + "mi en", + "pu e", + "c re", + "co mp", + "p on", + "d ía", + "tr os", + "s ab", + "so bre", + "es e", + "mb re", + "er on", + "a ñ", + "m or", + "f or", + "i do", + "por que", + "el la", + "p ri", + "g ran", + "f a", + "c en", + "di s", + "c ri", + "mu y", + "ch a", + "c al", + "es te", + "h as", + "c ó", + "g ra", + "r os", + "p os", + "o b", + "al l", + "aque l", + "j u", + "p res", + "m er", + "di jo", + "c ía", + "ent re", + "z o", + "ci ones", + "bi en", + "mb i", + "el o", + "t ó", + "in a", + "to dos", + "g en", + "ti en", + "est aba", + "de ci", + "ci o", + "h er", + "ñ o", + "l or", + "nu es", + "me di", + "l en", + "vi da", + "f e", + "al i", + "m on", + "c la", + "d re", + "pu es", + "al es", + "vo l", + "m í", + "r ar", + "b le", + "ci on", + "has ta", + "señ or", + "con o", + "a h", + "di os", + "s en", + "es a", + "ú n", + "v ar", + "s an", + "gu i", + "a c", + "o tros", + "ta do", + "bu en", + "ñ a", + "ti emp", + "ha cer", + "j er", + "f er", + "v u", + "f in", + "an a", + "as í", + "an tes", + "t in", + "ve z", + "mien to", + "j ar", + "la b", + "ch e", + "cas a", + "d r", + "es o", + "e go", + "di ó", + "an te", + "est á", + "m al", + "en cia", + "el i", + "í as", + "tiemp o", + "z ar", + "v an", + "m un", + "er ta", + "ta mbi", + "s í", + "b ar", + "a un", + "al e", + "mis mo", + "ent es", + "vi s", + "man o", + "el e", + "na da", + "se gu", + "me j", + "er ra", + "ab le", + "b e", + "ti r", + "un o", + "don de", + "to da", + "des de", + "r en", + "tambi én", + "cu er", + "per son", + "ho mbre", + "o tro", + "li b", + "tr ar", + "cu al", + "ha y", + "a u", + "ca da", + "t aba", + "i mp", + "men to", + "ten ía", + "qu er", + "er an", + "si emp", + "siemp re", + "er to", + "qu í", + "g os", + "pu és", + "el los", + "des pués", + "nu e", + "g an", + "l lo", + "in ter", + "có mo", + "tr i", + "ah ora", + "us te", + "tr aba", + "la do", + "in o", + "po co", + "er te", + "mu jer", + "i m", + "qui er", + "al gun", + "fu e", + "o jos", + "ent on", + "v os", + "es per", + "mu ch", + "o tra", + "a z", + "a d", + "in g", + "e za", + "a quí", + "ci as", + "gu a", + "mu cho", + "deci r", + "es ti", + "i dad", + "al go", + "e z", + "o cu", + "enton ces", + "di do", + "ent os", + "g ri", + "da do", + "i os", + "so l", + "dos e", + "uste d", + "qui en", + "a mi", + "un to", + "f r", + "mi r", + "mej or", + "b as", + "so lo", + "pre gun", + "tu r", + "al g", + "p la", + "to das", + "par te", + "e mb", + "c to", + "mun do", + "tien e", + "tan te", + "pa lab", + "tr an", + "aque lla", + "ci os", + "aun que", + "a y", + "cu en", + "ten er", + "f un", + "res pon", + "all í", + "x i", + "h an", + "pen s", + "con tra", + "tu ra", + "v al", + "di o", + "tr es", + "t re", + "tan to", + "ca min", + "m ó", + "es p", + "a da", + "í o", + "in s", + "ha cia", + "de j", + "est ar", + "i ón", + "g as", + "b er", + "v as", + "no che", + "é r", + "añ os", + "pa dre", + "gu s", + "á r", + "sin o", + "man os", + "ci do", + "es tu", + "a de", + "hu bi", + "vi r", + "b ri", + "ra z", + "ch i", + "pue de", + "men os", + "hab i", + "ho mb", + "ne ces", + "ma y", + "er os", + "r ía", + "he cho", + "es cu", + "l ti", + "án do", + "b us", + "cos as", + "t ú", + "es pa", + "re ci", + "c tor", + "pri m", + "di a", + "de se", + "mien tras", + "h or", + "fu er", + "i da", + "pos i", + "lan te", + "t on", + "an o", + "est as", + "p li", + "ch ar", + "lu ego", + "si ón", + "ci n", + "ti erra", + "m es", + "gu ar", + "ca do", + "en con", + "pr en", + "may or", + "f al", + "e r", + "o n", + "a n", + "t o", + "d i", + "r e", + "l a", + "i n", + "e n", + "a l", + "t a", + "c h", + "e l", + "r i", + "c o", + "t i", + "t e", + "s i", + "r a", + "u n", + "l e", + "l i", + "ch e", + "r o", + "c i", + "c a", + "s e", + "q u", + "m a", + "p o", + "s o", + "i l", + "d o", + "e s", + "v a", + "p er", + "l o", + "c on", + "d el", + "p a", + "m o", + "s a", + "p i", + "d a", + "m i", + "g i", + "s u", + "d e", + "v i", + "z i", + "m e", + "g li", + "n o", + "m en", + "v o", + "t u", + "n on", + "v e", + "t to", + "s t", + "on e", + "an o", + "ch i", + "er a", + "er e", + "f a", + "c e", + "z a", + "un a", + "b i", + "p re", + "s ta", + "o r", + "a r", + "f i", + "on o", + "t ra", + "n a", + "n el", + "n e", + "p ro", + "t ro", + "al e", + "v er", + "n i", + "c u", + "t ti", + "men te", + "del la", + "t er", + "zi one", + "g u", + "p e", + "t ta", + "an do", + "t à", + "al i", + "u o", + "qu el", + "co m", + "s en", + "co me", + "b a", + "al la", + "p ri", + "d u", + "qu es", + "l u", + "on i", + "g gi", + "pa r", + "s si", + "v en", + "in a", + "g a", + "pi ù", + "ci a", + "i m", + "co r", + "m an", + "in o", + "in i", + "t en", + "r an", + "b b", + "g o", + "s to", + "t re", + "a ve", + "a v", + "s ono", + "er i", + "a c", + "s se", + "er o", + "h a", + "s c", + "su l", + "f or", + "v ano", + "po r", + "s ti", + "su o", + "c chi", + "t an", + "z za", + "an che", + "p u", + "i o", + "t te", + "vo l", + "es s", + "s ci", + "co l", + "r u", + "p en", + "f u", + "al l", + "s so", + "s te", + "se m", + "s sa", + "d en", + "a d", + "t ri", + "de i", + "in e", + "ave va", + "men to", + "z z", + "a mo", + "g no", + "f o", + "un o", + "su a", + "g en", + "ri a", + "g e", + "st ra", + "s ì", + "c er", + "ch é", + "b u", + "a p", + "c en", + "d al", + "on a", + "s pe", + "g ni", + "b o", + "t t", + "del le", + "ques to", + "nel la", + "f f", + "d ere", + "an no", + "del l", + "un i", + "bb e", + "an ti", + "g ra", + "s p", + "en e", + "gi o", + "u to", + "qu al", + "gli a", + "qu ando", + "tu tto", + "c an", + "gli o", + "zi oni", + "ca m", + "h o", + "es so", + "s s", + "mo l", + "a t", + "lo ro", + "per ché", + "co sa", + "du e", + "po i", + "ca r", + "s co", + "ci o", + "to r", + "c co", + "c re", + "a m", + "g na", + "te m", + "pri ma", + "lu i", + "co sì", + "qu e", + "gu ar", + "ess ere", + "an i", + "con o", + "b ra", + "al le", + "m on", + "ri o", + "an co", + "cu i", + "s pi", + "vi a", + "g ran", + "gi or", + "a i", + "bi le", + "u l", + "ggi o", + "f e", + "an te", + "ma i", + "ta re", + "in ter", + "in di", + "re bbe", + "sen za", + "so lo", + "zi o", + "e d", + "en te", + "tu tti", + "sta to", + "zi a", + "d alla", + "tu ra", + "mi a", + "vi ta", + "quel la", + "qu a", + "ma r", + "do ve", + "g h", + "al lo", + "sem pre", + "zz o", + "si a", + "mo r", + "do po", + "por ta", + "d re", + "c cia", + "er ano", + "an ni", + "di o", + "chi a", + "en za", + "pro pri", + "qu i", + "m u", + "m b", + "an da", + "c ca", + "o cchi", + "ques ta", + "f fi", + "le i", + "par te", + "d on", + "r on", + "mi o", + "tan to", + "ri s", + "o gni", + "di s", + "r in", + "fa r", + "men ti", + "t el", + "anco ra", + "f ra", + "fa tto", + "man i", + "sen ti", + "p ra", + "tem po", + "es si", + "b bi", + "f in", + "a re", + "la re", + "per s", + "f on", + "b el", + "so r", + "d er", + "pre n", + "an za", + "di re", + "pi e", + "o ra", + "ver so", + "se gu", + "al tro", + "ta to", + "ca to", + "a to", + "vol ta", + "c c", + "fa re", + "pa re", + "ci ò", + "li b", + "bi li", + "n uo", + "s er", + "quel lo", + "co lo", + "p po", + "ca sa", + "tro va", + "o re", + "f er", + "r ono", + "d es", + "mol to", + "al mente", + "s ca", + "vo le", + "t ali", + "sul la", + "s ce", + "men o", + "an to", + "p un", + "s tu", + "ca pi", + "so l", + "gi u", + "m ini", + "m ano", + "z e", + "pi a", + "par ti", + "s al", + "la vo", + "ver o", + "r si", + "al tri", + "es ti", + "s cia", + "suo i", + "gli e", + "so tto", + "b ene", + "sc ri", + "t ale", + "de gli", + "n u", + "al c", + "uo mo", + "p el", + "f re", + "po te", + "es sa", + "s cu", + "si gno", + "el e", + "st ro", + "u ti", + "di a", + "si one", + "g re", + "f ini", + "ar ri", + "l un", + "c ri", + "e si", + "pa ssa", + "r à", + "men tre", + "an d", + "h anno", + "el o", + "u sci", + "gi a", + "gi à", + "di e", + "m ina", + "b e", + "ti ca", + "gior no", + "t in", + "es se", + "mo do", + "c al", + "s pa", + "propri o", + "l en", + "o ri", + "con tro", + "st ru", + "di ven", + "di sse", + "ra to", + "no i", + "v ere", + "pu ò", + "di ce", + "s an", + "es a", + "c ci", + "se con", + "re n", + "c cio", + "qual che", + "tu tta", + "g g", + "mon do", + "for ma", + "p li", + "m ma", + "pen sa", + "de va", + "tu r", + "fo sse", + "so pra", + "ta mente", + "n ess", + "qu anto", + "ra ga", + "un que", + "ca re", + "st re", + "gran de", + "pi cco", + "guar da", + "b en", + "nel l", + "a ff", + "po ssi", + "pre sen", + "r ò", + "pa ro", + "tu a", + "v in", + "an e", + "a s", + "ste sso", + "da v", + "ne i", + "nel le", + "gh i", + "pi o", + "ta r", + "an a", + "la to", + "si d", + "f ine", + "f uo", + "m er", + "z o", + "qua si", + "ul ti", + "i to", + "su e", + "si e", + "f il", + "allo ra", + "m in", + "ven i", + "t ano", + "el lo", + "d e", + "r a", + "e s", + "d o", + "e n", + "q u", + "c o", + "a s", + "o s", + "e r", + "a r", + "s e", + "qu e", + "a n", + "i n", + "i s", + "t o", + "ã o", + "t e", + "d a", + "m a", + "e l", + "t a", + "o r", + "i a", + "r e", + "e m", + "a l", + "co m", + "p a", + "o u", + "c a", + "u m", + "r o", + "v a", + "t i", + "s o", + "m en", + "n ão", + "h a", + "co n", + "m e", + "r i", + "pa ra", + "p o", + "d i", + "s a", + "v o", + "u ma", + "c i", + "n a", + "p or", + "n o", + "g u", + "s u", + "h o", + "an do", + "t ra", + "e i", + "v i", + "e u", + "i m", + "do s", + "el e", + "r es", + "m o", + "en t", + "f i", + "l a", + "e ra", + "l e", + "de s", + "el a", + "men te", + "l h", + "p er", + "l i", + "ç ão", + "m as", + "t er", + "m u", + "es t", + "v e", + "g o", + "l o", + "u s", + "ma is", + "v er", + "c ê", + "in ha", + "vo cê", + "f a", + "t u", + "c u", + "p ar", + "com o", + "p ro", + "s i", + "m os", + "e c", + "p re", + "d as", + "ç a", + "es ta", + "s er", + "u n", + "da de", + "d is", + "f o", + "e x", + "c h", + "i r", + "ra n", + "t ar", + "en te", + "g a", + "t r", + "p e", + "t os", + "b o", + "c ia", + "p en", + "c ar", + "s en", + "su a", + "se m", + "c as", + "f or", + "to u", + "n os", + "te m", + "r ia", + "m es", + "se u", + "co r", + "o n", + "a o", + "p os", + "ra m", + "v el", + "é m", + "t en", + "po de", + "t es", + "esta va", + "c e", + "b a", + "qu ando", + "m i", + "qu er", + "men to", + "se gu", + "t as", + "is so", + "mu i", + "g ar", + "t ro", + "d u", + "fa z", + "õ es", + "p es", + "an to", + "l u", + "p i", + "i x", + "ve z", + "s im", + "j a", + "p r", + "m in", + "b e", + "ra s", + "m an", + "p res", + "est á", + "c er", + "b re", + "p as", + "d ia", + "m b", + "dis se", + "n i", + "r os", + "es se", + "v ia", + "o lh", + "is a", + "an te", + "ê n", + "z a", + "qu i", + "b i", + "t inha", + "me u", + "s ão", + "m inha", + "a c", + "ri o", + "m ar", + "a t", + "p el", + "mui to", + "ta l", + "to r", + "fo i", + "h or", + "j o", + "b em", + "g i", + "f al", + "vo l", + "po n", + "di z", + "l ar", + "gu n", + "m or", + "r u", + "par ec", + "ç o", + "do r", + "pes so", + "n e", + "f er", + "b er", + "p u", + "po is", + "in a", + "es p", + "d ar", + "en do", + "de n", + "so bre", + "co s", + "p ri", + "al i", + "mes mo", + "ç ões", + "g ra", + "se us", + "me i", + "b ra", + "vi da", + "an tes", + "b ri", + "at é", + "ên cia", + "lh e", + "ti v", + "m ã", + "al g", + "qu anto", + "s ó", + "g os", + "de r", + "t ão", + "tu do", + "ent ão", + "r ou", + "es s", + "in da", + "b al", + "in do", + "ci o", + "n do", + "j á", + "va m", + "re i", + "l es", + "ei to", + "v is", + "tem po", + "de pois", + "c ha", + "m el", + "ch e", + "l ha", + "a inda", + "faz er", + "con tra", + "p ou", + "per gun", + "de ix", + "ta mb", + "ra r", + "al a", + "v en", + "t in", + "pel o", + "tamb ém", + "fi ca", + "pre c", + "el es", + "tra n", + "ha via", + "l á", + "to dos", + "j u", + "qu al", + "c an", + "ta do", + "cas a", + "es sa", + "n as", + "g em", + "m em", + "se i", + "na da", + "sen ti", + "c ri", + "ó s", + "de u", + "ei ro", + ". .", + "f un", + "as sim", + "s ou", + "ent re", + "com e", + "i or", + "h ar", + "f e", + "por que", + "s or", + "f in", + "ta mente", + "a qui", + "cu l", + "t ó", + "for ma", + "s ar", + "ou tra", + "olh os", + "i ma", + "m im", + "a go", + "in s", + "co u", + "g ran", + "v al", + "pesso as", + "era m", + "ei ra", + "a que", + "com p", + "de i", + "p ela", + "co isa", + "m ão", + "con h", + "ca da", + "ago ra", + "ia m", + "h á", + "con s", + "su as", + "gu ém", + "o b", + "l an", + "es ti", + "á s", + "la do", + "in ter", + "ca be", + "por ta", + "n em", + "í vel", + "r is", + "j e", + "n un", + "sem pre", + "con segu", + "h as", + "tra bal", + "f u", + "le v", + "l em", + "l as", + "va i", + "tr os", + "t ante", + "te i", + "pr ó", + "que m", + "tu ra", + "on de", + "cabe ça", + "nun ca", + "men tos", + "h um", + "de le", + "ver dade", + "t á", + "h os", + "el i", + "ent es", + "m er", + "alg um", + "diz er", + "s in", + "pen as", + "n ós", + "en quanto", + "ou tro", + "l ho", + "es te", + "mel hor", + "est ar", + "g an", + "b ar", + "pri mei", + "a u", + "i u", + "pen sa", + "a penas", + "p ra", + "es tou", + "con te", + "res pon", + "ho mem", + "do is", + "a do", + "c al", + "a b", + "l os", + "ç as", + "pou co", + "sen hor", + "t ando", + "esp era", + "pa i", + "ri os", + "no i", + "i da", + "ba ix", + "as e", + "is as", + "f r", + "ho ra", + "mu ndo", + "pas sa", + "fi car", + "to do", + "se ja", + "al mente", + "â n", + "c lar", + "a d", + "in c", + "f os", + "lo n", + "g ri", + "ou vi", + "v em", + "g e", + "ta va", + "á rio", + "mo n", + "s os", + "in ho", + "ma l", + "t an", + "t re", + "gran de", + "ran do", + "b u", + "v ou", + "ê s", + "co isas", + "a conte", + "lh er", + "g en", + "ci on", + "an os", + "i do", + "tal vez", + "est ão", + "li v", + "sa b", + "su r", + "ou tros", + "c re", + "qual quer", + "g ou", + "t ri", + "l í", + "tiv esse", + "ra do", + "prec isa", + "mã e", + "su s", + "t anto", + "de la", + "men os", + "s al", + "en tra", + "p é", + "ma ior", + "noi te", + "ti va", + "p ala", + "so n", + "ra ção", + "de us", + "s as", + "un i", + "l or", + "u l", + "in te", + "f ei", + "an o", + "par ti", + "pala v", + "tr ás", + "par te", + "b el", + "ci dade", + "lu gar", + "v os", + "vez es", + "do u", + "en contra", + "tr u", + "e ci", + "a r", + "e r", + "a n", + "e n", + "i n", + "i r", + "o r", + "d e", + "a k", + "ı n", + "a l", + "d i", + "d a", + "b u", + "b ir", + "y or", + "i l", + "e k", + "y a", + "m a", + "l a", + "e l", + "u n", + "k a", + "l ar", + "i m", + "d ı", + "e t", + "o n", + "d u", + "o l", + "e y", + "t ı", + "m i", + "h a", + "b a", + "l er", + "ü n", + "m ı", + "i z", + "l e", + "ı r", + "m e", + "i s", + "n e", + "o k", + "t a", + "s a", + "u m", + "r a", + "g ö", + "i k", + "s ı", + "d en", + "e s", + "b il", + "t i", + "l ı", + "ü z", + "i ç", + "ü r", + "g i", + "u r", + "t e", + "b en", + "d an", + "i y", + "ı m", + "u z", + "v e", + "c ak", + "a y", + "c e", + "i ş", + "ın ı", + "i yor", + "ba ş", + "d ü", + "a t", + "a m", + "g el", + "de ğ", + "k ar", + "i ̇", + "m u", + "e v", + "ö y", + "bu n", + "v ar", + "ya p", + "s en", + "an a", + "s un", + "in i", + "gö r", + "y ı", + "k i", + "l i", + "ar a", + "al ı", + "on u", + "ç ı", + "ş ey", + "s ın", + "k ı", + "ka d", + "s e", + "t an", + "a ğ", + "değ il", + "s in", + "ü k", + "a z", + "ç ok", + "s on", + "ş ı", + "b i", + "ü l", + "t u", + "v er", + "iç in", + "g e", + "k en", + "ey e", + "ol du", + "mı ş", + "y e", + "k al", + "m ek", + "l an", + "öy le", + "yor du", + "er i", + "y üz", + "mi ş", + "b e", + "m ak", + "o la", + "in e", + "y an", + "h er", + "c ek", + "yor um", + "b ak", + "ü m", + "ö n", + "lar ı", + "o ğ", + "d er", + "kad ar", + "h al", + "ar ı", + "s t", + "s an", + "ın da", + "du r", + "g ün", + "v a", + "y ok", + "y er", + "dı m", + "k o", + "da ha", + "l u", + "ın a", + "di m", + "e m", + "bil ir", + "ik i", + "s iz", + "s i", + "n a", + "di ğ", + "s u", + "b ü", + "ha y", + "s or", + "dü ş", + "ü ç", + "un u", + "ö r", + "d ir", + "m ü", + "c a", + "am an", + "f ak", + "a da", + "e de", + "son ra", + "h iç", + "ak i", + "ğ ı", + "bu l", + "r u", + "ma z", + "an la", + "bu ra", + "ge ç", + "ma ya", + "l en", + "k onu", + "c i", + "c u", + "d in", + "t ek", + "z aman", + "el er", + "ö z", + "dı r", + "gi bi", + "o t", + "ş a", + "g er", + "ler i", + "k im", + "k u", + "fak at", + "y ar", + "gö z", + "c ı", + "yor sun", + "b ek", + "in de", + "r o", + "p ek", + "bun u", + "l ik", + "m an", + "il er", + "e di", + "ö l", + "s ür", + "b in", + "s ır", + "çı k", + "sı l", + "al ar", + "k es", + "y ak", + "ç ek", + "yı l", + "e cek", + "ı z", + "gi t", + "ka p", + "a ma", + "ı l", + "lar ın", + "b iz", + "tı r", + "o y", + "an cak", + "d oğ", + "ç a", + "b ana", + "ş im", + "baş la", + "l ü", + "ma dı", + "ben i", + "t ir", + "y ük", + "lı k", + "be ş", + "b el", + "b er", + "m er", + "na sıl", + "tı k", + "k e", + "t ür", + "a v", + ". .", + "d aki", + "p ar", + "t er", + "ce ğ", + "t en", + "z ı", + "iy i", + "d ok", + "ben im", + "c ağ", + "n er", + "y en", + "ş u", + "me z", + "düş ün", + "ken di", + "şim di", + "y ol", + "y u", + "de v", + "is te", + "s ek", + "ma m", + "s öyle", + "di k", + "t o", + "k ur", + "oldu ğ", + "s ını", + "t ar", + "bil iyor", + "k an", + "y al", + "m eye", + "mu ş", + "f a", + "ka ç", + "bil e", + "iy e", + "t ü", + "e f", + "tı m", + "ev et", + "ç o", + "y et", + "g en", + "bura da", + "t im", + "bir az", + "es i", + "k or", + "doğ ru", + "in in", + "kı z", + "di ye", + "d ör", + "et ti", + "on un", + "is ti", + "ğ i", + "h e", + "s ana", + "ü ş", + "ar ka", + "hay ır", + "kar şı", + "h ar", + "il e", + "h ak", + "ı yor", + "ne den", + "s ev", + "sı z", + "ço cu", + "me m", + "ç alı", + "ol ur", + "b ır", + "g ir", + "is e", + "i h", + "c an", + "k ır", + "d ön", + "b öyle", + "sen i", + "! \"", + "al t", + "dör t", + "s öy", + "o ş", + "mu sun", + "la ş", + "h an", + "i p", + "ka y", + "h em", + "bü yük", + "a ç", + "bır ak", + "mi sin", + "s öz", + "u l", + "değ iş", + "ün ü", + "g ül", + "k ö", + "kar ı", + "ta mam", + "ol u", + "r ar", + "yen i", + "la m", + "mış tı", + "ya ş", + "al a", + "in iz", + "kad ın", + "bun un", + "m ey", + "al tı", + "y i", + "s o", + "in den", + "sen in", + "ya t", + "to p", + "s er", + "is i", + "d ün", + "s es", + "hiç bir", + "y on", + "d ın", + "t ün", + "baş ka", + "a s", + "he p", + "i t", + "ir mi", + "dev am", + "ola cak", + "ar tık", + "r e", + "dur um", + "im iz", + "üz el", + "ler ini", + "sa ğ", + "p ro", + "ger ek", + "y irmi", + "ş ek", + "ba ğ", + "me di", + "lar a", + "a h", + "t ur", + "y ür", + "ma sı", + "ka tı", + "de di", + "g ü", + "sor un", + "el i", + "ün e", + "mı z", + "yap ı", + "m il", + "ğ ını", + "t ara", + "m en", + "ha t", + "var dı", + "m et", + "konu ş", + "ar ak", + "lar ak", + "çocu k", + "bü tün", + "l ey", + "d ür", + "g üzel", + "ay ı", + "yap a", + "n ı", + "ay r", + "ö ne", + "yordu m", + "b an", + "i̇ ş", + "du m", + "un a", + "on a", + "yor lar", + "lar ını", + "çı kar", + "z an", + "se ç", + "l iyor", + "t ak", + "şı k", + "tek rar", + "a ş", + "e ş", + "miş ti", + "f ar", + "k in", + "im i", + "i f", + "e ğ", + "gi di", + "le ş", + "başla dı", + "gi de", + "ot ur", + "d de", + "ın dan", + "üz er", + "ın ın", + "n ız", + "u y", + "ye di", + "ka t", + "o larak", + "la dı", + "yal nız", + "ba h", + "iy et", + "m al", + "s ak", + "a çık", + "sın da", + ".. .", + "in san", + "ay nı", + "e der", + "is tan", + "uz un", + "sa h", + "d o", + "g eri", + "er ek", + "ol an", + "ger çek", + "f en", + "al an", + "dı ş", + "alı k", + "far k", + "ü st", + "sa de", + "r i", + "k iş", + "l dı", + "z or", + "et ir", + "her kes", + "s al", + "ö mer", + "s el", + "un da", + "ha f", + "bun a", + "y dı", + "pek i", + "ada m", + "ha z", + "sın a", + "kap ı", + "gör üş", + "sade ce", + "al dı", + "gel di", + "i e", + "n ie", + "n a", + "r z", + "s z", + "c z", + "p o", + "s t", + "c h", + "i ę", + "d z", + "n i", + "a ł", + "r a", + "j e", + "r o", + "d o", + "s ię", + "z a", + "g o", + "e m", + "w i", + "c i", + "rz e", + "k o", + "l e", + "l i", + "w a", + "t o", + "k a", + "m i", + "ż e", + "t a", + "w ie", + "b y", + "m o", + "w y", + "rz y", + "ł a", + "j a", + "n o", + "ł o", + "w o", + "p a", + "m a", + "t e", + "t y", + "n y", + "k i", + "d a", + "n e", + "dz ie", + "dz i", + "cz y", + "c ie", + "m y", + "p rze", + "d y", + "o d", + "l a", + "k ie", + "r y", + "st a", + "j ą", + "ó w", + "c e", + "p rzy", + "c o", + "k u", + "m ie", + "sz y", + "cz e", + "r e", + "b a", + "s i", + "b ie", + "m u", + "w e", + "c y", + "ni a", + "ś ci", + "sz e", + "je st", + "k t", + "s a", + "b o", + "t u", + "ż y", + "n ą", + "b i", + "r u", + "a le", + "kt ó", + "p ra", + "ał a", + "m nie", + "p ie", + "ł y", + "cz a", + "ja k", + "ro z", + "r ó", + "l u", + "z na", + "g a", + "ra z", + "ł u", + "ta k", + "j u", + "p i", + "ś ć", + "s o", + "wi a", + "m ó", + "ch o", + "w szy", + "p e", + "s po", + "c a", + "g dy", + "w ał", + "w ię", + "d e", + "b e", + "p ro", + "ł em", + "j ę", + "s k", + "z e", + "l o", + "g i", + "r ę", + "do b", + "d u", + "ju ż", + "st o", + "b ę", + "ał em", + "sz a", + "m e", + "po d", + "d la", + "pa n", + "n ę", + "z o", + "mo że", + "ś li", + "s ie", + "ał o", + "t em", + "l ko", + "ny ch", + "po wie", + "c ię", + "s u", + "ty lko", + "i n", + "b u", + "na j", + "ch a", + "te go", + "p u", + "s ki", + "ne go", + "wszy st", + "sz cze", + "je d", + "je j", + "t wo", + "ą d", + "ś my", + "cz ę", + "wa ć", + "je go", + "ż a", + "i m", + "s y", + "pra w", + "ty m", + "któ ry", + "ał y", + "t rze", + "nie j", + "s e", + "ny m", + "i ch", + "o b", + ". .", + "g ło", + "ją c", + "mó wi", + "s ka", + "o n", + "ne j", + "s łu", + "w ła", + "bę dzie", + "d ę", + "p ó", + "be z", + "ni c", + "p ła", + "ś cie", + "mi a", + "s ą", + "t rzy", + "kie m", + "by ł", + "mo g", + "ro bi", + "ta m", + "c u", + "te n", + "m ię", + "z y", + "pe w", + "ci a", + "my ś", + "prze d", + "s ko", + "n u", + "któ re", + "a l", + "l ę", + "w sze", + "ą c", + "by ło", + "so bie", + "p y", + "ci ą", + "ba r", + "je szcze", + "h a", + "t ę", + "b ra", + "cza s", + "sz ę", + "g ł", + "k ę", + "ma r", + "cz u", + "prze z", + "f i", + "s ło", + "w z", + "k to", + "k ów", + "cz o", + "li śmy", + "st ra", + "wię c", + "r ą", + "ma m", + "w ó", + "rz a", + "g ro", + "no ści", + "f a", + "we t", + "ną ł", + "ś mie", + "na wet", + "mu si", + "s wo", + "te j", + "w ą", + "w u", + "wi ą", + "ni u", + "cz ą", + "b li", + "dz o", + "s kie", + "n em", + "je śli", + "cze go", + "ch y", + "d ł", + "ty ch", + "by m", + "ż o", + "e ś", + "si ą", + "kie dy", + "na s", + "w ró", + "dz e", + "d ro", + "t ra", + "r ów", + "pa ni", + "z ie", + "ku l", + "na d", + "ch wi", + "ni m", + "t ro", + "by ć", + "cho dzi", + "ni o", + "dob rze", + "te raz", + "wo kul", + "co ś", + "k ł", + "pie r", + "h e", + "g dzie", + "dz y", + "p ię", + "d ź", + "k ą", + "g ó", + "z da", + "ch ce", + "st ę", + "o r", + "ś wia", + "wszyst ko", + "st ro", + "pe ł", + "wie m", + "wie l", + "ka ż", + "ki m", + "rz u", + "s ły", + "jed na", + "z u", + "myś l", + "mó j", + "g u", + "wa r", + "jest em", + "ó ż", + "mie j", + "mo ż", + "k ła", + "re sz", + "d łu", + "st wo", + "n ię", + "ma sz", + "że by", + "nie m", + "ja kie", + "st y", + "ni ą", + "we j", + "o j", + "g ra", + "s ła", + "no ść", + "z ło", + "sz czę", + ".. .", + "r i", + "le j", + "we go", + "c ał", + "dzi ał", + "ki ch", + "dz a", + "dz ię", + "o czy", + "zo sta", + "cz ło", + "na m", + "ki l", + "o na", + "sz u", + "w ę", + "pa r", + "mi ał", + "st rze", + "ce j", + "e j", + "zna j", + "da ć", + "miej s", + "k ró", + "k ry", + "bar dzo", + "si a", + "z i", + "ś nie", + "l ą", + "g ie", + "cie bie", + "d ni", + "st u", + "po trze", + "wokul ski", + "u wa", + "u mie", + "jedna k", + "k ra", + "wró ci", + "czło wie", + "czy ć", + "by ła", + "że li", + "m ę", + "c ę", + "z robi", + "mog ę", + "pro wa", + "r em", + "nie ch", + "cz nie", + "k ro", + "t ą", + "ch ci", + "b ro", + "dzie ć", + "sz ą", + "pa d", + "t rz", + "t ru", + "je m", + "a ni", + "t ów", + "a r", + "d ru", + "ta j", + "rze kł", + "sa m", + "st e", + "nie go", + "ta kie", + "w ała", + "to wa", + "ka pła", + "wi dzi", + "po dob", + "dz ę", + "t ał", + "stę p", + "b ą", + "po ko", + "w em", + "g ę", + "a by", + "g e", + "al bo", + "s pra", + "z no", + "de n", + "s mo", + "je sz", + "k się", + "jest eś", + "po z", + "ni gdy", + "k sią", + "c óż", + "w s", + "po w", + "t ka", + "ś wie", + "sz ka", + "sa mo", + "s ł", + "rz ę", + "na le", + "chce sz", + "ni k", + "p ę", + "chy ba", + "cią g", + "ją cy", + "wo j", + "na sze", + "mnie j", + "wię cej", + "z wy", + "o sta", + "f e", + "wa ż", + "h o", + "se r", + "śmie r", + "wie r", + "dz ą", + "za ś", + "gdy by", + "ja ki", + "wo l", + "wi n", + "d ą", + "ści a", + "roz ma", + "wa l", + "pa nie", + "sta r", + "ka z", + "je żeli", + "d em", + "w ra", + "ko ń", + "sie bie", + "zno wu", + "p ró", + "cz em", + "st wa", + "i sto", + "pó ł", + "d ał", + "ko bie", + "ała m", + "wy ch", + "ce sa", + "ni ch", + "za wsze", + "dzi ć", + "te ż", + "le pie", + "pro szę", + "k re", + "t wa", + "o t", + "ł ą", + "ch u", + "c ą", + "p rz", + "ł e", + "sze dł", + "od powie", + "my śli", + "ś wią", + "e n", + "e r", + "d e", + "a n", + "e t", + "i j", + "i n", + "e l", + "a a", + "s t", + "o r", + "g e", + "i s", + "a t", + "i e", + "c h", + "o n", + "e en", + "h et", + "i t", + "v er", + "aa r", + "a l", + "o or", + "g en", + "v an", + "o p", + "d en", + "h e", + "o m", + "t e", + "w e", + "i k", + "r e", + "z e", + "ij n", + "d at", + "b e", + "d er", + "in g", + "o e", + "ij k", + "a an", + "ch t", + "v oor", + "l e", + "i et", + "r o", + "m o", + "k en", + "z ijn", + "m en", + "i g", + "j e", + "n iet", + "a r", + "o o", + "i d", + "u n", + "i l", + "s ch", + "mo et", + "st e", + "u r", + "o l", + "he b", + "u it", + "g el", + "w ij", + "a s", + "m e", + "t en", + "w or", + "o u", + "v en", + "l en", + "aa t", + "d it", + "m et", + "r a", + "b en", + "s p", + "o ver", + "d ie", + "n o", + "w er", + "l ijk", + "f t", + "s l", + "an d", + "v e", + "t er", + "i er", + "i en", + "t o", + "d aar", + "g r", + "b el", + "de ze", + "d u", + "a g", + "k an", + "wor den", + "in gen", + "moet en", + "n en", + "on der", + "heb ben", + "r u", + "oo k", + "s en", + "c t", + "k t", + "no g", + "aa l", + "w as", + "u l", + "e er", + "b ij", + "m ijn", + "p ro", + "v ol", + "d o", + "k om", + "at ie", + "e ft", + "k el", + "al s", + "r ij", + "he id", + "a f", + "st el", + "m aar", + "a p", + "we e", + "a d", + "he eft", + "w aar", + "i cht", + "d an", + "er en", + "n e", + "w el", + "w at", + "w il", + "a cht", + "aa g", + "ge b", + "c on", + "z o", + "k e", + "b et", + "h ij", + "d ig", + "k un", + "u w", + "d t", + "d oor", + "t ij", + "a m", + "an g", + "on d", + "er s", + "is ch", + "ge en", + "i ge", + "ge v", + "ve el", + "n u", + "m a", + "on s", + "o f", + "b l", + "n aar", + "g ro", + "p l", + "an der", + "at en", + "kun nen", + "e cht", + "h ier", + "g oe", + "an t", + "u s", + "t wee", + "on t", + "de lijk", + "el e", + "u ur", + "al le", + "t oe", + "me er", + "i st", + "n a", + "n ie", + "on ze", + "l o", + "i m", + "p en", + "h ad", + "tij d", + "h oe", + "to t", + "z ou", + "a k", + "aa k", + "a men", + "d r", + "w oor", + "s e", + "wor dt", + "o t", + "gel ijk", + "g aan", + "i c", + "g er", + "k er", + "el d", + "e m", + "h ou", + "de l", + "z en", + "z el", + "te gen", + "b o", + "kom en", + "c om", + "i gen", + "e it", + "wer k", + "goe d", + "z al", + "z ij", + "sl ag", + "e s", + "z ien", + "a st", + "echt er", + "it ie", + "t ie", + "el ijk", + "m is", + "isch e", + "bel an", + "h aar", + "i ch", + "b er", + "h an", + "v r", + "al e", + "c i", + "gr ijk", + "in d", + "do en", + "l and", + "belan grijk", + "p un", + "op en", + "ct ie", + "zel f", + "m ij", + "it eit", + "ste m", + "me e", + "ar en", + "al l", + "b r", + "re cht", + "d ien", + "h u", + "g aat", + "pro b", + "m oe", + "p er", + "a u", + "ul len", + "z ich", + "daar om", + "or m", + "k l", + "v o", + "en t", + "st aat", + "z it", + "du i", + "n at", + "du s", + "d s", + "ver slag", + "kel ijk", + "prob le", + "w et", + "ge m", + "c r", + "i on", + "p r", + "sch ap", + "g d", + "h un", + "z a", + "er d", + "z et", + "st aan", + "st r", + "m aal", + "in der", + "e id", + "st en", + "p ar", + "k ken", + "ge d", + "z ullen", + "re s", + "men sen", + "j aar", + "re gel", + "ie der", + "vol gen", + "ge ven", + "e ven", + "l u", + "bl ij", + "i ë", + "k o", + "u we", + "m an", + "ma ken", + "l ie", + "g a", + "oe k", + "nie uwe", + "b aar", + "h o", + "h er", + "in ter", + "ander e", + "ru ik", + "s u", + "a gen", + "or t", + "m er", + "ou w", + "st er", + "wil len", + "aa kt", + "h oo", + "an den", + "f f", + "l ig", + "t re", + "s amen", + "ze er", + "dui delijk", + "ant woor", + "he el", + "men t", + "pun t", + "hou den", + "we g", + "vr aag", + "gel e", + "een s", + "be sch", + "om en", + "er g", + "do el", + "d ag", + "sp e", + "ur en", + "ing s", + "or en", + "l ang", + "de len", + "m ar", + "ste un", + "in nen", + "p ol", + "o on", + "i de", + "s n", + "s ie", + "r icht", + "z onder", + "no dig", + "all een", + "m id", + "ra gen", + "iet s", + "ver sch", + "geb ruik", + "st u", + "ro uw", + "stel len", + "be g", + "men ten", + "v in", + "eer ste", + "l aat", + "gro ot", + "oo d", + "to ch", + "l aten", + "aar d", + "s le", + "de el", + "st and", + "pl aat", + "re e", + "bet re", + "d i", + "l id", + "uit en", + "ra cht", + "bel eid", + "g et", + "ar t", + "st ie", + "st aten", + "g gen", + "re ken", + "e in", + "al en", + "m ing", + "mo gelijk", + "gro te", + "al tijd", + "z or", + "en kel", + "w ik", + "pol itie", + "e igen", + "el k", + "han del", + "g t", + "k we", + "m aat", + "el en", + "i p", + "v rij", + "s om", + "je s", + "aa m", + "hu is", + "v al", + "we er", + "lid staten", + "k ing", + "k le", + "be d", + "gev al", + "stel l", + "a i", + "wik kel", + "kwe stie", + "t al", + "ste e", + "a b", + "h el", + "kom st", + "p as", + "s s", + "it u", + "i den", + "eer d", + "m in", + "c e", + "p o", + "twee de", + "proble em", + "w aren", + "us sen", + "sn el", + "t ig", + "ge w", + "j u", + "ul t", + "ne men", + "com mis", + "versch il", + "k on", + "z oek", + "k rij", + "gr aag", + "den k", + "l anden", + "re den", + "be sl", + "oe g", + "bet er", + "he den", + "m ag", + "p e", + "bo ven", + "a c", + "con t", + "f d", + "h ele", + "k r", + "v ier", + "w in", + "ge z", + "k w", + "m il", + "v or", + "he m", + "ra m", + "aa s", + "ont wikkel", + "dr ie", + "v aak", + "plaat s", + "l a", + "g ang", + "ij f", + "f in", + "nat uur", + "t ussen", + "u g", + "in e", + "d a", + "b at", + "kom t", + "w acht", + "aa d", + "u t", + "é n", + "acht er", + "geb ie", + "ver k", + "lig t", + "c es", + "nie uw", + "van d", + "s t", + "n í", + "j e", + "p o", + "c h", + "r o", + "n a", + "s e", + "t o", + "n e", + "l e", + "k o", + "l a", + "d o", + "r a", + "n o", + "t e", + "h o", + "n ě", + "v a", + "l i", + "l o", + "ř e", + "c e", + "d e", + "v e", + "b y", + "n i", + "s k", + "t a", + "n á", + "z a", + "p ro", + "v o", + "v ě", + "m e", + "v á", + "s o", + "k a", + "r á", + "v y", + "z e", + "m i", + "p a", + "t i", + "st a", + "m ě", + "n é", + "ř i", + "ř í", + "m o", + "ž e", + "m a", + "j í", + "v ý", + "j i", + "d ě", + "r e", + "d a", + "k u", + "j a", + "c i", + "r u", + "č e", + "o b", + "t ě", + "m u", + "k y", + "d i", + "š e", + "k é", + "š í", + "t u", + "v i", + "p ře", + "v í", + "s i", + "n ý", + "o d", + "so u", + "v é", + "n y", + "r i", + "d y", + "b u", + "b o", + "t y", + "l á", + "l u", + "n u", + "ž i", + "m á", + "st i", + "c í", + "z á", + "p ra", + "sk é", + "m í", + "c o", + "d u", + "d á", + "by l", + "st o", + "s a", + "t í", + "je d", + "p ří", + "p ři", + "t é", + "s í", + "č i", + "v ní", + "č a", + "d í", + "z i", + "st u", + "p e", + "b a", + "d ní", + "ro z", + "va l", + "l í", + "s po", + "k á", + "b e", + "p i", + "no u", + "ta k", + "st e", + "r y", + "l é", + "vě t", + "se m", + "p ě", + "ko n", + "ne j", + "l y", + "ko u", + "ý ch", + "b ě", + "p r", + "f i", + "p rá", + "a le", + "ja ko", + "po d", + "ž í", + "z í", + "j sou", + "j sem", + "ch o", + "l ní", + "c ké", + "t á", + "m y", + "a k", + "h u", + "va t", + "pře d", + "h la", + "k e", + "st á", + "č í", + "š i", + "s le", + "k la", + "š tě", + "lo u", + "m ů", + "z na", + "ch á", + "o r", + "p ů", + "h a", + "b i", + "ta ké", + "d ů", + "no st", + "t ře", + "te r", + "p u", + "i n", + "v r", + "ve l", + "sk u", + "v še", + "t ní", + "do b", + "by la", + "č ní", + "ja k", + "v u", + "je ho", + "b ý", + "vá ní", + "ný ch", + "po u", + "te n", + "t ři", + "v z", + "st ře", + "d va", + "h le", + "č á", + "no sti", + "c k", + "v š", + "vo u", + "s u", + "h e", + "h ra", + "je n", + "s y", + "da l", + "po z", + "s lo", + "te l", + "d ru", + "de n", + "vš ak", + "g i", + "k dy", + "by lo", + "bu de", + "st ra", + "j ší", + "m é", + "me n", + "vý ch", + "ní m", + "s m", + "ko li", + "r ů", + "t ra", + "mů že", + "ne ní", + "ho d", + "b í", + "do u", + "sk a", + "t ý", + "st ě", + "u je", + "s á", + "pě t", + "ne s", + "k rá", + "to m", + "st ví", + "v ně", + "se d", + "s vé", + "p í", + "z o", + "mu sí", + "u ž", + "tí m", + "jí cí", + "jed no", + "t r", + "ča s", + "e v", + "č ty", + "sk ý", + "ni c", + "ev ro", + "to ho", + "h y", + "k ter", + "r ní", + "st í", + "s vě", + "pa k", + "vše ch", + "k ů", + "n g", + "á d", + "chá zí", + "a ni", + "a r", + "jed na", + "bý t", + "t ro", + "k ra", + "pr vní", + "m no", + "ské ho", + "p á", + "p la", + "le m", + "ne bo", + "ke m", + "st ro", + "s la", + "né ho", + "z de", + "dal ší", + "ř a", + "čty ři", + "h rá", + "dru h", + "l ně", + "v la", + "sk ých", + "š ko", + "pů so", + "pro to", + "v ů", + "sk á", + "ve n", + "še st", + "d ně", + "je ště", + "me zi", + "te k", + "s ko", + "ch a", + "ně koli", + "be z", + "g ra", + "ji ž", + "č ně", + "j á", + "s lu", + "z ná", + "ve r", + "sed m", + "k ro", + "ta m", + "a no", + "v lá", + "o sm", + "byl y", + "vá m", + "ck ý", + "te ch", + "dě ji", + "vel mi", + "le ži", + "va la", + "l ý", + "t vo", + "spo le", + "ch u", + "stu p", + "mo ž", + "evro p", + "g e", + "sta l", + "j de", + "ch y", + "ro di", + "je jí", + "po li", + "de vět", + "s me", + "a ž", + "té to", + "re m", + "d é", + "f or", + "u ni", + "f o", + "ten to", + "a u", + "ka ž", + "nu la", + "na d", + "by ch", + "mo c", + "sto u", + "e x", + "le n", + "k do", + "z d", + "pra co", + "to mu", + "ný m", + "ži vo", + "ze m", + "f e", + "f u", + "ná sle", + "j o", + "sk y", + "ji ch", + "h á", + "mě l", + "dě la", + "j sme", + "p re", + "ni ce", + "ste j", + "ne m", + "st ní", + "he m", + "ná ro", + "z u", + "b li", + "ni t", + "pa r", + "a l", + "poz ději", + "ta ko", + "n ce", + "če r", + "ší m", + "ně co", + "vá l", + "ře j", + "krá t", + "á lní", + "u r", + ". .", + "a si", + "kter é", + "sta v", + "ma jí", + "my s", + "do bě", + "s ně", + "ce n", + "z y", + "z ku", + "t ů", + "ch od", + "s pě", + "je jich", + "sou čas", + "d r", + "va li", + "ri e", + "k te", + "pr ů", + "ze ní", + "pa t", + "a n", + "po tře", + "de m", + "d nes", + "ze mí", + "sa mo", + "zna m", + "b ra", + "má m", + "te dy", + "g o", + "hla vní", + "pou ží", + "b ní", + "ve de", + "le p", + "je k", + "pra v", + "poli ti", + "d ne", + "je m", + "le t", + "če ní", + "pro b", + "ne ž", + "dě l", + "fi l", + "č o", + "cí ch", + "st é", + "d lou", + "h i", + "a by", + "to u", + "několi k", + "d la", + "vy u", + "vi t", + "ho u", + "ck ých", + "no vé", + "či n", + "st y", + "dě lá", + "k ý", + "ob la", + "pod le", + "ra n", + "dů leži", + "ta to", + "po ku", + "ko ne", + "d ý", + "d vě", + "ž ád", + "nou t", + "t ku", + "t vr", + "cké ho", + "ro v", + "r é", + "te le", + "p sa", + "s vět", + "ti vní", + "do sta", + "te m", + "še l", + "druh é", + "s kou", + "ž o", + "jed ná", + "vý znam", + "prob lé", + "pu bli", + "vá n", + "od po", + "pod po", + "d le", + "ja ké", + "še ní", + "ví m", + "bě hem", + "na chází", + "s lou", + "pou ze", + "o tá", + "p lo", + "to vé", + "vět ši", + "ko mi", + "va jí", + "ty to", + "zá pa", + "z mě", + "mo h", + "ví ce", + "spole č", + "au to", + "pro ti", + "st ru", + "dě t", + "chá ze", + "že l", + "с т", + "е н", + "н о", + "н а", + "п р", + "т о", + "п о", + "р а", + "г о", + "к о", + "н е", + "в о", + "в а", + "е т", + "е р", + "н и", + "е л", + "и т", + "н ы", + "з а", + "р о", + "ен и", + "к а", + "л и", + "е м", + "д а", + "о б", + "л а", + "д о", + "с я", + "т ь", + "о т", + "л о", + "л ь", + "е д", + "с о", + "м и", + "р е", + "м о", + "ц и", + "пр о", + "т а", + "э то", + "к и", + "р у", + "пр и", + "т и", + "с е", + "ст а", + "в ы", + "м ы", + "в и", + "б ы", + "м а", + "е с", + "л я", + "ст и", + "л е", + "ч то", + "м е", + "р и", + "ч а", + "о д", + "е й", + "ел ь", + "ени я", + "г а", + "н у", + "с и", + "п а", + "ра з", + "б о", + "ст о", + "с у", + "с а", + "д у", + "е го", + "е ст", + "и н", + "ит ь", + "и з", + "ж е", + "м у", + "п ер", + "по д", + "ени е", + "с ь", + "к у", + "пр ед", + "но го", + "ны х", + "в ер", + "т е", + "но й", + "ци и", + "д е", + "р ы", + "д ел", + "л ю", + "в е", + "о н", + "м ен", + "г и", + "н я", + "б у", + "пр а", + "в се", + "ет ся", + "ст ь", + "ж а", + "до л", + "ж и", + "б е", + "ко н", + "с л", + "ш и", + "д и", + "ст в", + "с ко", + "ны е", + "ч и", + "ю т", + "д ер", + "ст ра", + "т ы", + "х од", + "щ и", + "з о", + "з на", + "но сти", + "ч ес", + "в ля", + "ва ть", + "о р", + "по л", + "в ет", + "та к", + "ш а", + "т у", + "с во", + "пр е", + "о на", + "ит ель", + "ны й", + "с ло", + "ка к", + "в л", + "но сть", + "х о", + "мо ж", + "п е", + "д ля", + "ни я", + "но е", + "ра с", + "дол ж", + "да р", + "т ель", + "с ка", + "п у", + "ст во", + "ко то", + "ра б", + "е е", + "ро д", + "э ти", + "с об", + "о ру", + "ж ен", + "ны м", + "ит и", + "ни е", + "ко м", + "д ет", + "ст у", + "г у", + "п и", + "ме ж", + "ени ю", + "т ер", + "раб от", + "во з", + "ци я", + "ко й", + "щ ест", + "г ра", + "з и", + "р я", + "меж ду", + "ст ва", + "в с", + "ел о", + "ш е", + "м ер", + "б а", + "з ы", + "л у", + "а ль", + "д ей", + "г ла", + "на род", + "к ти", + "пред ста", + "л ся", + "я вля", + "с ки", + "но в", + "ед ин", + "ро в", + "и с", + "ни ма", + "р ем", + "ход и", + "так же", + "д ру", + "а ть", + "сл ед", + "го во", + "на я", + "ю щи", + "ен ь", + "кото ры", + "х от", + "в у", + "и х", + "ем у", + "ч ит", + "ва ж", + "ор га", + "чес ки", + "щ е", + "к е", + "х а", + "по с", + "то м", + "бо ль", + "м не", + "па с", + "об ъ", + "пра в", + "кон ф", + "сл у", + "под дер", + "ст ви", + "на ш", + "ль ко", + "сто я", + "ну ю", + "л ем", + "ен ных", + "к ра", + "д ы", + "между народ", + "г да", + "не об", + "го су", + "ств у", + "ени и", + "госу дар", + "к то", + "и м", + "ч ест", + "р ет", + "во про", + "л ен", + "ел и", + "ро ва", + "ци й", + "на м", + "это й", + "ж ения", + "необ ходи", + "мен я", + "бы ло", + "си ли", + "ф и", + "в я", + "ш ь", + "это го", + "о ни", + "орга ни", + "бе зо", + "пр об", + "и ме", + "ре ш", + "б и", + "безо пас", + "ют ся", + "о ста", + "ен но", + "го д", + "ел а", + "предста в", + "ть ся", + "сло во", + "органи за", + "долж ны", + "это м", + "б ла", + "ч е", + "ч у", + "бла го", + "это му", + "в рем", + "с пе", + "но м", + "ени й", + "с по", + "на с", + "не т", + "з у", + "в ед", + "е ще", + "ска за", + "се й", + "ер ен", + "да н", + "са м", + "ел я", + "ра н", + "зы ва", + "явля ется", + "бу дет", + "кти в", + "т ре", + "дел е", + "м от", + "конф ерен", + "ла сь", + "ча с", + "сто ро", + "ко го", + "е з", + "не й", + "о с", + "ли сь", + "раз ору", + "пер е", + "с си", + "ны ми", + "про ц", + "го ло", + "ч ело", + "бо ле", + "чело ве", + "с ер", + "п л", + "ч ет", + "стра н", + "п я", + "бы л", + "к ла", + "то в", + "ж д", + "дел а", + "е ра", + "у же", + "со вет", + "г ен", + "безопас ности", + "ц а", + "се да", + "по з", + "от вет", + "проб лем", + "на ко", + "т ем", + "до ста", + "п ы", + "щ а", + "во й", + "су щест", + "необходи мо", + "бы ть", + "мож ет", + "д ем", + "что бы", + "е к", + "ч ер", + "у сили", + "ре с", + "ру д", + "един енных", + "д об", + "до сти", + "ств ен", + "я дер", + "год ня", + "ка за", + "се годня", + "сей час", + "то лько", + "во д", + "ес ь", + "м ного", + "бу ду", + "е в", + "ест ь", + "т ри", + "об щест", + ". .", + "я вл", + "вы сту", + "р ед", + "с чит", + "с ит", + "деле га", + "ло ж", + "это т", + "ф ор", + "к лю", + "воз мож", + "ва ния", + "б ли", + "и ли", + "в з", + "на ций", + "ско го", + "при ня", + "п ла", + "о ч", + "ить ся", + "ст е", + "на ши", + "которы е", + "а р", + "име ет", + "с от", + "зна ч", + "пер ь", + "след у", + "ен ы", + "та ки", + "объ единенных", + "ст ро", + "те перь", + "б ле", + "благо дар", + "раз в", + "а н", + "жи ва", + "оч ень", + "я т", + "бе з", + "об ес", + "г ро", + "ло сь", + "с ы", + "организа ции", + "ч лен", + "то го", + "она ль", + "ж да", + "все х", + "с вя", + "боле е", + "со в", + "ко гда", + "во т", + "к ре", + "к ры", + "по этому", + "во ль", + "о й", + "ген ера", + "ч ем", + "л ы", + "пол ити", + "в ен", + "конферен ции", + "проц ес", + "б я", + "ит е", + "от но", + "разв ити", + "а ф", + "ю щ", + "в но", + "ми р", + "ни и", + "ка я", + "а с", + "итель но", + "в то", + "ени ем", + "генера ль", + "пр от", + "вс ем", + "сам бле", + "ас самбле", + "о м", + "з д", + "с мот", + "ре ги", + "ч его", + "од нако", + "усили я", + "дей стви", + "ч но", + "у ча", + "об раз", + "во с", + "э та", + "пер его", + "гово р", + "ва м", + "мо ло", + "врем я", + "д ь", + "хот ел", + "г ру", + "за явл", + "пре доста", + "по ль", + "не е", + "ре зо", + "перего во", + "резо лю", + "к рет", + "поддер ж", + "обес пе", + "не го", + "представ ит", + "на де", + "к ри", + "ч ь", + "про ек", + "л ет", + "дру ги", + "ا ل", + "َ ا", + "و َ", + "ّ َ", + "ِ ي", + "أ َ", + "ل َ", + "ن َ", + "ال ْ", + "ه ُ", + "ُ و", + "م ا", + "ن ْ", + "م ن", + "ع َ", + "ن ا", + "ل ا", + "م َ", + "ت َ", + "ف َ", + "أ ن", + "ل ي", + "م ِ", + "ا ن", + "ف ي", + "ر َ", + "ي َ", + "ه ِ", + "م ْ", + "ق َ", + "ب ِ", + "ل ى", + "ي ن", + "إ ِ", + "ل ِ", + "و ا", + "ك َ", + "ه ا", + "ً ا", + "م ُ", + "و ن", + "ال م", + "ب َ", + "ي ا", + "ذ ا", + "س ا", + "ال ل", + "م ي", + "ي ْ", + "ر ا", + "ر ي", + "ل ك", + "م َا", + "ن َّ", + "ل م", + "إ ن", + "س ت", + "و م", + "ّ َا", + "ل َا", + "ه م", + "ّ ِ", + "ك ُ", + "ك ان", + "س َ", + "ب ا", + "د ي", + "ح َ", + "ع ْ", + "ب ي", + "ال أ", + "و ل", + "ف ِي", + "ر ِ", + "د ا", + "مِ نْ", + "ُو نَ", + "و ْ", + "ه َا", + "ّ ُ", + "ال س", + "ال َ", + "ن ي", + "ل ْ", + "ت ُ", + "ه ل", + "ر ة", + "د َ", + "س ْ", + "ت ِ", + "ن َا", + "ر ْ", + "الل َّ", + "سا مي", + "ك ن", + "ك ل", + "ه َ", + "عَ لَ", + "ع لى", + "م ع", + "إ لى", + "ق د", + "ال ر", + "ُو ا", + "ي ر", + "ع ن", + "ي ُ", + "ن ِ", + "ب ْ", + "ال ح", + "هُ مْ", + "ق ا", + "ذ ه", + "ال ت", + "ِي نَ", + "ج َ", + "ه ذا", + "ع د", + "ال ع", + "د ْ", + "قَ الَ", + "ر ُ", + "ي م", + "ي ة", + "ن ُ", + "خ َ", + "ر ب", + "ال ك", + "و َا", + "أ نا", + "ة ِ", + "ال ن", + "ح د", + "ع ِ", + "ت ا", + "ه و", + "ف ا", + "ع ا", + "ال ش", + "ل ُ", + "ي ت", + "ذ َا", + "ي ع", + "ال ذ", + "ح ْ", + "ال ص", + "إِ نَّ", + "ج ا", + "ع لي", + "ك َا", + "ب ُ", + "ت ع", + "و ق", + "م ل", + "ل َّ", + "ي د", + "أ خ", + "ر ف", + "ت ي", + "ال ِ", + "ّ ا", + "ذ لك", + "أَ نْ", + "س ِ", + "ت وم", + "م ر", + "مَ نْ", + "ب ل", + "ال ق", + "الل ه", + "ِي َ", + "ك م", + "ذ َ", + "ع ل", + "ح ب", + "س ي", + "ع ُ", + "ال ج", + "ال د", + "ش َ", + "ت ك", + "ف ْ", + "ص َ", + "ل ل", + "د ِ", + "ب ر", + "ف ِ", + "ت ه", + "أ ع", + "ت ْ", + "ق ْ", + "الْ أَ", + "ئ ِ", + "عَ نْ", + "و ر", + "ح ا", + "ال َّ", + "م ت", + "ف ر", + "د ُ", + "ه نا", + "وَ أَ", + "ت ب", + "ة ُ", + "أ ي", + "س ب", + "ري د", + "و ج", + "كُ مْ", + "ح ِ", + "ك ْ", + "د ر", + "َا ء", + "ه ذه", + "ال ط", + "الْ مُ", + "د ة", + "ق ل", + "غ َ", + "ي وم", + "الَّ ذ", + "ك ر", + "ت ر", + "ك ِ", + "ك ي", + "عَلَ ى", + "رَ ب", + "ع ة", + "ق ُ", + "ج ْ", + "ف ض", + "ل ة", + "ه ْ", + "ر َا", + "وَ لَ", + "الْ مَ", + "أَ نَّ", + "ي َا", + "أ ُ", + "ش ي", + "اللَّ هُ", + "لَ ى", + "ق ِ", + "أ ت", + "عَلَ يْ", + "اللَّ هِ", + "ال ب", + "ض َ", + "ة ً", + "ق ي", + "ا ر", + "ب د", + "خ ْ", + "سْ تَ", + "ط َ", + "قَ دْ", + "ذه ب", + "أ م", + "ما ذا", + "وَ إِ", + "ة ٌ", + "و نَ", + "لي لى", + "و لا", + "ح ُ", + "ه ي", + "ص ل", + "ال خ", + "و د", + "لي س", + "ل دي", + "ق ال", + "كَا نَ", + "م َّ", + "ح ي", + "ت م", + "ل ن", + "وَ لَا", + "ب ع", + "يم كن", + "س ُ", + "ة َ", + "ح ت", + "ر ًا", + "ك ا", + "ش ا", + "هِ مْ", + "لَ هُ", + "ز َ", + "دا ً", + "م س", + "ك ث", + "الْ عَ", + "ج ِ", + "ص ْ", + "ف َا", + "ل ه", + "و ي", + "ع َا", + "هُ وَ", + "ب ِي", + "ب َا", + "أ س", + "ث َ", + "ل ِي", + "ر ض", + "الر َّ", + "لِ كَ", + "ت َّ", + "ف ُ", + "ق ة", + "ف عل", + "مِ ن", + "ال آ", + "ث ُ", + "س م", + "م َّا", + "بِ هِ", + "ت ق", + "خ ر", + "ل قد", + "خ ل", + "ش ر", + "أن ت", + "ل َّا", + "س ن", + "الس َّ", + "الذ ي", + "س َا", + "و ما", + "ز ل", + "و ب", + "أ ْ", + "إ ذا", + "ر ِي", + "ح ة", + "ن ِي", + "الْ حَ", + "وَ قَالَ", + "ب ه", + "ة ٍ", + "س أ", + "ر ٌ", + "ب ال", + "م ة", + "ش ْ", + "و ت", + "عن د", + "ف س", + "بَ عْ", + "ه ر", + "ق ط", + "أ ح", + "إن ه", + "و ع", + "ف ت", + "غ ا", + "هنا ك", + "ب ت", + "مِ نَ", + "س ر", + "ذَ لِكَ", + "ر س", + "حد ث", + "غ ْ", + "ّ ِي", + "ال إ", + "وَ يَ", + "ج ل", + "ا ست", + "ق ِي", + "ع ب", + "و س", + "ي ش", + "الَّذ ِينَ", + "تا ب", + "د ِي", + "ج ب", + "ك ون", + "ب ن", + "ال ث", + "لَ يْ", + "ب عد", + "وَ الْ", + "فَ أَ", + "ع م", + "هُ م", + "ت ن", + "ذ ْ", + "أ ص", + "أ ين", + "رَب ِّ", + "الذ ين", + "إِ ن", + "ب ين", + "ج ُ", + "عَلَيْ هِ", + "ح َا", + "ل و", + "ست ط", + "ظ ر", + "لَ مْ", + "ء ِ", + "كُ ل", + "ط ل", + "ت َا", + "ض ُ", + "كن ت", + "ل ًا", + "م ٌ", + "ق بل", + "ـ ـ", + "ذ ِ", + "قَ وْ", + "ص ِ", + "م ًا", + "كان ت", + "ص ا", + "ي ق", + "ال ف", + "ال نا", + "م ٍ", + "إِ نْ", + "ال نَّ", + "ج د", + "وَ مَا", + "ت ت", + "ب ح", + "م كان", + "كي ف", + "ّ ة", + "ال ا", + "ج َا", + "أ و", + "سا عد", + "ض ِ", + "إ لا", + "را ً", + "ق َا", + "ر أ", + "ع ت", + "أ حد", + "ه د", + "ض ا", + "ط ر", + "أ ق", + "ما ء", + "د َّ", + "ال با", + "م ُو", + "أَ وْ", + "ط ا", + "ق ُو", + "خ ِ", + "ت ل", + "ستط يع", + "د َا", + "الن َّا", + "إ لَى", + "وَ تَ", + "هَ ذَا", + "ب ة", + "علي ك", + "ج ر", + "ال من", + "ز ا", + "ر ٍ", + "د ع", + "ّ ًا", + "س ة", + "ثُ مَّ", + "شي ء", + "ال غ", + "ت ح", + "ر ُونَ", + "ال يوم", + "م ِي", + "ن ُوا", + "أ ر", + "تُ مْ", + "ع ر", + "ي ف", + "أ ب", + "د ًا", + "ص َا", + "الت َّ", + "أ ريد", + "ال ز", + "يَ وْ", + "إ لي", + "ج ي", + "يَ عْ", + "فض ل", + "ال إن", + "أن ه", + "n g", + "i 4", + "a n", + "s h", + "z h", + "i 2", + "ng 1", + "u 4", + "i 1", + "ng 2", + "d e", + "j i", + "a o", + "x i", + "u 3", + "de 5", + "e 4", + "i 3", + "ng 4", + "an 4", + "e n", + "u o", + "sh i4", + "an 2", + "u 2", + "c h", + "u 1", + "ng 3", + "a 1", + "an 1", + "e 2", + "a 4", + "e i4", + "o ng1", + "a i4", + "ao 4", + "h u", + "a ng1", + "l i", + "y o", + "an 3", + "w ei4", + "uo 2", + "n 1", + "en 2", + "ao 3", + "e 1", + "y u", + "q i", + "e ng2", + "zh o", + "a ng3", + "a ng4", + "a ng2", + "uo 4", + "m i", + "g e4", + "y i1", + "g uo2", + "e r", + "b i", + "a 3", + "h e2", + "e 3", + "y i2", + "d i4", + "zh ong1", + "b u4", + "g u", + "a i2", + "n 2", + "z ai4", + "sh i2", + "e ng1", + "r en2", + "o ng2", + "xi an4", + "y i", + "x u", + "n 4", + "l i4", + "en 4", + "y u2", + "e i2", + "yi2 ge4", + "o u4", + "e i3", + "d i", + "u i4", + "a 2", + "yo u3", + "ao 1", + "d a4", + "ch eng2", + "en 1", + "e ng4", + "y i4", + "s i1", + "zh i4", + "ji a1", + "yu an2", + "n i", + "t a1", + "de5 yi2ge4", + "k e1", + "sh u3", + "x i1", + "j i2", + "ao 2", + "t i", + "o u3", + "o ng4", + "xi a4", + "a i1", + "g ong1", + "zh i1", + "en 3", + "w ei2", + "j u", + "xu e2", + "q u1", + "zho u1", + "er 3", + "mi ng2", + "zho ng3", + "l i3", + "w u4", + "y i3", + "uo 1", + "e 5", + "j i4", + "xi ng2", + "ji an4", + "hu a4", + "y u3", + "uo 3", + "j i1", + "a i3", + "z uo4", + "h ou4", + "hu i4", + "e i1", + "ni an2", + "q i2", + "p i", + "d ao4", + "sh eng1", + "de 2", + "d ai4", + "u an2", + "zh e4", + "zh eng4", + "b en3", + "sh ang4", + "zh u3", + "b ei4", + "y e4", + "ch u1", + "zh an4", + "l e5", + "l ai2", + "sh i3", + "n an2", + "r en4", + "yo u2", + "k e4", + "b a1", + "f u4", + "d ui4", + "y a4", + "m ei3", + "z i4", + "xi n1", + "ji ng1", + "zh u", + "n 3", + "yo ng4", + "m u4", + "ji ao4", + "y e3", + "ji n4", + "bi an4", + "l u4", + "q i1", + "sh e4", + "xi ang1", + "o ng3", + "sh u4", + "d ong4", + "s uo3", + "gu an1", + "s an1", + "b o", + "t e4", + "d uo1", + "f u2", + "mi n2", + "l a1", + "zh i2", + "zh en4", + "o u1", + "w u3", + "m a3", + "i 5", + "z i5", + "j u4", + "er 4", + "y ao4", + "xia4 de5yi2ge4", + "s i4", + "t u2", + "sh an1", + "z ui4", + "ch u", + "yi n1", + "er 2", + "t ong2", + "d ong1", + "y u4", + "y an2", + "qi an2", + "shu3 xia4de5yi2ge4", + "ju n1", + "k e3", + "w en2", + "f a3", + "l uo2", + "zh u4", + "x i4", + "k ou3", + "b ei3", + "ji an1", + "f a1", + "di an4", + "ji ang1", + "wei4 yu2", + "xi ang4", + "zh i3", + "e ng3", + "f ang1", + "l an2", + "sh u", + "r i4", + "li an2", + "sh ou3", + "m o", + "qi u2", + "ji n1", + "h uo4", + "shu3xia4de5yi2ge4 zhong3", + "f en1", + "n ei4", + "g ai1", + "mei3 guo2", + "u n2", + "g e2", + "b ao3", + "qi ng1", + "g ao1", + "t ai2", + "d u", + "xi ao3", + "ji e2", + "ti an1", + "ch ang2", + "q uan2", + "li e4", + "h ai3", + "f ei1", + "t i3", + "ju e2", + "o u2", + "c i3", + "z u2", + "n i2", + "bi ao3", + "zhong1 guo2", + "d u4", + "yu e4", + "xi ng4", + "sh eng4", + "ch e1", + "d an1", + "ji e1", + "li n2", + "pi ng2", + "f u3", + "g u3", + "ji e4", + "w o", + "v 3", + "sh eng3", + "n a4", + "yu an4", + "zh ang3", + "gu an3", + "d ao3", + "z u3", + "di ng4", + "di an3", + "c eng2", + "ren2 kou3", + "t ai4", + "t ong1", + "g uo4", + "n eng2", + "ch ang3", + "hu a2", + "li u2", + "yi ng1", + "xi ao4", + "c i4", + "bian4 hua4", + "li ang3", + "g ong4", + "zho ng4", + "de5 yi1", + "s e4", + "k ai1", + "w ang2", + "ji u4", + "sh i1", + "sh ou4", + "m ei2", + "k u", + "s u", + "f eng1", + "z e2", + "tu2 shi4", + "t i2", + "q i4", + "ji u3", + "sh en1", + "zh e3", + "ren2kou3 bian4hua4", + "ren2kou3bian4hua4 tu2shi4", + "di4 qu1", + "y ang2", + "m en", + "men 5", + "l ong2", + "bi ng4", + "ch an3", + "zh u1", + "w ei3", + "w ai4", + "xi ng1", + "bo 1", + "b i3", + "t ang2", + "hu a1", + "bo 2", + "shu i3", + "sh u1", + "d ou1", + "s ai4", + "ch ao2", + "b i4", + "li ng2", + "l ei4", + "da4 xue2", + "f en4", + "shu3 de5", + "m u3", + "ji ao1", + "d ang1", + "ch eng1", + "t ong3", + "n v3", + "q i3", + "y an3", + "mi an4", + "l uo4", + "ji ng4", + "g e1", + "r u4", + "d an4", + "ri4 ben3", + "p u3", + "yu n4", + "hu ang2", + "wo 3", + "l v", + "h ai2", + "shi4 yi1", + "xi e1", + "yi ng3", + "w u2", + "sh en2", + "w ang3", + "gu ang3", + "li u4", + "s u4", + "shi4 zhen4", + "c an1", + "c ao3", + "xi a2", + "k a3", + "d a2", + "h u4", + "b an4", + "d ang3", + "h u2", + "z ong3", + "de ng3", + "de5yi2ge4 shi4zhen4", + "ch uan2", + "mo 4", + "zh ang1", + "b an1", + "mo 2", + "ch a2", + "c e4", + "zhu3 yao4", + "t ou2", + "j u2", + "shi4 wei4yu2", + "s a4", + "u n1", + "ke3 yi3", + "d u1", + "h an4", + "li ang4", + "sh a1", + "ji a3", + "z i1", + "lv 4", + "f u1", + "xi an1", + "x u4", + "gu ang1", + "m eng2", + "b ao4", + "yo u4", + "r ong2", + "zhi1 yi1", + "w ei1", + "m ao2", + "guo2 jia1", + "c ong2", + "g ou4", + "ti e3", + "zh en1", + "d u2", + "bi an1", + "c i2", + "q u3", + "f an4", + "xi ang3", + "m en2", + "j u1", + "h ong2", + "z i3", + "ta1 men5", + "ji 3", + "z ong1", + "zhou1 de5yi2ge4shi4zhen4", + "t uan2", + "ji ng3", + "gong1 si1", + "xi e4", + "l i2", + "li4 shi3", + "b ao1", + "g ang3", + "gu i1", + "zh eng1", + "zhi2 wu4", + "ta1 de5", + "pi n3", + "zhu an1", + "ch ong2", + "shi3 yong4", + "w a3", + "sh uo1", + "chu an1", + "l ei2", + "w an1", + "h uo2", + "q u", + "s u1", + "z ao3", + "g ai3", + "q u4", + "g u4", + "l u", + "x i2", + "h ang2", + "yi ng4", + "c un1", + "g en1", + "yi ng2", + "ti ng2", + "cheng2 shi4", + "ji ang3", + "li ng3", + "l un2", + "bu4 fen4", + "de ng1", + "xu an3", + "dong4 wu4", + "de2 guo2", + "xi an3", + "f an3", + "zh e5", + "h an2", + "h ao4", + "m i4", + "r an2", + "qi n1", + "ti ao2", + "zh an3", + "h i", + "k a", + "n o", + "t e", + "s u", + "s hi", + "t a", + "t o", + "n a", + "w a", + "o u", + "r u", + "n i", + "k u", + "k i", + "g a", + "d e", + "k o", + "m a", + "r e", + "r a", + "m o", + "t su", + "w o", + "e n", + "r i", + "s a", + "d a", + "s e", + "j i", + "h a", + "c hi", + "k e", + "te ki", + "m i", + "y ou", + "s h", + "s o", + "y o", + "y a", + "na i", + "t te", + "a ru", + "b a", + "u u", + "t ta", + "ka i", + "ka n", + "shi te", + "m e", + "d o", + "mo no", + "se i", + "r o", + "ko to", + "ka ra", + "shi ta", + "b u", + "m u", + "c h", + "su ru", + "k ou", + "g o", + "ma su", + "ta i", + "f u", + "k en", + "i u", + "g en", + "wa re", + "shi n", + "z u", + "a i", + "o n", + "o ku", + "g i", + "d ou", + "n e", + "y uu", + "i ru", + "i te", + "ji ko", + "de su", + "j u", + "ra re", + "sh u", + "b e", + "sh ou", + "s ha", + "se kai", + "s ou", + "k you", + "ma shita", + "s en", + "na ra", + "sa n", + "ke i", + "i ta", + "a ri", + "i tsu", + "ko no", + "j ou", + "na ka", + "ch ou", + "so re", + "g u", + "na ru", + "ga ku", + "re ba", + "g e", + "h o", + "i n", + "hi to", + "sa i", + "na n", + "da i", + "tsu ku", + "shi ki", + "sa re", + "na ku", + "p p", + "bu n", + "ju n", + "so no", + "ka ku", + "z ai", + "b i", + "to u", + "wa ta", + "sh uu", + "i i", + "te i", + "ka re", + "y u", + "shi i", + "ma de", + "sh o", + "a n", + "ke reba", + "shi ka", + "i chi", + "ha n", + "de ki", + "ni n", + "ware ware", + "na kereba", + "o ite", + "h ou", + "ya ku", + "ra i", + "mu jun", + "l e", + "yo ku", + "bu tsu", + "o o", + "ko n", + "o mo", + "ga e", + "nara nai", + "ta chi", + "z en", + "ch uu", + "kan gae", + "ta ra", + "to ki", + "ko ro", + "mujun teki", + "z e", + "na ga", + "ji n", + "shi ma", + "te n", + "i ki", + "i ku", + "no u", + "i masu", + "r ou", + "h on", + "ka e", + "t to", + "ko re", + "ta n", + "ki ta", + "i s", + "da tta", + "ji tsu", + "ma e", + "i e", + "me i", + "da n", + "h e", + "to ku", + "dou itsu", + "ri tsu", + "k yuu", + "h you", + "rare ta", + "kei sei", + "k kan", + "rare ru", + "m ou", + "do ko", + "r you", + "da ke", + "naka tta", + "so ko", + "ta be", + "e r", + "ha na", + "c o", + "fu ku", + "p a", + "so n", + "ya su", + "ch o", + "wata ku", + "ya ma", + "z a", + "k yo", + "gen zai", + "b oku", + "a ta", + "j a", + "ka wa", + "ma sen", + "j uu", + "ro n", + "b o", + "na tte", + "wataku shi", + "yo tte", + "ma i", + "g ou", + "ha i", + "mo n", + "ba n", + "ji shin", + "c a", + "re te", + "n en", + "o ka", + "ka gaku", + "na tta", + "p o", + "ka ru", + "na ri", + "m en", + "ma ta", + "e i", + "ku ru", + "ga i", + "ka ri", + "sha kai", + "kou i", + "yo ri", + "se tsu", + "j o", + "re ru", + "to koro", + "ju tsu", + "i on", + "sa ku", + "tta i", + "c ha", + "nin gen", + "n u", + "c e", + "ta me", + "kan kyou", + "de n", + "o oku", + "i ma", + "wata shi", + "tsuku ru", + "su gi", + "b en", + "ji bun", + "shi tsu", + "ke ru", + "ki n", + "ki shi", + "shika shi", + "mo to", + "ma ri", + "i tte", + "de shita", + "n de", + "ari masu", + "te r", + "z ou", + "ko e", + "ze ttai", + "kkan teki", + "h en", + "re kishi", + "deki ru", + "tsu ka", + "l a", + "i tta", + "o i", + "ko butsu", + "mi ru", + "sh oku", + "shi masu", + "gi jutsu", + "g you", + "jou shiki", + "a tta", + "ho do", + "ko ko", + "tsuku rareta", + "z oku", + "hi tei", + "ko ku", + "rekishi teki", + "ke te", + "o ri", + "i mi", + "ka ko", + "naga ra", + "ka karu", + "shu tai", + "ha ji", + "ma n", + "ta ku", + "ra n", + "douitsu teki", + "z o", + "me te", + "re i", + "tsu u", + "sare te", + "gen jitsu", + "p e", + "s t", + "ba i", + "na wa", + "ji kan", + "wa ru", + "r t", + "a tsu", + "so ku", + "koui teki", + "a ra", + "u ma", + "a no", + "i de", + "ka ta", + "te tsu", + "ga wa", + "ke do", + "re ta", + "mi n", + "sa you", + "tte ru", + "to ri", + "p u", + "ki mi", + "b ou", + "mu ra", + "sare ru", + "ma chi", + "k ya", + "o sa", + "kon na", + "a ku", + "a l", + "sare ta", + "i pp", + "shi ku", + "u chi", + "hito tsu", + "ha tara", + "tachi ba", + "shi ro", + "ka tachi", + "to mo", + "e te", + "me ru", + "ni chi", + "da re", + "ka tta", + "e ru", + "su ki", + "a ge", + "oo ki", + "ma ru", + "mo ku", + "o ko", + "kangae rareru", + "o to", + "tan ni", + "ta da", + "tai teki", + "mo tte", + "ki nou", + "shi nai", + "k ki", + "u e", + "ta ri", + "l i", + "ra nai", + "k kou", + "mi rai", + "pp on", + "go to", + "hi n", + "hi tsu", + "te ru", + "mo chi", + "ka tsu", + "re n", + "n yuu", + "su i", + "zu ka", + "tsu ite", + "no mi", + "su gu", + "ku da", + "tetsu gaku", + "i ka", + "ron ri", + "o ki", + "ni ppon", + "p er", + "shi mashita", + "chi shiki", + "cho kkanteki", + "su ko", + "t ion", + "ku u", + "a na", + "a rou", + "ka tte", + "ku ri", + "i nai", + "hyou gen", + "i shiki", + "do ku", + "a tte", + "a tara", + "to n", + "wa ri", + "ka o", + "sei san", + "hana shi", + "s i", + "ka ke", + "na ji", + "su nawa", + "sunawa chi", + "u go", + "su u", + "ba ra", + "le v", + "hi ro", + "i wa", + "be tsu", + "yo i", + "se ru", + "shite ru", + "rare te", + "to shi", + "se ki", + "tai ritsu", + "wa kara", + "to kyo", + "k ka", + "k yoku", + "u n", + "i ro", + "mi te", + "sa ki", + "kan ji", + "mi ta", + "su be", + "r yoku", + "ma tta", + "kuda sai", + "omo i", + "ta no", + "ware ru", + "co m", + "hitsu you", + "ka shi", + "re nai", + "kan kei", + "a to", + "ga tte", + "o chi", + "mo tsu", + "in g", + "son zai", + "l l", + "o re", + "tai shite", + "a me", + "sei mei", + "ka no", + "gi ri", + "kangae ru", + "yu e", + "a sa", + "o naji", + "yo ru", + "ni ku", + "osa ka", + "suko shi", + "c k", + "ta ma", + "kano jo", + "ki te", + "mon dai", + "a mari", + "e ki", + "ko jin", + "ha ya", + "i t", + "de te", + "atara shii", + "a wa", + "ga kkou", + "tsu zu", + "shu kan", + "i mashita", + "mi na", + "ata e", + "da rou", + "hatara ku", + "ga ta", + "da chi", + "ma tsu", + "ari masen", + "sei butsu", + "mi tsu", + "he ya", + "yasu i", + "d i", + "de ni", + "no ko", + "ha ha", + "do mo", + "ka mi", + "su deni", + "na o", + "ra ku", + "i ke", + "a ki", + "me ta", + "l o", + "ko domo", + "so shite", + "ga me", + "ba kari", + "to te", + "ha tsu", + "mi se", + "moku teki", + "da kara", + "s z", + "e l", + "g y", + "e n", + "t t", + "e m", + "a n", + "a k", + "e r", + "a z", + "a l", + "e t", + "o l", + "e g", + "e k", + "m i", + "o n", + "é s", + "c s", + "a t", + "á r", + "h o", + "e z", + "á l", + "i s", + "á n", + "o r", + "a r", + "e gy", + "e s", + "é r", + "á t", + "o tt", + "e tt", + "m eg", + "t a", + "o k", + "o s", + "ho gy", + "n em", + "é g", + "n y", + "k i", + "é l", + "h a", + "á s", + "ü l", + "i n", + "mi n", + "n a", + "e d", + "o m", + "i k", + "k ö", + "m a", + "n i", + "v a", + "v ol", + "é t", + "b b", + "f el", + "i g", + "l e", + "r a", + "é n", + "t e", + "d e", + "a d", + "ó l", + "b e", + "on d", + "j a", + "r e", + "u l", + "b en", + "n ek", + "u t", + "vol t", + "b an", + "ö r", + "o g", + "a p", + "o d", + "á g", + "n k", + "é k", + "v al", + "k or", + "a m", + "i l", + "í t", + "á k", + "b a", + "u d", + "sz er", + "min d", + "o z", + "é p", + "el l", + "ér t", + "m ond", + "i t", + "sz t", + "n ak", + "a mi", + "n e", + "ő l", + "cs ak", + "n é", + "ma g", + "ol y", + "m er", + "ál l", + "án y", + "ö n", + "ö l", + "min t", + "m ár", + "ö tt", + "na gy", + "é sz", + "az t", + "el ő", + "t ud", + "o t", + "é ny", + "á z", + "m ég", + "kö z", + "el y", + "s ég", + "en t", + "s em", + "ta m", + "h et", + "h al", + "f i", + "a s", + "v an", + "ho z", + "v e", + "u k", + "k ez", + "á m", + "v el", + "b er", + "a j", + "u nk", + "i z", + "va gy", + "m os", + "sz em", + "em ber", + "f og", + "mer t", + "ü k", + "l en", + "ö s", + "e j", + "t al", + "h at", + "t ak", + "h i", + "m ás", + "s ág", + "ett e", + "l eg", + "ü nk", + "h át", + "sz a", + "on y", + "ez t", + "mind en", + "en d", + "ül t", + "h an", + "j ó", + "k is", + "á j", + "in t", + "ú gy", + "i d", + "mos t", + "ar t", + "í r", + "k er", + "i tt", + "a tt", + "el t", + "mond ta", + "k ell", + "l á", + "ak i", + "ál t", + "ér d", + "t ö", + "l an", + "v ár", + "h ol", + "t el", + "l át", + "ő k", + "v et", + "s e", + "ut án", + "k ét", + "na p", + "í v", + "ál y", + "v ég", + "ö k", + "i r", + "d ul", + "v is", + "né z", + "t er", + "á ban", + "k ül", + "ak kor", + "k ap", + "sz él", + "y en", + "ú j", + "i m", + "oly an", + "es en", + "k ed", + "h ely", + "t ör", + "b ól", + "el m", + "r á", + "ár a", + "r ó", + "l ó", + "vol na", + "t an", + "le het", + "e bb", + "t en", + "t ek", + "s ok", + "k al", + "f or", + "u g", + "ol t", + "k a", + "ek et", + "b or", + "f ej", + "g ond", + "a g", + "ak ar", + "f él", + "ú l", + "b el", + "ott a", + "mi t", + "val ami", + "j el", + "é d", + "ar c", + "u r", + "hal l", + "t i", + "f öl", + "á ba", + "ol g", + "ki r", + "ol d", + "m ar", + "k érd", + "j ár", + "ú r", + "sz e", + "z s", + "él et", + "j át", + "o v", + "u s", + "é z", + "v il", + "v er", + "ő r", + "á d", + "ö g", + "le sz", + "on t", + "b iz", + "k oz", + "á bb", + "kir ály", + "es t", + "a b", + "en g", + "ig az", + "b ar", + "ha j", + "d i", + "o b", + "k od", + "r ól", + "v ez", + "tö bb", + "sz ó", + "é ben", + "ö t", + "ny i", + "t á", + "sz ól", + "gond ol", + "eg ész", + "í gy", + "ő s", + "o bb", + "os an", + "b ől", + "a bb", + "c i", + "ő t", + "n ál", + "k ép", + "azt án", + "v i", + "t art", + "be szél", + "m en", + "elő tt", + "a szt", + "ma j", + "kö r", + "han g", + "í z", + "in cs", + "a i", + "é v", + "ó d", + "ó k", + "hoz z", + "t em", + "ok at", + "an y", + "nagy on", + "h áz", + "p er", + "p ed", + "ez te", + "et len", + "nek i", + "maj d", + "sz ony", + "án ak", + "fel é", + "egy szer", + "j e", + "ad t", + "gy er", + "ami kor", + "f oly", + "sz ak", + "ő d", + "h ú", + "á sz", + "am ely", + "h ar", + "ér e", + "il yen", + "od a", + "j ák", + "t ár", + "á val", + "l ak", + "t ó", + "m ent", + "gy an", + "él y", + "ú t", + "v ar", + "kez d", + "m ell", + "mi kor", + "h ez", + "val ó", + "k o", + "m es", + "szer et", + "r end", + "l et", + "vis sza", + "ig en", + "f ő", + "va s", + "as szony", + "r ől", + "ped ig", + "p i", + "sz ép", + "t ák", + "ö v", + "an i", + "vil ág", + "p en", + "mag a", + "t et", + "sz ik", + "é j", + "én t", + "j ött", + "s an", + "sz í", + "i de", + "g at", + "ett em", + "ul t", + "h ány", + "ás t", + "a hol", + "ők et", + "h ár", + "k el", + "n ő", + "cs i", + "tal ál", + "el te", + "lá tt", + "tör t", + "ha gy", + "e sz", + "s en", + "n él", + "p ar", + "v ál", + "k ut", + "l ány", + "ami t", + "s ő", + "ell en", + "mag át", + "in k", + "u gyan", + "kül ön", + "a sz", + "mind ig", + "l ép", + "tal án", + "u n", + "sz or", + "k e", + "il lan", + "n incs", + "z et", + "vagy ok", + "tel en", + "is mer", + "s or", + "is ten", + "ít ott", + "j obb", + "v es", + "dul t", + "j uk", + "sz en", + "r o", + "ö m", + "l ett", + "k ar", + "egy ik", + "b ár", + "sz i", + "sz ív", + "az on", + "e szt", + "föl d", + "kut y", + "p illan", + "f ér", + "k om", + "t ől", + "t ű", + "é be", + "t ött", + "bar át", + "í g", + "a hogy", + "e h", + "e p", + "s o", + "v en", + "jel ent", + "t at", + "sz eg", + "mint ha", + "f al", + "egy en", + "mi l", + "sza b", + "r i", + "é m", + "biz ony", + "j on", + "ör eg", + "d olg", + "cs ap", + "ti szt", + "áll t", + "an cs", + "id ő", + "k at", + "ü gy", + "mi ért", + "ó t", + "ü r", + "cs in", + "h az", + "b et", + "én ek", + "v ér", + "j ól", + "al att", + "m ely", + "l o", + "sem mi", + "ny ug", + "v ág", + "kö vet", + "ös sze", + "ma d", + "l i", + "a cs", + "fi ú", + "kö n", + "más ik", + "j ön", + "sz ám", + "g er", + "s ó", + "r ész", + "k ér", + "z el", + "é vel", + "e o", + "e u", + "a n", + "eu l", + "eu n", + "eo n", + "a e", + "d a", + "a l", + "s s", + "i n", + "i l", + "a g", + "an g", + "y eon", + "y eo", + "d o", + "c h", + "n g", + "j i", + "h an", + "g a", + "g o", + "u i", + "h ae", + "a m", + "u l", + "u n", + "g eo", + "s i", + "n eun", + "ss da", + "s eo", + "eon g", + "y o", + "i da", + "t t", + "k k", + "j eo", + "d eul", + "w a", + "eu m", + "g e", + "o n", + "o g", + "s al", + "m an", + "yeon g", + "geo s", + "h ag", + "an eun", + "j a", + "g i", + "s u", + "i ss", + "o l", + "d ae", + "eo b", + "h a", + "j u", + "eo l", + "g eu", + "j eong", + "s ae", + "do e", + "g eul", + "s eu", + "s in", + "eul o", + "b n", + "s ang", + "bn ida", + "h al", + "b o", + "han eun", + "m al", + "i m", + "m o", + "b u", + "jeo g", + "sae ng", + "in eun", + "an h", + "m a", + "sal am", + "j o", + "s a", + "eo m", + "n ae", + "w i", + "l o", + "g wa", + "yeo l", + "n a", + "e seo", + "y e", + "m yeon", + "tt ae", + "h w", + "j e", + "eob s", + "j ang", + "g u", + "g w", + "il eul", + "yeo g", + "j eon", + "si g", + "j ag", + "j in", + "y u", + "o e", + "s e", + "hag o", + "d eun", + "y a", + "m un", + "s eong", + "g ag", + "h am", + "d ang", + "b a", + "l eul", + "s il", + "do ng", + "kk a", + "b al", + "da l", + "han da", + "eo ssda", + "ae g", + "l i", + "ha ji", + "s eon", + "o ng", + "hae ssda", + "d e", + "i ssda", + "e ge", + "b un", + "m ul", + "ju ng", + "ji g", + "m u", + "iss neun", + "b i", + "g eun", + "seu bnida", + "w on", + "p p", + "d aneun", + "eo h", + "d eo", + "ga m", + "j al", + "hae ng", + "ag o", + "y ang", + "b ul", + "b ang", + "u m", + "s o", + "h i", + "j ae", + "si m", + "saeng gag", + "hag e", + "s og", + "eo ss", + "d an", + "ja sin", + "j il", + "eo g", + "g yeong", + "doe n", + "go ng", + "m i", + "ch i", + "d eu", + "d eon", + "hae ss", + "d u", + "n am", + "eun g", + "jo h", + "n al", + "m yeong", + "w o", + "eon a", + "i go", + "g yeol", + "y ag", + "gw an", + "ul i", + "yo ng", + "n o", + "l yeo", + "j og", + "eoh ge", + "ga t", + "b og", + "mo s", + "t ong", + "ch a", + "man h", + "jeo l", + "geo l", + "h oe", + "ag a", + "n aneun", + "g an", + "un eun", + "ch eol", + "ch e", + "do l", + "b on", + "b an", + "ba d", + "ch u", + "ham yeon", + "yeo ssda", + "i bnida", + "g ye", + "eo s", + "hw al", + "salam deul", + "ji man", + "dang sin", + "ji b", + "ttae mun", + "m ae", + "i b", + "e neun", + "eu g", + "jeo m", + "geul eon", + "h wa", + "a ssda", + "b eob", + "bu t", + "b ae", + "yeo ss", + "ch in", + "ch aeg", + "g eon", + "g ae", + "nae ga", + "i ga", + "m og", + "sig an", + "g il", + "h yeon", + "l yeog", + "gu g", + "p yeon", + "s an", + "w ae", + "j ul", + "s eul", + "deun g", + "haji man", + "eum yeon", + "p il", + "m ol", + "n eu", + "a ss", + "n yeon", + "t ae", + "h u", + "p yo", + "s ul", + "g ang", + "j ineun", + "b eon", + "ha da", + "seo l", + "si p", + "dal eun", + "a p", + "sal m", + "g yo", + "ch eon", + "hag i", + "in a", + "cheol eom", + "g al", + "il a", + "kka ji", + "anh neun", + "ha bnida", + "tt eon", + "n u", + "hae seo", + "doen da", + "s ol", + "tt al", + "l a", + "il o", + "seu b", + "b yeon", + "m yeo", + "b eol", + "s on", + "n un", + "j un", + "j am", + "j eung", + "tt o", + "e n", + "mo m", + "h o", + "ch im", + "hw ang", + "eun eun", + "jo ng", + "bo da", + "n ol", + "n eom", + "but eo", + "jig eum", + "eobs da", + "dae lo", + "i g", + "y ul", + "p yeong", + "seon eun", + "sal ang", + "seu t", + "h im", + "n an", + "h eom", + "h yang", + "p i", + "gw ang", + "eobs neun", + "hw ag", + "ge ss", + "jag i", + "il eon", + "wi hae", + "dae han", + "ga ji", + "m eog", + "j yeo", + "cha j", + "b yeong", + "eo d", + "g yeo", + "do n", + "eo ji", + "g ul", + "mo deun", + "j on", + "in saeng", + "geul ae", + "h ang", + "sa sil", + "si b", + "ch al", + "il ago", + "doe l", + "g eum", + "doe neun", + "b ol", + "ga jang", + "geul igo", + "e l", + "h yeong", + "haeng bog", + "ch ul", + "h on", + "ch ae", + "s am", + "m ang", + "in da", + "da m", + "w ol", + "ch oe", + "d ul", + "si jag", + "ch eong", + "il aneun", + "ul ineun", + "ae n", + "kk e", + "mun je", + "a do", + "t eu", + "g un", + "geun eun", + "b ge", + "ch eo", + "b aeg", + "ju g", + "t a", + "sang dae", + "geu geos", + "do g", + "eu s", + "deu s", + "ja b", + "h yeo", + "tt eohge", + "u g", + "ma j", + "ch il", + "s wi", + "j ileul", + "ch ang", + "g aneun", + "m ag", + "i ji", + "da go", + "m in", + "yo han", + "t eug", + "pp un", + "al eul", + "haeng dong", + "p o", + "m il", + "ch am", + "se sang", + "e do", + "p an", + "man deul", + "am yeon", + "a b", + "kk ae", + "b ag", + "i deul", + "p um", + "m eol", + "s un", + "n eul", + "ham kke", + "chu ng", + "da b", + "yu g", + "s ag", + "gwang ye", + "il eohge", + "bal o", + "neun de", + "ham yeo", + "go s", + "geul eoh", + "an ila", + "bang beob", + "da si", + "b yeol", + "g yeon", + "gam jeong", + "on eul", + "j aneun", + "yeo m", + "l ago", + "i gi", + "hw an", + "t eul", + "eo seo", + "si k", + "ch o", + "jag a", + "geul eom", + "geul eona", + "jeong do", + "g yeog", + "geul eohge", + "geu deul", + "eu t", + "im yeon", + "j jae", + "k eun", + "i sang", + "mal haessda", + "eu ge", + "no p", + "in gan", + "bo myeon", + "t aeg", + "seu s", + "d wi", + "s aneun", + "w an", + "anh go", + "t an", + "nu gu", + "su ng", + "da myeon", + "a deul", + "p eul", + "ttal a", + "d i", + "geos do", + "a ji", + "m eon", + "eum yeo", + "dol og", + "neun g", + "mo du", + "क े", + "ह ै", + "े ं", + "् र", + "ा र", + "न े", + "य ा", + "म ें", + "स े", + "क ी", + "क ा", + "ो ं", + "त ा", + "क र", + "स ्", + "क ि", + "क ो", + "र ्", + "न ा", + "क ्", + "ह ी", + "औ र", + "प र", + "त े", + "ह ो", + "प ्र", + "ा न", + "् य", + "ल ा", + "व ा", + "ल े", + "स ा", + "है ं", + "ल ि", + "ज ा", + "ह ा", + "भ ी", + "व ि", + "इ स", + "त ी", + "न ्", + "र ा", + "म ा", + "द े", + "द ि", + "ब ा", + "त ि", + "थ ा", + "न ि", + "क ार", + "ए क", + "ही ं", + "ह ु", + "ं ग", + "ै ं", + "न ी", + "स ी", + "अ प", + "त ्", + "न हीं", + "र ी", + "म े", + "म ु", + "ि त", + "त ो", + "प ा", + "ल ी", + "लि ए", + "ग ा", + "ल ्", + "र ह", + "र े", + "क् ष", + "म ैं", + "स म", + "उ स", + "ज ि", + "त ्र", + "म ि", + "च ा", + "ो ग", + "स ं", + "द ्", + "स ि", + "आ प", + "त ु", + "द ा", + "क ु", + "य ों", + "व े", + "ज ी", + "् या", + "उ न", + "ि क", + "य े", + "भ ा", + "् ट", + "ह म", + "स् ट", + "श ा", + "ड ़", + "ं द", + "ख ा", + "म ्", + "श ्", + "य ह", + "स क", + "प ू", + "कि या", + "अप ने", + "र ू", + "स ु", + "म ी", + "ह ि", + "ज ो", + "थ े", + "र ि", + "द ी", + "थ ी", + "ग ी", + "ल ोग", + "ग या", + "त र", + "न् ह", + "च ्", + "व ार", + "ब ी", + "प ्", + "द ो", + "ट ी", + "श ि", + "कर ने", + "ग े", + "ै से", + "इ न", + "ं ड", + "सा थ", + "प ु", + "ब े", + "ब ार", + "व ी", + "अ न", + "ह र", + "उ न्ह", + "हो ता", + "ज ब", + "कु छ", + "म ान", + "क ्र", + "ब ि", + "प ह", + "फ ि", + "स र", + "ार ी", + "र ो", + "द ू", + "क हा", + "त क", + "श न", + "ब ्", + "स् थ", + "व ह", + "बा द", + "ओ ं", + "ग ु", + "ज ्", + "्र े", + "ग र", + "रह े", + "व र्", + "ह ू", + "ार ्", + "प ी", + "ब हु", + "मु झ", + "्र ा", + "दि या", + "स ब", + "कर ते", + "अप नी", + "बहु त", + "क ह", + "ट े", + "हु ए", + "कि सी", + "र हा", + "ष ्ट", + "ज ़", + "ब ना", + "स ो", + "ड ि", + "को ई", + "व ्य", + "बा त", + "र ु", + "व ो", + "मुझ े", + "द् ध", + "च ार", + "मे रे", + "व र", + "्र ी", + "जा ता", + "न ों", + "प्र ा", + "दे ख", + "ट ा", + "क् या", + "अ ध", + "ल ग", + "ल ो", + "प ि", + "य ु", + "च े", + "जि स", + "ं त", + "ान ी", + "प ै", + "ज न", + "ार े", + "च ी", + "मि ल", + "द ु", + "दे श", + "च् छ", + "ष ्", + "स ू", + "ख े", + "च ु", + "ि या", + "ल गा", + "ब ु", + "उन के", + "ज् ञ", + "क्ष ा", + "त रह", + "्या दा", + "वा ले", + "पू र्", + "मैं ने", + "का म", + "रू प", + "हो ती", + "उ प", + "ज ान", + "प्र कार", + "भ ार", + "म न", + "हु आ", + "ट र", + "हू ँ", + "पर ि", + "पा स", + "अन ु", + "रा ज", + "लोग ों", + "अ ब", + "सम झ", + "ड ी", + "म ौ", + "श ु", + "च ि", + "प े", + "क ृ", + "सक ते", + "म ह", + "य ोग", + "द र्", + "उ से", + "ं ध", + "ड ा", + "जा ए", + "ब ो", + "ू ल", + "म ो", + "ों ने", + "ं स", + "तु म", + "पह ले", + "ब ता", + "त था", + "य ो", + "ग ई", + "उ त्", + "सक ता", + "क म", + "ज ्यादा", + "र ख", + "सम य", + "ार ा", + "अ गर", + "स् त", + "च ल", + "फि र", + "वार ा", + "कर ना", + "श ी", + "ग ए", + "ब न", + "ौ र", + "हो ने", + "चा ह", + "ख ु", + "हा ँ", + "उन्ह ें", + "उन्ह ोंने", + "छ ो", + "म् ह", + "प्र ति", + "नि क", + "व न", + "्य ू", + "र ही", + "तु म्ह", + "ज ैसे", + "ि यों", + "क् यों", + "ल ों", + "फ ़", + "ं त्र", + "हो ते", + "क् ति", + "त ्य", + "कर ्", + "क ई", + "व ं", + "कि न", + "प ो", + "कार ण", + "ड़ ी", + "भ ि", + "इस के", + "ब र", + "उस के", + "द् वारा", + "श े", + "क ॉ", + "दि न", + "न् न", + "ड़ ा", + "स् व", + "नि र्", + "मु ख", + "लि या", + "ट ि", + "ज्ञ ान", + "क् त", + "द ्र", + "ग ्", + "क् स", + "म ै", + "ग ो", + "ज े", + "ट ्र", + "म ार", + "त् व", + "ध ार", + "भा व", + "कर ता", + "ख ि", + "क ं", + "चा हि", + "य र", + "प् त", + "क ों", + "ं च", + "ज ु", + "म त", + "अ च्छ", + "हु ई", + "क भी", + "ले किन", + "भ ू", + "अप ना", + "दू स", + "चाहि ए", + "य ू", + "घ र", + "सब से", + "मे री", + "ना म", + "ढ ़", + "ं ट", + "ें गे", + "ब ै", + "फ ा", + "ए वं", + "य ी", + "ग ्र", + "क्ष े", + "आ ज", + "आप को", + "भा ग", + "ठ ा", + "क ै", + "भार त", + "उन की", + "प हु", + "स भी", + "ध ा", + "ण ा", + "स ान", + "हो गा", + "त ब", + "स ंग", + "प र्", + "अ व", + "त ना", + "ग ि", + "य न", + "स् था", + "च ित", + "ट ्", + "छ ा", + "जा ने", + "क्षे त्र", + "वा ली", + "पूर् ण", + "स मा", + "कार ी", + "n h", + "k h", + "ô ng", + "ờ i", + "ấ y", + "ô i", + "ộ t", + "m ột", + "c á", + "kh ông", + "ư ời", + "t ôi", + "m à", + "ạ i", + "cá i", + "a nh", + "th ì", + "ng ười", + "i ế", + "n ó", + "nh ư", + "đ ã", + "à o", + "th ế", + "l ại", + "đ ư", + "ồ i", + "ớ i", + "c ũ", + "đ i", + "ì nh", + "cũ ng", + "ả i", + "ủ a", + "ng h", + "c ủa", + "i ệ", + "ế n", + "ằ ng", + "ợ c", + "đư ợc", + "đ ến", + "ấ t", + "ph ải", + "ữ ng", + "nh ững", + "r ồi", + "à ng", + "nó i", + "ú c", + "u ố", + "ữ a", + "ư ớ", + "à y", + "c ả", + "b à", + "â y", + "g ì", + "v à", + "ầ n", + "iế t", + "là m", + "đ ể", + "ò n", + "v ới", + "ậ t", + "m ình", + "v ào", + "â u", + "đ á", + "ướ c", + "nh à", + "n ữa", + "c òn", + "á c", + "n ào", + "ch ỉ", + "th ấy", + "ắ t", + "h ai", + "ế t", + "b iết", + "v ì", + "ư a", + "i ề", + "ắ m", + "ồ ng", + "à i", + "ê u", + "n ày", + "m ặ", + "s ự", + "tr ong", + "l ên", + "ă n", + "v iệ", + "kh i", + "á i", + "ầ u", + "ơ i", + "v ề", + "ạ n", + "ơ ng", + "ò ng", + "l úc", + "đ ứ", + "i ên", + "c ứ", + "ỏ i", + "th ôi", + "ngh ĩ", + "r ằng", + "l ắm", + "ă m", + "ậ n", + "như ng", + "ú ng", + "m ới", + "mặ t", + "gi ờ", + "ê m", + "i nh", + "việ c", + "ố i", + "á o", + "í nh", + "s ao", + "ậ y", + "ộ i", + "à n", + "t ay", + "ẫ n", + "ờ ng", + "t ư", + "h ọ", + "uố n", + "cá ch", + "a ng", + "ẳ ng", + "th ư", + "ù ng", + "l ấy", + "đ ó", + "ch ứ", + "u ng", + "á ch", + "h ỏi", + "c ụ", + "m uốn", + "ch àng", + "n ên", + "th ật", + "đ ấy", + "ệ n", + "ừ ng", + "nh au", + "t iế", + "ơ n", + "v ợ", + "đ ây", + "ế u", + "s au", + "iề n", + "s ố", + "b ả", + "đ ầu", + "â m", + "b ị", + "ở ng", + "c ô", + "ổ i", + "h ay", + "ụ c", + "n ếu", + "v ậy", + "qu an", + "ã i", + "ng ồi", + "ch ưa", + "v ẫn", + "đ ời", + "ph úc", + "uy ện", + "ử a", + "tr ước", + "ì n", + "ắ c", + "ng ay", + "uố ng", + "ứ c", + "i ể", + "ch ẳng", + "ch ồng", + "ch úng", + "ầ m", + "ạ c", + "ộ c", + "ch ị", + "ph ú", + "l òng", + "ậ u", + "à nh", + "đứ ng", + "ô n", + "ỗ i", + "th ể", + "s ẽ", + "ừ a", + "t ừ", + "r ất", + "ó c", + "t ình", + "y êu", + "ki a", + "b ố", + "ự c", + "l ư", + "i êm", + "t uy", + "đ âu", + "x uống", + "ầ y", + "ị ch", + "iề u", + "n ay", + "m ẹ", + "bả o", + "iể u", + "b ao", + "t ự", + "t iền", + "nh ìn", + "ô m", + "mà y", + "b ác", + "í ch", + "x in", + "u ộc", + "ọ i", + "tr ên", + "á ng", + "ọ n", + "đá ng", + "ề u", + "n ước", + "b ằng", + "b ạn", + "h ạ", + "x e", + "cá c", + "ề n", + "nh iên", + "c ùng", + "ư ơng", + "l ẽ", + "ch ính", + "m ấy", + "ă ng", + "n ăm", + "v ừa", + "c ậu", + "m ắt", + "l ong", + "ọ ng", + "tư ởng", + "qu á", + "iệ n", + "kh ác", + "s ợ", + "nh ân", + "ọ c", + "đ ộ", + "n àng", + "ng o", + "b ắt", + "ng ày", + "ạ y", + "ch ết", + "h ôm", + "h ơn", + "th ằng", + "tiế ng", + "xu ân", + "l iêm", + "h iểu", + "c ười", + "ẩ m", + "b ây", + "đ àn", + "r õ", + "l ời", + "x em", + "v ô", + "ch uyện", + "h ằng", + "l ần", + "ắ ng", + "m ất", + "nghĩ a", + "nh ất", + "c âu", + "ch ỗ", + "b ên", + "b áo", + "á nh", + "kh ổ", + "đ ủ", + "ướ ng", + "ố c", + "c ông", + "iế u", + "ấ m", + "h àng", + "đá p", + "ng ài", + "ố t", + "iế n", + "tr ông", + "đi ều", + "m ãi", + "ặ c", + "ạ o", + "th u", + "đ ồng", + "ch ủ", + "nh iều", + "g ái", + "x ong", + "ồ n", + "c uộc", + "á u", + "kh o", + "họ c", + "ỳ nh", + "đá nh", + "b ạc", + "th ứ", + "ngh e", + "b ọn", + "m ịch", + "ớ n", + "k ẻ", + "qu y", + "ậ p", + "ố n", + "th ành", + "c ơ", + "ấ n", + "qu ay", + "c ửa", + "qu ỳnh", + "m inh", + "đ ị", + "th ầy", + "t âm", + "v ẻ", + "è n", + "v ăn", + "th eo", + "h ết", + "i êu", + "đư ờng", + "đ o", + "ấ u", + "b ẩm", + "đư ơng", + "l ạ", + "iệ u", + "g ọi", + "t ân", + "ắ n", + "ch ơi", + "uy ên", + "c ổ", + "d ám", + "đ ồ", + "ẩ n", + "ph òng", + "ch ạy", + "d ân", + "s ướng", + "qu ả", + "ú i", + "ngh ị", + "c ần", + "b ất", + "bà n", + "ợ i", + "h ư", + "ẳ n", + "kh ó", + "mặ c", + "th ân", + "uố c", + "ố ng", + "ó ng", + "ặ p", + "h ẳn", + "t ao", + "ì m", + "ngo ài", + "ho ặc", + "h ội", + "ắ p", + "ch ân", + "b ỏ", + "m ười", + "x ưa", + "l ã", + "l âu", + "đứ a", + "ư ờng", + "ợ ng", + "kh ách", + "t ây", + "ấ p", + "chị u", + "qu ần", + "ợ t", + "th ưa", + "thư ơng", + "m ai", + "t óc", + "t ìm", + "tr ăm", + "t ức", + "ặ ng", + "nh ớ", + "â ng", + "é o", + "ạ t", + "n ỗi", + "ằ m", + "nh ận", + "ạ nh", + "cả nh", + "s inh", + "t ử", + "d ài", + "h ơi", + "ả n", + "ư ới", + "ả ng", + "việ t", + "v ị", + "h uyện", + "qu ân", + "thư ờng", + "l ớn", + "đ au", + "ặ t", + "th uốc", + "cả m", + "th ời", + "ê ng", + "đư a", + "t ội", + "đứ c", + "m ư", + "độ ng", + "là ng", + "k êu", + "uy ền", + "m ọi", + "đ ối", + "t ài", + "ỉ nh", + "í n", + "c ử", + "bố n", + "v ân", + "tuy ết", + "cơ m", + "b ộ", + "h ình", + "qu ý", + "i êng", + "và i", + "v ội", + "tr ẻ", + "lã o", + "n ằm", + "đ ừng", + "t ại", + "vu i", + "đị nh", + "đ êm", + "ch ắc", + "u a", + "ẹ p", + "ph án", + "gi ữa", + "đ em", + "th êm", + "b ình", + "th o", + "số ng", + "c ao", + "b ước", + "ụ ng", + "x a", + "h ại", + "ệ t", + "h ách", + "t ai", + "s ung", + "ớ m", + "ẩ y", + "ph á", + "c ố", + "ã y", + "ả y", + "ẹ n", + "h ồi", + "iế c", + "ng ủ", + "th iên", + "s áng", + "nh iêu", + "d ung", + "bà i", + "tiế p", + "ệ nh", + "m ay", + "n ổi", + "đ ình", + "m ời", + "ạ m", + "t ấn", + "iế m", + "ó a", + "th iếu", + "k ể", + "ả nh", + "d ễ", + "t ất", + "kh ỏi", + "quy ền", + "v âng", + "th ực", + "gi á", + "h ộ", + "ph u", + "m ỗi", + "c ầm", + "ị p", + "nh ỏ", + "kh iến", + "tr ọng", + "gi ữ", + "gi ấy", + "g ặp", + "s ở", + "ẫ u", + "h ồ", + "đ ỏ", + "ặ n", + "c àng", + "đ ại", + "th ị", + "tr ở", + "iệ t", + "v ạn", + "n ữ", + "tr í", + "á y", + "th áng", + "ch ú", + "h à", + "mư ơi", + "d ưới", + "ng ờ", + "ho a", + "đ ọc", + "đ ẹp", + "ch ung", + "m ở", + "t ối", + "đ ẻ", + "ch ờ", + "k ỳ", + "gi ận", + "ph ần", + "hạ nh", + "đ ất", + "ph ụ", + "ò a", + "n ạn", + "bu ồn", + "tu ổi", + "s ức", + "th ần", + "qu ên", + "gh ế", + "th ức", + "tr ời", + "đ ôi", + "phú t", + "t ính", + "nh é", + "ở i", + "d anh", + "t ổ", + "th ơ", + "t ên", + "ộ ng", + "ch ả", + "b ữa", + "x ã", + "v iên", + "th a", + "h óa", + "ỗ ng", + "hư ởng", + "m ũ", + "ỏ ng", + "c ây", + "th ay", + "kh ẽ", + "l ỗi", + "è o", + "ch ục", + "đ úng", + "kh ốn", + "ph ố", + "t ờ", + "ch áu", + "khi nh", + "k ính", + "r iêng", + "ạ ch", + "ng u", + "t iết", + "y ên", + "d ùng", + "v ọng", + "ử i", + "n ửa", + "iệ m", + "lư ng", + "b ỗng", + "h iệu", + "ch iếc", + "s ắp", + "t ù", + "gi an", + "đá m", + "già u", + "d âm", + "l ặng", + "g ần", + "gi ọng", + "l ộ", + "l ính", + "s áu", + "ù a", + "tr ường", + "c úi", + "t inh", + "lư u", + "d ạ", + "ph ó", + "c oi", + "th ở", + "ngh ìn", + "bu ổi", + "h uyền", + "x ấu", + "ch ợt", + "tr ò", + "đ ều", + "th anh", + "b án", + "đ ầy", + "to àn", + "ờ n", + "ph ục", + "ngh i", + "m ua", + "h ãy", + "gi ả", + "c ãi", + "h ôn", + "v ai", + "ch ữ", + "tr ả", + "tr ị", + "h ành", + "s ư", + "ồ m", + "ng ạc", + "g ớm", + "th ử", + "d ậy", + "t ốt", + "ú p", + "ph ép", + "qu ê", + "s ách", + "ch ừng", + "ch iều", + "ch út", + "iệ ng", + "ứ a", + "phá p", + "k éo", + "v iết", + "l ôi", + "đ u", + "ế p", + "m ỹ", + "ch ương", + "ch ánh", + "cứ u", + "iế p", + "ò i", + "ó i", + "ù i", + "gi ai", + "ngh ề", + "h iện", + "s ân", + "ể m", + "th ất", + "v ật", + "ả o", + "ch ữa", + "b èn", + "v ả", + "ư ng", + "ng uy", + "qu anh", + "ng ôn", + "đ en", + "đ èn", + "ch ào", + "lu ôn", + "c ạnh", + "k inh", + "đ ông", + "ph ận", + "v àng", + "ng uyên", + "th ủ", + "ch í", + "uố i", + "m ạnh", + "m iệng", + "ph ương", + "d ẫu", + "g ật", + "x o", + "đ ạo", + "ả m", + "iệ p", + "đị a", + "đo ạn", + "kh óc", + "v ư", + "đ ào", + "ợ n", + "hạ ng", + "đ ê", + "d ại", + "ng ây", + "tr ắng", + "b ích", + "cá nh", + "qu ang", + "ru ột", + "h ổ", + "đ ổi", + "gi ời", + "h ở", + "m ật", + "t ế", + "t ỉnh", + "ụ t", + "quy ết", + "c ắt", + "i ễ", + "đ ốc", + "ĩ nh", + "ch án", + "tr ái", + "c ầu", + "h út", + "m ừng", + "ợ p", + "h ối", + "đ ề", + "th ẳng", + "g ác", + "ứ ng", + "nh ỉ", + "th ích", + "k iếm", + "ph át", + "gi ường", + "m ồm", + "gh en", + "b át", + "l ễ", + "tiế n", + "h ận", + "b ụng", + "ã n", + "c ờ", + "lu ận", + "ẩ u", + "h ợp", + "ọ t", + "gi ơ", + "lư ơng", + "ng ư", + "gh ê", + "x anh", + "d ù", + "gi ới", + "mũ i", + "g ắt", + "k ết", + "l ối", + "th ù", + "gi úp" + ] + } +} \ No newline at end of file diff --git a/viXTTS/output/1120155531_chao_ban_toi_la_mot_tro_ly_ao.wav b/viXTTS/output/1120155531_chao_ban_toi_la_mot_tro_ly_ao.wav new file mode 100644 index 0000000000000000000000000000000000000000..4e4e999ca8c46b943530279e3c6f7f33af4fb3ef Binary files /dev/null and b/viXTTS/output/1120155531_chao_ban_toi_la_mot_tro_ly_ao.wav differ diff --git a/viXTTS/output/1120160032_chao_ban_toi_la_mot_tro_ly_ao.wav b/viXTTS/output/1120160032_chao_ban_toi_la_mot_tro_ly_ao.wav new file mode 100644 index 0000000000000000000000000000000000000000..73ec29ac4b083997c6db382831d5387cc2ed4801 Binary files /dev/null and b/viXTTS/output/1120160032_chao_ban_toi_la_mot_tro_ly_ao.wav differ diff --git a/viXTTS/requirements.txt b/viXTTS/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d446089cc9c23a222dc9d768ab08a8ce4c70e38 --- /dev/null +++ b/viXTTS/requirements.txt @@ -0,0 +1,108 @@ +# requirements.txt + +absl-py==2.1.0 +annotated-types==0.7.0 +audioread==3.0.1 +blis==1.0.1 +catalogue==2.0.10 +certifi==2024.8.30 +cffi==1.17.1 +charset-normalizer==3.4.0 +click==8.1.7 +cloudpathlib==0.20.0 +colorama==0.4.6 +confection==0.1.5 +contourpy==1.3.1 +coqpit==0.0.17 +cycler==0.12.1 +cymem==2.0.8 +decorator==5.1.1 +docopt==0.6.2 +einops==0.8.0 +filelock==3.16.1 +fonttools==4.55.0 +fsspec==2024.10.0 +grpcio==1.68.0 +hangul-romanize==0.1.0 +huggingface-hub==0.26.2 +idna==3.10 +Jinja2==3.1.4 +joblib==1.4.2 +kiwisolver==1.4.7 +langcodes==3.5.0 +language_data==1.3.0 +lazy_loader==0.4 +librosa==0.10.2.post1 +llvmlite==0.43.0 +marisa-trie==1.2.1 +Markdown==3.7 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.9.2 +mdurl==0.1.2 +mpmath==1.3.0 +msgpack==1.1.0 +murmurhash==1.0.10 +mutagen==1.47.0 +networkx==3.4.2 +nltk==3.9.1 +num2words==0.5.13 +numba==0.60.0 +numpy==2.0.2 +packaging==24.2 +pandas==2.2.3 +pillow==10.2.0 +platformdirs==4.3.6 +pooch==1.8.2 +preshed==3.0.9 +protobuf==5.28.3 +psutil==6.1.0 +pycparser==2.22 +pydantic==2.9.2 +pydantic_core==2.23.4 +Pygments==2.18.0 +pyparsing==3.2.0 +pypinyin==0.53.0 +python-crfsuite==0.9.11 +python-dateutil==2.9.0.post0 +pytz==2024.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rich==13.9.4 +safetensors==0.4.5 +scikit-learn==1.5.2 +scipy==1.14.1 +shellingham==1.5.4 +six==1.16.0 +smart-open==7.0.5 +soundfile==0.12.1 +soxr==0.5.0.post1 +spacy==3.8.2 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.4.8 +sympy==1.13.1 +tensorboard==2.18.0 +tensorboard-data-server==0.7.2 +thinc==8.3.2 +threadpoolctl==3.5.0 +tokenizers==0.20.3 +torch==2.5.1 +torchaudio==2.5.1+cpu +torchvision==0.20.1+cpu +tqdm==4.67.0 +trainer==0.0.36 +transformers==4.46.3 +typer==0.13.1 +typing_extensions==4.12.2 +tzdata==2024.2 +underthesea==6.8.4 +underthesea_core==1.0.4 +Unidecode==1.3.8 +urllib3==2.2.3 +vinorm==2.0.7 +wasabi==1.1.3 +weasel==0.4.1 +Werkzeug==3.1.3 +wrapt==1.16.0 diff --git a/viXTTS/samples/nam-calm.wav b/viXTTS/samples/nam-calm.wav new file mode 100644 index 0000000000000000000000000000000000000000..c8bbd1806f1b35733cec2550731c1de1a17b8332 Binary files /dev/null and b/viXTTS/samples/nam-calm.wav differ diff --git a/viXTTS/samples/nam-cham.wav b/viXTTS/samples/nam-cham.wav new file mode 100644 index 0000000000000000000000000000000000000000..cd0b285cc746c413b8560ac7b3e29e3c1d057e66 Binary files /dev/null and b/viXTTS/samples/nam-cham.wav differ diff --git a/viXTTS/samples/nam-nhanh.wav b/viXTTS/samples/nam-nhanh.wav new file mode 100644 index 0000000000000000000000000000000000000000..f4cc4875b4548117c7639787500c076dfe34f316 Binary files /dev/null and b/viXTTS/samples/nam-nhanh.wav differ diff --git a/viXTTS/samples/nam-truyen-cam.wav b/viXTTS/samples/nam-truyen-cam.wav new file mode 100644 index 0000000000000000000000000000000000000000..19f956436092b2cb56702be29cb8cdd75b614ab3 Binary files /dev/null and b/viXTTS/samples/nam-truyen-cam.wav differ diff --git a/viXTTS/samples/nu-calm.wav b/viXTTS/samples/nu-calm.wav new file mode 100644 index 0000000000000000000000000000000000000000..b1b9b8bad4c1e1c47216edaa57caaecdfe138238 Binary files /dev/null and b/viXTTS/samples/nu-calm.wav differ diff --git a/viXTTS/samples/nu-cham.wav b/viXTTS/samples/nu-cham.wav new file mode 100644 index 0000000000000000000000000000000000000000..853adef9856314faff39af1e756e0613157ea7e1 Binary files /dev/null and b/viXTTS/samples/nu-cham.wav differ diff --git a/viXTTS/samples/nu-luu-loat.wav b/viXTTS/samples/nu-luu-loat.wav new file mode 100644 index 0000000000000000000000000000000000000000..008a7715fcd171ca944e1d0ad7a9808cab06872c Binary files /dev/null and b/viXTTS/samples/nu-luu-loat.wav differ diff --git a/viXTTS/samples/nu-nhan-nha.wav b/viXTTS/samples/nu-nhan-nha.wav new file mode 100644 index 0000000000000000000000000000000000000000..88a2a4e62e23f466c18c6dceb13f34052614da2f Binary files /dev/null and b/viXTTS/samples/nu-nhan-nha.wav differ diff --git a/viXTTS/samples/nu-nhe-nhang.wav b/viXTTS/samples/nu-nhe-nhang.wav new file mode 100644 index 0000000000000000000000000000000000000000..59ea01ef940536f9beea729061730f8a0ce21dbe Binary files /dev/null and b/viXTTS/samples/nu-nhe-nhang.wav differ