diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..4eca27be6d8721a0bc619989952418132b3d59aa --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +docs +logs +output +reference +SoVITS_weights +GPT_weights +TEMP +.git diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0bb4e0bf22a20558257328203d997569b4a83722 --- /dev/null +++ b/.gitignore @@ -0,0 +1,200 @@ +.DS_Store +.vscode +__pycache__ +*.pyc +env +runtime +.idea +output +logs +reference +GPT_weights +SoVITS_weights +GPT_weights_v2 +SoVITS_weights_v2 +GPT_weights_v3 +SoVITS_weights_v3 +TEMP +weight.json +ffmpeg* +ffprobe* +cfg.json +speakers.json +ref_audios +tools/AP_BWE_main/24kto48k/* +!tools/AP_BWE_main/24kto48k/readme.txt + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc diff --git a/Docker/damo.sha256 b/Docker/damo.sha256 new file mode 100644 index 0000000000000000000000000000000000000000..6e9804da4a02caf45bac51adfb9a9b75de7c16d0 --- /dev/null +++ b/Docker/damo.sha256 @@ -0,0 +1,3 @@ +5bba782a5e9196166233b9ab12ba04cadff9ef9212b4ff6153ed9290ff679025 /workspace/tools/damo_asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pb +b3be75be477f0780277f3bae0fe489f48718f585f3a6e45d7dd1fbb1a4255fc5 /workspace/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pb +a5818bb9d933805a916eebe41eb41648f7f9caad30b4bd59d56f3ca135421916 /workspace/tools/damo_asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pb \ No newline at end of file diff --git a/Docker/download.py b/Docker/download.py new file mode 100644 index 0000000000000000000000000000000000000000..952423d1edc0b048db3865a0f25516b9726773ac --- /dev/null +++ b/Docker/download.py @@ -0,0 +1,8 @@ +# Download moda ASR related models +from modelscope import snapshot_download + +model_dir = snapshot_download( + "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4" +) +model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4") +model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4") diff --git a/Docker/download.sh b/Docker/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..447e018ead5725d2a220234398161702ff12c5f8 --- /dev/null +++ b/Docker/download.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -Eeuo pipefail + +echo "Downloading models..." + +aria2c --disable-ipv6 --input-file /workspace/Docker/links.txt --dir /workspace --continue + +echo "Checking SHA256..." + +parallel --will-cite -a /workspace/Docker/links.sha256 "echo -n {} | sha256sum -c" diff --git a/Docker/links.sha256 b/Docker/links.sha256 new file mode 100644 index 0000000000000000000000000000000000000000..cda6dc1552e87ff035bfb27ffcac76f429c780f8 --- /dev/null +++ b/Docker/links.sha256 @@ -0,0 +1,12 @@ +b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1 /workspace/GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +fc579c1db3c1e21b721001cf99d7a584214280df19b002e200b630a34fa06eb8 /workspace/GPT_SoVITS/pretrained_models/s2D488k.pth +020a014e1e01e550e510f2f61fae5e5f5b6aab40f15c22f1f12f724df507e835 /workspace/GPT_SoVITS/pretrained_models/s2G488k.pth +24164f129c66499d1346e2aa55f183250c223161ec2770c0da3d3b08cf432d3c /workspace/GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin +e53a693acc59ace251d143d068096ae0d7b79e4b1b503fa84c9dcf576448c1d8 /workspace/GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin +39796caa5db18d7f9382d8ac997ac967bfd85f7761014bb807d2543cc844ef05 /workspace/tools/uvr5/uvr5_weights/HP2_all_vocals.pth +45e6b65199e781b4a6542002699be9f19cd3d1cb7d1558bc2bfbcd84674dfe28 /workspace/tools/uvr5/uvr5_weights/HP3_all_vocals.pth +5908891829634926119720241e8573d97cbeb8277110a7512bdb0bd7563258ee /workspace/tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth +8c8fd1582f9aabc363e47af62ddb88df6cae7e064cae75bbf041a067a5e0aee2 /workspace/tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth +01376dd2a571bf3cb9cced680732726d2d732609d09216a610b0d110f133febe /workspace/tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth +56aba59db3bcdd14a14464e62f3129698ecdea62eee0f003b9360923eb3ac79e /workspace/tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth +233bb5c6aaa365e568659a0a81211746fa881f8f47f82d9e864fce1f7692db80 /workspace/tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx \ No newline at end of file diff --git a/Docker/links.txt b/Docker/links.txt new file mode 100644 index 0000000000000000000000000000000000000000..e6603db0c027547b9bf85d1e8f54837fbfe0b076 --- /dev/null +++ b/Docker/links.txt @@ -0,0 +1,34 @@ +# GPT-SoVITS models +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s1bert25hz-2kh-longer-epoch%3D68e-step%3D50232.ckpt + out=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2D488k.pth + out=GPT_SoVITS/pretrained_models/s2D488k.pth +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2G488k.pth + out=GPT_SoVITS/pretrained_models/s2G488k.pth +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/config.json + out=GPT_SoVITS/pretrained_models/chinese-hubert-base/config.json +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/preprocessor_config.json + out=GPT_SoVITS/pretrained_models/chinese-hubert-base/preprocessor_config.json +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/pytorch_model.bin + out=GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/config.json + out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/config.json +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/pytorch_model.bin + out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin +https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/tokenizer.json + out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json +# UVR5 +https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2_all_vocals.pth + out=tools/uvr5/uvr5_weights/HP2_all_vocals.pth +https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP3_all_vocals.pth + out=tools/uvr5/uvr5_weights/HP3_all_vocals.pth +https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5_only_main_vocal.pth + out=tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth +https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoAggressive.pth + out=tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth +https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoDeReverb.pth + out=tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth +https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoNormal.pth + out=tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth +https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx + out=tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..80cd9f3a106c9c997c23c0aa7c8fecab7039001b --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +# Base CUDA image +FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04 + +LABEL maintainer="breakstring@hotmail.com" +LABEL version="dev-20240209" +LABEL description="Docker image for GPT-SoVITS" + + +# Install 3rd party apps +ENV DEBIAN_FRONTEND=noninteractive +ENV TZ=Etc/UTC +RUN apt-get update && \ + apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \ + git lfs install && \ + rm -rf /var/lib/apt/lists/* + +# Copy only requirements.txt initially to leverage Docker cache +WORKDIR /workspace +COPY requirements.txt /workspace/ +RUN pip install --no-cache-dir -r requirements.txt + +# Define a build-time argument for image type +ARG IMAGE_TYPE=full + +# Conditional logic based on the IMAGE_TYPE argument +# Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite" +COPY ./Docker /workspace/Docker +# elite 类型的镜像里面不包含额外的模型 +RUN if [ "$IMAGE_TYPE" != "elite" ]; then \ + chmod +x /workspace/Docker/download.sh && \ + /workspace/Docker/download.sh && \ + python /workspace/Docker/download.py && \ + python -m nltk.downloader averaged_perceptron_tagger cmudict; \ + fi + + +# Copy the rest of the application +COPY . /workspace + +EXPOSE 9871 9872 9873 9874 9880 + +CMD ["python", "webui.py"] diff --git a/GPT_SoVITS/BigVGAN/LICENSE b/GPT_SoVITS/BigVGAN/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4c78361c86d4f685117d60d6623e2197fcfed706 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 NVIDIA CORPORATION. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/GPT_SoVITS/BigVGAN/README.md b/GPT_SoVITS/BigVGAN/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2fa70ceea647053933b913b329041ee8c41526db --- /dev/null +++ b/GPT_SoVITS/BigVGAN/README.md @@ -0,0 +1,266 @@ +## BigVGAN: A Universal Neural Vocoder with Large-Scale Training + +#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon + +[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN) + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large) + +
+ +## News +- **Sep 2024 (v2.4):** + - We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints. + +- **Jul 2024 (v2.3):** + - General refactor and code improvements for improved readability. + - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark. + +- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio. + +- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces. + +- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights: + - Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU. + - Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546). + - Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments. + - We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio. + +## Installation + +The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment: + +```shell +conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia +conda activate bigvgan +``` + +Clone the repository and install dependencies: + +```shell +git clone https://github.com/NVIDIA/BigVGAN +cd BigVGAN +pip install -r requirements.txt +``` + +## Inference Quickstart using 🤗 Hugging Face Hub + +Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input. + +```python +device = 'cuda' + +import torch +import bigvgan +import librosa +from meldataset import get_mel_spectrogram + +# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference. +model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False) + +# remove weight norm in the model and set to eval mode +model.remove_weight_norm() +model = model.eval().to(device) + +# load wav file and compute mel spectrogram +wav_path = '/path/to/your/audio.wav' +wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1] +wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time] + +# compute mel spectrogram from the ground truth audio +mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame] + +# generate waveform from mel +with torch.inference_mode(): + wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1] +wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time] + +# you can convert the generated waveform to 16 bit linear PCM +wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype +``` + +## Local gradio demo + +You can run a local gradio demo using below command: + +```python +pip install -r demo/requirements.txt +python demo/app.py +``` + +## Training + +Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset: + +```shell +cd filelists/LibriTTS && \ +ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \ +ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \ +ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \ +ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \ +ln -s /path/to/your/LibriTTS/dev-other dev-other && \ +ln -s /path/to/your/LibriTTS/test-clean test-clean && \ +ln -s /path/to/your/LibriTTS/test-other test-other && \ +cd ../.. +``` + +Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input: + +```shell +python train.py \ +--config configs/bigvgan_v2_24khz_100band_256x.json \ +--input_wavs_dir filelists/LibriTTS \ +--input_training_file filelists/LibriTTS/train-full.txt \ +--input_validation_file filelists/LibriTTS/val-full.txt \ +--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \ +--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \ +--checkpoint_path exp/bigvgan_v2_24khz_100band_256x +``` + +## Synthesis + +Synthesize from BigVGAN model. Below is an example command for generating audio from the model. +It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`. + +```shell +python inference.py \ +--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \ +--input_wavs_dir /path/to/your/input_wav \ +--output_dir /path/to/your/output_wav +``` + +`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`. +It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`. + +Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model. + +```shell +python inference_e2e.py \ +--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \ +--input_mels_dir /path/to/your/input_mel \ +--output_dir /path/to/your/output_wav +``` + +## Using Custom CUDA Kernel for Synthesis + +You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN: + +```python +generator = BigVGAN(h, use_cuda_kernel=True) +``` + +You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature. + +When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`. + +Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using. + +We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`: + +```python +python tests/test_cuda_vs_torch_model.py \ +--checkpoint_file /path/to/your/bigvgan_generator.pt +``` + +```shell +loading plain Pytorch BigVGAN +... +loading CUDA kernel BigVGAN with auto-build +Detected CUDA files, patching ldflags +Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja.. +Building extension module anti_alias_activation_cuda... +... +Loading extension module anti_alias_activation_cuda... +... +Loading '/path/to/your/bigvgan_generator.pt' +... +[Success] test CUDA fused vs. plain torch BigVGAN inference + > mean_difference=0.0007238413265440613 +... +``` + +If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version. + +## Pretrained Models + +We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a). +One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories. + +| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned | +|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:| +| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No | +| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No | +| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No | +| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No | +| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No | + +The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset. +We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications. +Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality. + +You can fine-tune the models by: + +1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states) +2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py` + +## Training Details of BigVGAN-v2 + +Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs. + +Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs. + +When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`. + +## Evaluation Results of BigVGAN-v2 + +Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio. + +| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) | +|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:| +| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 | +| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 | +| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 | +| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** | + +## Speed Benchmark + +Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model. + +| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) | +|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:| +| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 | +| | | True | 3916.5 | 163.2x | 1.3 | +| | 2048 | False | 1899.6 | 79.2x | 1.7 | +| | | True | 5330.1 | 222.1x | 1.7 | +| | 16384 | False | 1973.8 | 82.2x | 5.0 | +| | | True | 5761.7 | 240.1x | 4.4 | +| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 | +| | | True | 1598.1 | 66.6x | 1.3 | +| | 2048 | False | 929.9 | 38.7x | 1.7 | +| | | True | 1971.3 | 82.1x | 1.6 | +| | 16384 | False | 943.4 | 39.3x | 5.0 | +| | | True | 2026.5 | 84.4x | 3.9 | +| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 | +| | | True | 811.3 | 33.8x | 1.3 | +| | 2048 | False | 576.5 | 24.0x | 1.7 | +| | | True | 1023.0 | 42.6x | 1.5 | +| | 16384 | False | 589.4 | 24.6x | 5.0 | +| | | True | 1068.1 | 44.5x | 3.2 | + +## Acknowledgements + +We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference. + +## References + +- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator) +- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation) +- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing) +- [Julius](https://github.com/adefossez/julius) (for low-pass filter) +- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator) +- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss) +- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator) diff --git a/GPT_SoVITS/BigVGAN/activations.py b/GPT_SoVITS/BigVGAN/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..abe3ad9e25c6ab3d4545c6a8c60e1f85a5a8e98e --- /dev/null +++ b/GPT_SoVITS/BigVGAN/activations.py @@ -0,0 +1,122 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + """ + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + """ + super(Snake, self).__init__() + self.in_features = in_features + + # Initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # Log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # Linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # Initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # Log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # Linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/GPT_SoVITS/BigVGAN/bigvgan.py b/GPT_SoVITS/BigVGAN/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..febdf165c354b1fa2932f27e4ef8b7b6da10e2a6 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/bigvgan.py @@ -0,0 +1,461 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import os +import json +from pathlib import Path +from typing import Optional, Union, Dict + +import torch +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from . import activations +from .utils0 import init_weights, get_padding +from .alias_free_activation.torch.act import Activation1d as TorchActivation1d +from .env import AttrDict + +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download + + +def load_hparams_from_json(path) -> AttrDict: + with open(path) as f: + data = f.read() + return AttrDict(json.loads(data)) + + +class AMPBlock1(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ) + for _ in range(len(dilation)) + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer. + Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + ): + super().__init__() + + self.h = h + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + ) + for d in dilation + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # Total number of conv layers + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + # Activation functions + if activation == "snake": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + elif activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN( + torch.nn.Module, + PyTorchModelHubMixin, + # library_name="bigvgan", + # repo_url="https://github.com/NVIDIA/BigVGAN", + # docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md", + # pipeline_tag="audio-to-audio", + # license="mit", + # tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"], +): + """ + BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks). + New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks. + + Args: + h (AttrDict): Hyperparameters. + use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels. + + Note: + - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported. + - Ensure that the activation function is correctly specified in the hyperparameters (h.activation). + """ + + def __init__(self, h: AttrDict, use_cuda_kernel: bool = False): + super().__init__() + self.h = h + self.h["use_cuda_kernel"] = use_cuda_kernel + + # Select which Activation1d, lazy-load cuda version to ensure backward compatibility + if self.h.get("use_cuda_kernel", False): + from .alias_free_activation.cuda.activation1d import ( + Activation1d as CudaActivation1d, + ) + + Activation1d = CudaActivation1d + else: + Activation1d = TorchActivation1d + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # Pre-conv + self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) + + # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + if h.resblock == "1": + resblock_class = AMPBlock1 + elif h.resblock == "2": + resblock_class = AMPBlock2 + else: + raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}") + + # Transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ] + ) + ) + + # Residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation)) + + # Post-conv + activation_post = ( + activations.Snake(ch, alpha_logscale=h.snake_logscale) + if h.activation == "snake" + else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None) + ) + if activation_post is None: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.activation_post = Activation1d(activation=activation_post) + + # Whether to use bias for the final conv_post. Default to True for backward compatibility + self.use_bias_at_final = h.get("use_bias_at_final", True) + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)) + + # Weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + # Final tanh activation. Defaults to True for backward compatibility + self.use_tanh_at_final = h.get("use_tanh_at_final", True) + + def forward(self, x): + # Pre-conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # Upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # Post-conv + x = self.activation_post(x) + x = self.conv_post(x) + # Final tanh activation + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] + + return x + + def remove_weight_norm(self): + try: + # print("Removing weight norm...") + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + except ValueError: + print("[INFO] Model already removed weight norm. Skipping!") + pass + + # Additional methods for huggingface_hub support + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights and config.json from a Pytorch model to a local directory.""" + + model_path = save_directory / "bigvgan_generator.pt" + torch.save({"generator": self.state_dict()}, model_path) + + config_path = save_directory / "config.json" + with open(config_path, "w") as config_file: + json.dump(self.h, config_file, indent=4) + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: str, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", # Additional argument + strict: bool = False, # Additional argument + use_cuda_kernel: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + + # Download and load hyperparameters (h) used by BigVGAN + if os.path.isdir(model_id): + # print("Loading config.json from local directory") + config_file = os.path.join(model_id, "config.json") + else: + config_file = hf_hub_download( + repo_id=model_id, + filename="config.json", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + h = load_hparams_from_json(config_file) + + # instantiate BigVGAN using h + if use_cuda_kernel: + print( + "[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!" + ) + print( + "[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!" + ) + print( + "[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis" + ) + model = cls(h, use_cuda_kernel=use_cuda_kernel) + + # Download and load pretrained generator weight + if os.path.isdir(model_id): + # print("Loading weights from local directory") + model_file = os.path.join(model_id, "bigvgan_generator.pt") + else: + # print(f"Loading weights from {model_id}") + model_file = hf_hub_download( + repo_id=model_id, + filename="bigvgan_generator.pt", + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + + checkpoint_dict = torch.load(model_file, map_location=map_location) + + try: + model.load_state_dict(checkpoint_dict["generator"]) + except RuntimeError: + print( + "[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!" + ) + model.remove_weight_norm() + model.load_state_dict(checkpoint_dict["generator"]) + + return model diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_22khz_80band.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_22khz_80band.json new file mode 100644 index 0000000000000000000000000000000000000000..64bca7846edb4e86d7ee22d9ca7a1554cf7f1042 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_22khz_80band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_24khz_100band.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_24khz_100band.json new file mode 100644 index 0000000000000000000000000000000000000000..e7f7ff08f6697a4640d8e28c0b3fe7e62d0c3fc7 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_24khz_100band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": 12000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_base_22khz_80band.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_base_22khz_80band.json new file mode 100644 index 0000000000000000000000000000000000000000..fd244848308917f4df7ce49bf6b76530fd04cbc2 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_base_22khz_80band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [8,8,2,2], + "upsample_kernel_sizes": [16,16,4,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_base_24khz_100band.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_base_24khz_100band.json new file mode 100644 index 0000000000000000000000000000000000000000..0911508cac4a9346ada8c196bfcc228998da6f42 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_base_24khz_100band.json @@ -0,0 +1,45 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 32, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [8,8,2,2], + "upsample_kernel_sizes": [16,16,4,4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "activation": "snakebeta", + "snake_logscale": true, + + "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "segment_size": 8192, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": 12000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..e96bd5fdd5b99767adba7f13bfcd1f777d5c599a --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..a3c9699fbe11948f4fd7e3434d2e623a00c802dd --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 22050, + + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..8057ee267c8ed80615362a41892b923a3ccd27e5 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 100, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 24000, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json new file mode 100644 index 0000000000000000000000000000000000000000..b6999d3028e5d741ec99b16b34f153e763d0cfec --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [4,4,2,2,2,2], + "upsample_kernel_sizes": [8,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 128, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + + "sampling_rate": 44100, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json new file mode 100644 index 0000000000000000000000000000000000000000..2d7176c910ae0969f208f6d28e3f14abca2dbc7f --- /dev/null +++ b/GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "num_gpus": 0, + "batch_size": 4, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.9999996, + "seed": 1234, + + "upsample_rates": [8,4,2,2,2,2], + "upsample_kernel_sizes": [16,8,4,4,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], + + "use_tanh_at_final": false, + "use_bias_at_final": false, + + "activation": "snakebeta", + "snake_logscale": true, + + "use_cqtd_instead_of_mrd": true, + "cqtd_filters": 128, + "cqtd_max_filters": 1024, + "cqtd_filters_scale": 1, + "cqtd_dilations": [1, 2, 4], + "cqtd_hop_lengths": [512, 256, 256], + "cqtd_n_octaves": [9, 9, 9], + "cqtd_bins_per_octaves": [24, 36, 48], + + "mpd_reshapes": [2, 3, 5, 7, 11], + "use_spectral_norm": false, + "discriminator_channel_mult": 1, + + "use_multiscale_melloss": true, + "lambda_melloss": 15, + + "clip_grad_norm": 500, + + "segment_size": 65536, + "num_mels": 128, + "num_freq": 2049, + "n_fft": 2048, + "hop_size": 512, + "win_size": 2048, + + "sampling_rate": 44100, + + "fmin": 0, + "fmax": null, + "fmax_for_loss": null, + + "num_workers": 4, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/GPT_SoVITS/BigVGAN/discriminators.py b/GPT_SoVITS/BigVGAN/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..2d44c7983955a1be15a4520f6730de272f799128 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/discriminators.py @@ -0,0 +1,625 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm, spectral_norm +from torchaudio.transforms import Spectrogram, Resample + +from env import AttrDict +from utils import get_padding +import typing +from typing import List, Tuple + + +class DiscriminatorP(torch.nn.Module): + def __init__( + self, + h: AttrDict, + period: List[int], + kernel_size: int = 5, + stride: int = 3, + use_spectral_norm: bool = False, + ): + super().__init__() + self.period = period + self.d_mult = h.discriminator_channel_mult + norm_f = weight_norm if not use_spectral_norm else spectral_norm + + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + int(32 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(32 * self.d_mult), + int(128 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(128 * self.d_mult), + int(512 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(512 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(1024 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + 1, + padding=(2, 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, h: AttrDict): + super().__init__() + self.mpd_reshapes = h.mpd_reshapes + print(f"mpd_reshapes: {self.mpd_reshapes}") + self.discriminators = nn.ModuleList( + [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg: AttrDict, resolution: List[List[int]]): + super().__init__() + + self.resolution = resolution + assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}" + self.lrelu_slope = 0.1 + + norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm + if hasattr(cfg, "mrd_use_spectral_norm"): + print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}") + norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm + self.d_mult = cfg.discriminator_channel_mult + if hasattr(cfg, "mrd_channel_mult"): + print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}") + self.d_mult = cfg.mrd_channel_mult + + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 3), + padding=(1, 1), + ) + ), + ] + ) + self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x: torch.Tensor) -> torch.Tensor: + n_fft, hop_length, win_length = self.resolution + x = F.pad( + x, + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + x = x.squeeze(1) + x = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=False, + return_complex=True, + ) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.resolutions + assert len(self.resolutions) == 3, ( + f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}" + ) + self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec +# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license. +# LICENSE is in incl_licenses directory. +class DiscriminatorB(nn.Module): + def __init__( + self, + window_length: int, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ( + (0.0, 0.1), + (0.1, 0.25), + (0.25, 0.5), + (0.5, 0.75), + (0.75, 1.0), + ), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, + hop_length=int(window_length * hop_factor), + win_length=window_length, + power=None, + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]: + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F] + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x_bands = self.spectrogram(x.squeeze(1)) + fmap = [] + x = [] + + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return x, fmap + + +# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec +# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license. +# LICENSE is in incl_licenses directory. +class MultiBandDiscriminator(nn.Module): + def __init__( + self, + h, + ): + """ + Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec. + and the modified code adapted from https://github.com/gemelo-ai/vocos. + """ + super().__init__() + # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h. + self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512]) + self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license. +# LICENSE is in incl_licenses directory. +class DiscriminatorCQT(nn.Module): + def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int): + super().__init__() + self.cfg = cfg + + self.filters = cfg["cqtd_filters"] + self.max_filters = cfg["cqtd_max_filters"] + self.filters_scale = cfg["cqtd_filters_scale"] + self.kernel_size = (3, 9) + self.dilations = cfg["cqtd_dilations"] + self.stride = (1, 2) + + self.in_channels = cfg["cqtd_in_channels"] + self.out_channels = cfg["cqtd_out_channels"] + self.fs = cfg["sampling_rate"] + self.hop_length = hop_length + self.n_octaves = n_octaves + self.bins_per_octave = bins_per_octave + + # Lazy-load + from nnAudio import features + + self.cqt_transform = features.cqt.CQT2010v2( + sr=self.fs * 2, + hop_length=self.hop_length, + n_bins=self.bins_per_octave * self.n_octaves, + bins_per_octave=self.bins_per_octave, + output_format="Complex", + pad_mode="constant", + ) + + self.conv_pres = nn.ModuleList() + for _ in range(self.n_octaves): + self.conv_pres.append( + nn.Conv2d( + self.in_channels * 2, + self.in_channels * 2, + kernel_size=self.kernel_size, + padding=self.get_2d_padding(self.kernel_size), + ) + ) + + self.convs = nn.ModuleList() + + self.convs.append( + nn.Conv2d( + self.in_channels * 2, + self.filters, + kernel_size=self.kernel_size, + padding=self.get_2d_padding(self.kernel_size), + ) + ) + + in_chs = min(self.filters_scale * self.filters, self.max_filters) + for i, dilation in enumerate(self.dilations): + out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters) + self.convs.append( + weight_norm( + nn.Conv2d( + in_chs, + out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=(dilation, 1), + padding=self.get_2d_padding(self.kernel_size, (dilation, 1)), + ) + ) + ) + in_chs = out_chs + out_chs = min( + (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, + self.max_filters, + ) + self.convs.append( + weight_norm( + nn.Conv2d( + in_chs, + out_chs, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + ) + ) + ) + + self.conv_post = weight_norm( + nn.Conv2d( + out_chs, + self.out_channels, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + ) + ) + + self.activation = torch.nn.LeakyReLU(negative_slope=0.1) + self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2) + + self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False) + if self.cqtd_normalize_volume: + print( + "[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!" + ) + + def get_2d_padding( + self, + kernel_size: typing.Tuple[int, int], + dilation: typing.Tuple[int, int] = (1, 1), + ): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) + + def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + fmap = [] + + if self.cqtd_normalize_volume: + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + + x = self.resample(x) + + z = self.cqt_transform(x) + + z_amplitude = z[:, :, :, 0].unsqueeze(1) + z_phase = z[:, :, :, 1].unsqueeze(1) + + z = torch.cat([z_amplitude, z_phase], dim=1) + z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W] + + latent_z = [] + for i in range(self.n_octaves): + latent_z.append( + self.conv_pres[i]( + z[ + :, + :, + :, + i * self.bins_per_octave : (i + 1) * self.bins_per_octave, + ] + ) + ) + latent_z = torch.cat(latent_z, dim=-1) + + for i, l in enumerate(self.convs): + latent_z = l(latent_z) + + latent_z = self.activation(latent_z) + fmap.append(latent_z) + + latent_z = self.conv_post(latent_z) + + return latent_z, fmap + + +class MultiScaleSubbandCQTDiscriminator(nn.Module): + def __init__(self, cfg: AttrDict): + super().__init__() + + self.cfg = cfg + # Using get with defaults + self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32) + self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024) + self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1) + self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4]) + self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1) + self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1) + # Multi-scale params to loop over + self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256]) + self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9]) + self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48]) + + self.discriminators = nn.ModuleList( + [ + DiscriminatorCQT( + self.cfg, + hop_length=self.cfg["cqtd_hop_lengths"][i], + n_octaves=self.cfg["cqtd_n_octaves"][i], + bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i], + ) + for i in range(len(self.cfg["cqtd_hop_lengths"])) + ] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discriminators: + y_d_r, fmap_r = disc(y) + y_d_g, fmap_g = disc(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class CombinedDiscriminator(nn.Module): + """ + Wrapper of chaining multiple discrimiantor architectures. + Example: combine mbd and cqtd as a single class + """ + + def __init__(self, list_discriminator: List[nn.Module]): + super().__init__() + self.discrimiantor = nn.ModuleList(list_discriminator) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discrimiantor: + y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat) + y_d_rs.extend(y_d_r) + fmap_rs.extend(fmap_r) + y_d_gs.extend(y_d_g) + fmap_gs.extend(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/GPT_SoVITS/BigVGAN/env.py b/GPT_SoVITS/BigVGAN/env.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8ac6cea644c78d115dd3902b902993f366ee61 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/env.py @@ -0,0 +1,18 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/GPT_SoVITS/BigVGAN/inference.py b/GPT_SoVITS/BigVGAN/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5f892a3c807a7020eff7fea35179b0f6e5f991c9 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/inference.py @@ -0,0 +1,85 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import argparse +import json +import torch +import librosa +from utils import load_checkpoint +from meldataset import get_mel_spectrogram +from scipy.io.wavfile import write +from env import AttrDict +from meldataset import MAX_WAV_VALUE +from bigvgan import BigVGAN as Generator + +h = None +device = None +torch.backends.cudnn.benchmark = False + + +def inference(a, h): + generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device) + + state_dict_g = load_checkpoint(a.checkpoint_file, device) + generator.load_state_dict(state_dict_g["generator"]) + + filelist = os.listdir(a.input_wavs_dir) + + os.makedirs(a.output_dir, exist_ok=True) + + generator.eval() + generator.remove_weight_norm() + with torch.no_grad(): + for i, filname in enumerate(filelist): + # Load the ground truth audio and resample if necessary + wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True) + wav = torch.FloatTensor(wav).to(device) + # Compute mel spectrogram from the ground truth audio + x = get_mel_spectrogram(wav.unsqueeze(0), generator.h) + + y_g_hat = generator(x) + + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + + output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav") + write(output_file, h.sampling_rate, audio) + print(output_file) + + +def main(): + print("Initializing Inference Process..") + + parser = argparse.ArgumentParser() + parser.add_argument("--input_wavs_dir", default="test_files") + parser.add_argument("--output_dir", default="generated_files") + parser.add_argument("--checkpoint_file", required=True) + parser.add_argument("--use_cuda_kernel", action="store_true", default=False) + + a = parser.parse_args() + + config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + torch.manual_seed(h.seed) + global device + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + device = torch.device("cuda") + else: + device = torch.device("cpu") + + inference(a, h) + + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/BigVGAN/inference_e2e.py b/GPT_SoVITS/BigVGAN/inference_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0df77435e91935beaca365dd5fd38d76098a4a --- /dev/null +++ b/GPT_SoVITS/BigVGAN/inference_e2e.py @@ -0,0 +1,100 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import glob +import os +import numpy as np +import argparse +import json +import torch +from scipy.io.wavfile import write +from env import AttrDict +from meldataset import MAX_WAV_VALUE +from bigvgan import BigVGAN as Generator + +h = None +device = None +torch.backends.cudnn.benchmark = False + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "*") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return "" + return sorted(cp_list)[-1] + + +def inference(a, h): + generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device) + + state_dict_g = load_checkpoint(a.checkpoint_file, device) + generator.load_state_dict(state_dict_g["generator"]) + + filelist = os.listdir(a.input_mels_dir) + + os.makedirs(a.output_dir, exist_ok=True) + + generator.eval() + generator.remove_weight_norm() + with torch.no_grad(): + for i, filname in enumerate(filelist): + # Load the mel spectrogram in .npy format + x = np.load(os.path.join(a.input_mels_dir, filname)) + x = torch.FloatTensor(x).to(device) + if len(x.shape) == 2: + x = x.unsqueeze(0) + + y_g_hat = generator(x) + + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + + output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav") + write(output_file, h.sampling_rate, audio) + print(output_file) + + +def main(): + print("Initializing Inference Process..") + + parser = argparse.ArgumentParser() + parser.add_argument("--input_mels_dir", default="test_mel_files") + parser.add_argument("--output_dir", default="generated_files_from_mel") + parser.add_argument("--checkpoint_file", required=True) + parser.add_argument("--use_cuda_kernel", action="store_true", default=False) + + a = parser.parse_args() + + config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + torch.manual_seed(h.seed) + global device + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + device = torch.device("cuda") + else: + device = torch.device("cpu") + + inference(a, h) + + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/BigVGAN/loss.py b/GPT_SoVITS/BigVGAN/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c295a144ff7bcfc0d91d9d4676bedfa7015cdb79 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/loss.py @@ -0,0 +1,238 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + + +import torch +import torch.nn as nn +from librosa.filters import mel as librosa_mel_fn +from scipy import signal + +import typing +from typing import List, Tuple +from collections import namedtuple +import math +import functools + + +# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license. +# LICENSE is in incl_licenses directory. +class MultiScaleMelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320], + window_lengths : List[int], optional + Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part) + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 1.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py + """ + + def __init__( + self, + sampling_rate: int, + n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], + window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 0.0, + log_weight: float = 1.0, + pow: float = 1.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0], + mel_fmax: List[float] = [None, None, None, None, None, None, None], + window_type: str = "hann", + ): + super().__init__() + self.sampling_rate = sampling_rate + + STFTParams = namedtuple( + "STFTParams", + ["window_length", "hop_length", "window_type", "match_stride"], + ) + + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + @staticmethod + @functools.lru_cache(None) + def get_window( + window_type, + window_length, + ): + return signal.get_window(window_type, window_length) + + @staticmethod + @functools.lru_cache(None) + def get_mel_filters(sr, n_fft, n_mels, fmin, fmax): + return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + + def mel_spectrogram( + self, + wav, + n_mels, + fmin, + fmax, + window_length, + hop_length, + match_stride, + window_type, + ): + """ + Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from: + https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py + """ + B, C, T = wav.shape + + if match_stride: + assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4" + right_pad = math.ceil(T / hop_length) * hop_length - T + pad = (window_length - hop_length) // 2 + else: + right_pad = 0 + pad = 0 + + wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect") + + window = self.get_window(window_type, window_length) + window = torch.from_numpy(window).to(wav.device).float() + + stft = torch.stft( + wav.reshape(-1, T), + n_fft=window_length, + hop_length=hop_length, + window=window, + return_complex=True, + center=True, + ) + _, nf, nt = stft.shape + stft = stft.reshape(B, C, nf, nt) + if match_stride: + """ + Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples. + """ + stft = stft[..., 2:-2] + magnitude = torch.abs(stft) + + nf = magnitude.shape[2] + mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax) + mel_basis = torch.from_numpy(mel_basis).to(wav.device) + mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T + mel_spectrogram = mel_spectrogram.transpose(-1, 2) + + return mel_spectrogram + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : torch.Tensor + Estimate signal + y : torch.Tensor + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + + loss = 0.0 + for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params): + kwargs = { + "n_mels": n_mels, + "fmin": fmin, + "fmax": fmax, + "window_length": s.window_length, + "hop_length": s.hop_length, + "match_stride": s.match_stride, + "window_type": s.window_type, + } + + x_mels = self.mel_spectrogram(x, **kwargs) + y_mels = self.mel_spectrogram(y, **kwargs) + x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) + y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0)) + + loss += self.log_weight * self.loss_fn(x_logmels, y_logmels) + loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels) + + return loss + + +# Loss functions +def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 # This equates to lambda=2.0 for the feature matching loss + + +def discriminator_loss( + disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] +) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss( + disc_outputs: List[torch.Tensor], +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/GPT_SoVITS/BigVGAN/meldataset.py b/GPT_SoVITS/BigVGAN/meldataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dc12c9874cfb9958d6f4842cc067ffda66a390eb --- /dev/null +++ b/GPT_SoVITS/BigVGAN/meldataset.py @@ -0,0 +1,370 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import math +import os +import random +import torch +import torch.utils.data +import numpy as np +import librosa +from librosa.filters import mel as librosa_mel_fn +import pathlib +from tqdm import tqdm +from typing import List, Tuple, Optional +from .env import AttrDict + +MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +mel_basis_cache = {} +hann_window_cache = {} + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sampling_rate: int, + hop_size: int, + win_size: int, + fmin: int, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """ + Calculate the mel spectrogram of an input signal. + This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). + + Args: + y (torch.Tensor): Input signal. + n_fft (int): FFT size. + num_mels (int): Number of mel bins. + sampling_rate (int): Sampling rate of the input signal. + hop_size (int): Hop size for STFT. + win_size (int): Window size for STFT. + fmin (int): Minimum frequency for mel filterbank. + fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn + center (bool): Whether to pad the input to center the frames. Default is False. + + Returns: + torch.Tensor: Mel spectrogram. + """ + if torch.min(y) < -1.0: + print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") + if torch.max(y) > 1.0: + print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") + + device = y.device + key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) + hann_window_cache[key] = torch.hann_window(win_size).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = spectral_normalize_torch(mel_spec) + + return mel_spec + + +def get_mel_spectrogram(wav, h): + """ + Generate mel spectrogram from a waveform using given hyperparameters. + + Args: + wav (torch.Tensor): Input waveform. + h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax. + + Returns: + torch.Tensor: Mel spectrogram. + """ + return mel_spectrogram( + wav, + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax, + ) + + +def get_dataset_filelist(a): + training_files = [] + validation_files = [] + list_unseen_validation_files = [] + + with open(a.input_training_file, "r", encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + print(f"first training file: {training_files[0]}") + + with open(a.input_validation_file, "r", encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 + ] + print(f"first validation file: {validation_files[0]}") + + for i in range(len(a.list_input_unseen_validation_file)): + with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi: + unseen_validation_files = [ + os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 + ] + print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}") + list_unseen_validation_files.append(unseen_validation_files) + + return training_files, validation_files, list_unseen_validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files: List[str], + hparams: AttrDict, + segment_size: int, + n_fft: int, + num_mels: int, + hop_size: int, + win_size: int, + sampling_rate: int, + fmin: int, + fmax: Optional[int], + split: bool = True, + shuffle: bool = True, + device: str = None, + fmax_loss: Optional[int] = None, + fine_tuning: bool = False, + base_mels_path: str = None, + is_seen: bool = True, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.hparams = hparams + self.is_seen = is_seen + if self.is_seen: + self.name = pathlib.Path(self.audio_files[0]).parts[0] + else: + self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/") + + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + print("[INFO] checking dataset integrity...") + for i in tqdm(range(len(self.audio_files))): + assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found" + + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]: + try: + filename = self.audio_files[index] + + # Use librosa.load that ensures loading waveform into mono with [-1, 1] float values + # Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead + # The on-the-fly resampling during training will be done only for the obtained random chunk + audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True) + + # Main logic that uses pair for training BigVGAN + if not self.fine_tuning: + if self.split: # Training step + # Obtain randomized audio chunk + if source_sampling_rate != self.sampling_rate: + # Adjust segment size to crop if the source sr is different + target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate)) + else: + target_segment_size = self.segment_size + + # Compute upper bound index for the random chunk + random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size) + + # Crop or pad audio to obtain random chunk with target_segment_size + if audio.shape[0] >= target_segment_size: + audio_start = random.randint(0, random_chunk_upper_bound) + audio = audio[audio_start : audio_start + target_segment_size] + else: + audio = np.pad( + audio, + (0, target_segment_size - audio.shape[0]), + mode="constant", + ) + + # Resample audio chunk to self.sampling rate + if source_sampling_rate != self.sampling_rate: + audio = librosa.resample( + audio, + orig_sr=source_sampling_rate, + target_sr=self.sampling_rate, + ) + if audio.shape[0] > self.segment_size: + # trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384) + audio = audio[: self.segment_size] + + else: # Validation step + # Resample full audio clip to target sampling rate + if source_sampling_rate != self.sampling_rate: + audio = librosa.resample( + audio, + orig_sr=source_sampling_rate, + target_sr=self.sampling_rate, + ) + # Trim last elements to match audio length to self.hop_size * n for evaluation + if (audio.shape[0] % self.hop_size) != 0: + audio = audio[: -(audio.shape[0] % self.hop_size)] + + # BigVGAN is trained using volume-normalized waveform + audio = librosa.util.normalize(audio) * 0.95 + + # Cast ndarray to torch tensor + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) # [B(1), self.segment_size] + + # Compute mel spectrogram corresponding to audio + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) # [B(1), self.num_mels, self.segment_size // self.hop_size] + + # Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input + else: + # For fine-tuning, assert that the waveform is in the defined sampling_rate + # Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly) + assert source_sampling_rate == self.sampling_rate, ( + f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}" + ) + + # Cast ndarray to torch tensor + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) # [B(1), T_time] + + # Load pre-computed mel from disk + mel = np.load( + os.path.join( + self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", + ) + ) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) # ensure [B, C, T] + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[ + :, + mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size, + ] + + # Pad pre-computed mel and audio to match length to ensuring fine-tuning without error. + # NOTE: this may introduce a single-frame misalignment of the + # To remove possible misalignment, it is recommended to prepare the pair where the audio length is the integer multiple of self.hop_size + mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") + audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") + + # Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None) + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) # [B(1), self.num_mels, self.segment_size // self.hop_size] + + # Shape sanity checks + assert ( + audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size + ), ( + f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}" + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + # If it encounters error during loading the data, skip this sample and load random other sample to the batch + except Exception as e: + if self.fine_tuning: + raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly. + else: + print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}") + return self[random.randrange(len(self))] + + def __len__(self): + return len(self.audio_files) diff --git a/GPT_SoVITS/BigVGAN/nv-modelcard++/.gitkeep b/GPT_SoVITS/BigVGAN/nv-modelcard++/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/GPT_SoVITS/BigVGAN/nv-modelcard++/.gitkeep @@ -0,0 +1 @@ + diff --git a/GPT_SoVITS/BigVGAN/nv-modelcard++/bias.md b/GPT_SoVITS/BigVGAN/nv-modelcard++/bias.md new file mode 100644 index 0000000000000000000000000000000000000000..4b388c28d09b8ca3aab5096304c52e1a5dac0e16 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/nv-modelcard++/bias.md @@ -0,0 +1,4 @@ +| Field | Response | +| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- | +| Participation considerations from adversely impacted groups protected classes in model design and testing: | None | +| Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. | diff --git a/GPT_SoVITS/BigVGAN/nv-modelcard++/explainability.md b/GPT_SoVITS/BigVGAN/nv-modelcard++/explainability.md new file mode 100644 index 0000000000000000000000000000000000000000..6f1a16676e438ba95f9d411a19e04a0f13409e54 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/nv-modelcard++/explainability.md @@ -0,0 +1,13 @@ +| Field | Response | +| :---------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Intended Application & Domain: | Generating waveform from mel spectrogram. | +| Model Type: | Convolutional Neural Network (CNN) | +| Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. | +| Output: | Audio Waveform | +| Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. | +| Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable | +| Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. | +| Verified to have met prescribed NVIDIA quality standards: | Yes | +| Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) | +| Potential Known Risks: | This model may generate low-quality or distorted soundwaves. | +| Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE | diff --git a/GPT_SoVITS/BigVGAN/nv-modelcard++/overview.md b/GPT_SoVITS/BigVGAN/nv-modelcard++/overview.md new file mode 100644 index 0000000000000000000000000000000000000000..a39cba0b49a4a32a37afa90f2baf4630dcd9cadc --- /dev/null +++ b/GPT_SoVITS/BigVGAN/nv-modelcard++/overview.md @@ -0,0 +1,126 @@ +# Model Overview + +## Description: + +BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs. + +
+ +BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers. + +BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles. + +It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms.
+ +This model is ready for commercial use.
+ +## References(s): + +- [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658)
+- [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/)
+- [Audio Demo](https://bigvgan-demo.github.io/)
+ +## Model Architecture: + +**Architecture Type:** Convolution Neural Network (CNN)
+**Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658
+**Model Version:** 2.0
+ +## Input: + +**Input Type:** Audio
+**Input Format:** Mel Spectrogram
+**Input Parameters:** None
+**Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory. + +## Output: + +**Input Type:** Audio
+**Output Format:** Audio Waveform
+**Output Parameters:** None
+**Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`. + +## Software Integration: + +**Runtime Engine(s):** PyTorch + +**Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta
+ +## Preferred/Supported Operating System(s): + +Linux + +## Model Version(s): + +v2.0 + +## Training, Testing, and Evaluation Datasets: + +### Training Dataset: + +The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments. + +**Links:** + +- [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629) +- [AudioCaps](https://audiocaps.github.io/) +- [AudioSet](https://research.google.com/audioset/index.html) +- [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent) +- [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440) +- [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection) +- [FSDnoisy18k](https://zenodo.org/records/2529934) +- [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384) +- [Greatest Hits dataset](https://andrewowens.com/vis/) +- [GTZAN](https://ieeexplore.ieee.org/document/1021072) +- [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus) +- [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194) +- [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/) +- [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench) +- [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps) +- [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset) +- [NSynth](https://magenta.tensorflow.org/datasets/nsynth) +- [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset) +- [Audio Piano Triads Dataset](https://zenodo.org/records/4740877) +- [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097) +- [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543) +- [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433) +- [WavText5K](https://github.com/microsoft/WavText5K) +- [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10) +- [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/) +- [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/) +- [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875) +- [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60) +- [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/) +- [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353) +- [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/) +- [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus) +- [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443) + +\*\* Data Collection Method by dataset
+ +- Human
+ +\*\* Labeling Method by dataset (for those with labels)
+ +- Hybrid: Automated, Human, Unknown
+ +### Evaluating Dataset: + +Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders. + +\*\* Data Collection Method by dataset
+ +- Human
+ +\*\* Labeling Method by dataset
+ +- Automated
+ +## Inference: + +**Engine:** PyTorch
+**Test Hardware:** NVIDIA A100 GPU
+ +## Ethical Considerations: + +NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/). diff --git a/GPT_SoVITS/BigVGAN/nv-modelcard++/privacy.md b/GPT_SoVITS/BigVGAN/nv-modelcard++/privacy.md new file mode 100644 index 0000000000000000000000000000000000000000..73554a998384ca1b1050239ebd51bda46aec1878 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/nv-modelcard++/privacy.md @@ -0,0 +1,14 @@ +| Field | Response | +| :------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------- | +| Generatable or reverse engineerable personal information? | None | +| Protected class data used to create this model? | None | +| Was consent obtained for any personal data used? | Not Applicable (No Personal Data) | +| How often is dataset reviewed? | Before Release | +| Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable | +| If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable | +| If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable | +| If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable | +| Is data in dataset traceable? | Yes | +| Is there provenance for all datasets used in training? | Yes | +| Does data labeling (annotation, metadata) comply with privacy laws? | Yes | +| Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. | diff --git a/GPT_SoVITS/BigVGAN/nv-modelcard++/safety.md b/GPT_SoVITS/BigVGAN/nv-modelcard++/safety.md new file mode 100644 index 0000000000000000000000000000000000000000..ed30370dfedbbb49748706034a7153d54f1a668f --- /dev/null +++ b/GPT_SoVITS/BigVGAN/nv-modelcard++/safety.md @@ -0,0 +1,6 @@ +| Field | Response | +| :---------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Model Application(s): | Synethic Audio Generation | +| Describe the life critical impact (if present). | Not Applicable | +| Use Case Restrictions: | None | +| Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. | diff --git a/GPT_SoVITS/BigVGAN/requirements.txt b/GPT_SoVITS/BigVGAN/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e61d3203966612e6ad193bbabdef10b1d3fed84 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/requirements.txt @@ -0,0 +1,13 @@ +torch +numpy +librosa>=0.8.1 +scipy +tensorboard +soundfile +matplotlib +pesq +auraloss +tqdm +nnAudio +ninja +huggingface_hub>=0.23.4 \ No newline at end of file diff --git a/GPT_SoVITS/BigVGAN/train.py b/GPT_SoVITS/BigVGAN/train.py new file mode 100644 index 0000000000000000000000000000000000000000..39718cdb33d2e9a88ec9b98dd2032bdce83a4231 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/train.py @@ -0,0 +1,716 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + + +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) +import itertools +import os +import time +import argparse +import json +import torch +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DistributedSampler, DataLoader +import torch.multiprocessing as mp +from torch.distributed import init_process_group +from torch.nn.parallel import DistributedDataParallel +from env import AttrDict, build_env +from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE + +from bigvgan import BigVGAN +from discriminators import ( + MultiPeriodDiscriminator, + MultiResolutionDiscriminator, + MultiBandDiscriminator, + MultiScaleSubbandCQTDiscriminator, +) +from loss import ( + feature_loss, + generator_loss, + discriminator_loss, + MultiScaleMelSpectrogramLoss, +) + +from utils import ( + plot_spectrogram, + plot_spectrogram_clipped, + scan_checkpoint, + load_checkpoint, + save_checkpoint, + save_audio, +) +import torchaudio as ta +from pesq import pesq +from tqdm import tqdm +import auraloss + +torch.backends.cudnn.benchmark = False + + +def train(rank, a, h): + if h.num_gpus > 1: + # initialize distributed + init_process_group( + backend=h.dist_config["dist_backend"], + init_method=h.dist_config["dist_url"], + world_size=h.dist_config["world_size"] * h.num_gpus, + rank=rank, + ) + + # Set seed and device + torch.cuda.manual_seed(h.seed) + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank:d}") + + # Define BigVGAN generator + generator = BigVGAN(h).to(device) + + # Define discriminators. MPD is used by default + mpd = MultiPeriodDiscriminator(h).to(device) + + # Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default + # New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator + if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD + print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator") + # Variable name is kept as "mrd" for backward compatibility & minimal code change + mrd = MultiBandDiscriminator(h).to(device) + elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD + print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator") + mrd = MultiScaleSubbandCQTDiscriminator(h).to(device) + else: # Fallback to original MRD in BigVGAN-v1 + mrd = MultiResolutionDiscriminator(h).to(device) + + # New in BigVGAN-v2: option to switch to multi-scale L1 mel loss + if h.get("use_multiscale_melloss", False): + print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss") + fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss( + sampling_rate=h.sampling_rate + ) # NOTE: accepts waveform as input + else: + fn_mel_loss_singlescale = F.l1_loss + + # Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory + if rank == 0: + print(generator) + print(mpd) + print(mrd) + print(f"Generator params: {sum(p.numel() for p in generator.parameters())}") + print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}") + print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}") + os.makedirs(a.checkpoint_path, exist_ok=True) + print(f"Checkpoints directory: {a.checkpoint_path}") + + if os.path.isdir(a.checkpoint_path): + # New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training + cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt") + cp_do = scan_checkpoint( + a.checkpoint_path, + prefix="do_", + renamed_file="bigvgan_discriminator_optimizer.pt", + ) + + # Load the latest checkpoint if exists + steps = 0 + if cp_g is None or cp_do is None: + state_dict_do = None + last_epoch = -1 + else: + state_dict_g = load_checkpoint(cp_g, device) + state_dict_do = load_checkpoint(cp_do, device) + generator.load_state_dict(state_dict_g["generator"]) + mpd.load_state_dict(state_dict_do["mpd"]) + mrd.load_state_dict(state_dict_do["mrd"]) + steps = state_dict_do["steps"] + 1 + last_epoch = state_dict_do["epoch"] + + # Initialize DDP, optimizers, and schedulers + if h.num_gpus > 1: + generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) + mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) + mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device) + + optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) + optim_d = torch.optim.AdamW( + itertools.chain(mrd.parameters(), mpd.parameters()), + h.learning_rate, + betas=[h.adam_b1, h.adam_b2], + ) + + if state_dict_do is not None: + optim_g.load_state_dict(state_dict_do["optim_g"]) + optim_d.load_state_dict(state_dict_do["optim_d"]) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) + + # Define training and validation datasets + + """ + unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset + Example: trained on LibriTTS, validate on VCTK + """ + training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a) + + trainset = MelDataset( + training_filelist, + h, + h.segment_size, + h.n_fft, + h.num_mels, + h.hop_size, + h.win_size, + h.sampling_rate, + h.fmin, + h.fmax, + shuffle=False if h.num_gpus > 1 else True, + fmax_loss=h.fmax_for_loss, + device=device, + fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, + is_seen=True, + ) + + train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None + + train_loader = DataLoader( + trainset, + num_workers=h.num_workers, + shuffle=False, + sampler=train_sampler, + batch_size=h.batch_size, + pin_memory=True, + drop_last=True, + ) + + if rank == 0: + validset = MelDataset( + validation_filelist, + h, + h.segment_size, + h.n_fft, + h.num_mels, + h.hop_size, + h.win_size, + h.sampling_rate, + h.fmin, + h.fmax, + False, + False, + fmax_loss=h.fmax_for_loss, + device=device, + fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, + is_seen=True, + ) + validation_loader = DataLoader( + validset, + num_workers=1, + shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True, + ) + + list_unseen_validset = [] + list_unseen_validation_loader = [] + for i in range(len(list_unseen_validation_filelist)): + unseen_validset = MelDataset( + list_unseen_validation_filelist[i], + h, + h.segment_size, + h.n_fft, + h.num_mels, + h.hop_size, + h.win_size, + h.sampling_rate, + h.fmin, + h.fmax, + False, + False, + fmax_loss=h.fmax_for_loss, + device=device, + fine_tuning=a.fine_tuning, + base_mels_path=a.input_mels_dir, + is_seen=False, + ) + unseen_validation_loader = DataLoader( + unseen_validset, + num_workers=1, + shuffle=False, + sampler=None, + batch_size=1, + pin_memory=True, + drop_last=True, + ) + list_unseen_validset.append(unseen_validset) + list_unseen_validation_loader.append(unseen_validation_loader) + + # Tensorboard logger + sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs")) + if a.save_audio: # Also save audio to disk if --save_audio is set to True + os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True) + + """ + Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset). + If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors + """ + + def validate(rank, a, h, loader, mode="seen"): + assert rank == 0, "validate should only run on rank=0" + generator.eval() + torch.cuda.empty_cache() + + val_err_tot = 0 + val_pesq_tot = 0 + val_mrstft_tot = 0 + + # Modules for evaluation metrics + pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda() + loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda") + + if a.save_audio: # Also save audio to disk if --save_audio is set to True + os.makedirs( + os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"), + exist_ok=True, + ) + os.makedirs( + os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"), + exist_ok=True, + ) + + with torch.no_grad(): + print(f"step {steps} {mode} speaker validation...") + + # Loop over validation set and compute metrics + for j, batch in enumerate(tqdm(loader)): + x, y, _, y_mel = batch + y = y.to(device) + if hasattr(generator, "module"): + y_g_hat = generator.module(x.to(device)) + else: + y_g_hat = generator(x.to(device)) + y_mel = y_mel.to(device, non_blocking=True) + y_g_hat_mel = mel_spectrogram( + y_g_hat.squeeze(1), + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax_for_loss, + ) + min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1)) + val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item() + + # PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out) + if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech" + # Resample to 16000 for pesq + y_16k = pesq_resampler(y) + y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1)) + y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() + y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy() + val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb") + + # MRSTFT calculation + min_t = min(y.size(-1), y_g_hat.size(-1)) + val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item() + + # Log audio and figures to Tensorboard + if j % a.eval_subsample == 0: # Subsample every nth from validation set + if steps >= 0: + sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate) + if a.save_audio: # Also save audio to disk if --save_audio is set to True + save_audio( + y[0], + os.path.join( + a.checkpoint_path, + "samples", + f"gt_{mode}", + f"{j:04d}.wav", + ), + h.sampling_rate, + ) + sw.add_figure( + f"gt_{mode}/y_spec_{j}", + plot_spectrogram(x[0]), + steps, + ) + + sw.add_audio( + f"generated_{mode}/y_hat_{j}", + y_g_hat[0], + steps, + h.sampling_rate, + ) + if a.save_audio: # Also save audio to disk if --save_audio is set to True + save_audio( + y_g_hat[0, 0], + os.path.join( + a.checkpoint_path, + "samples", + f"{mode}_{steps:08d}", + f"{j:04d}.wav", + ), + h.sampling_rate, + ) + # Spectrogram of synthesized audio + y_hat_spec = mel_spectrogram( + y_g_hat.squeeze(1), + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax, + ) + sw.add_figure( + f"generated_{mode}/y_hat_spec_{j}", + plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), + steps, + ) + + """ + Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization. + """ + spec_delta = torch.clamp( + torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()), + min=1e-6, + max=1.0, + ) + sw.add_figure( + f"delta_dclip1_{mode}/spec_{j}", + plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0), + steps, + ) + + val_err = val_err_tot / (j + 1) + val_pesq = val_pesq_tot / (j + 1) + val_mrstft = val_mrstft_tot / (j + 1) + # Log evaluation metrics to Tensorboard + sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps) + sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps) + sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps) + + generator.train() + + # If the checkpoint is loaded, start with validation loop + if steps != 0 and rank == 0 and not a.debug: + if not a.skip_seen: + validate( + rank, + a, + h, + validation_loader, + mode=f"seen_{train_loader.dataset.name}", + ) + for i in range(len(list_unseen_validation_loader)): + validate( + rank, + a, + h, + list_unseen_validation_loader[i], + mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}", + ) + # Exit the script if --evaluate is set to True + if a.evaluate: + exit() + + # Main training loop + generator.train() + mpd.train() + mrd.train() + for epoch in range(max(0, last_epoch), a.training_epochs): + if rank == 0: + start = time.time() + print(f"Epoch: {epoch + 1}") + + if h.num_gpus > 1: + train_sampler.set_epoch(epoch) + + for i, batch in enumerate(train_loader): + if rank == 0: + start_b = time.time() + x, y, _, y_mel = batch + + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + y_mel = y_mel.to(device, non_blocking=True) + y = y.unsqueeze(1) + + y_g_hat = generator(x) + y_g_hat_mel = mel_spectrogram( + y_g_hat.squeeze(1), + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax_for_loss, + ) + + optim_d.zero_grad() + + # MPD + y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) + loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) + + # MRD + y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach()) + loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) + + loss_disc_all = loss_disc_s + loss_disc_f + + # Set clip_grad_norm value + clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000 + + # Whether to freeze D for initial training steps + if steps >= a.freeze_step: + loss_disc_all.backward() + grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm) + grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm) + optim_d.step() + else: + print(f"[WARNING] skipping D training for the first {a.freeze_step} steps") + grad_norm_mpd = 0.0 + grad_norm_mrd = 0.0 + + # Generator + optim_g.zero_grad() + + # L1 Mel-Spectrogram Loss + lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set + if h.get("use_multiscale_melloss", False): # uses wav for loss + loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss + else: # Uses mel for loss + loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss + + # MPD loss + y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) + loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) + loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) + + # MRD loss + y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat) + loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) + loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) + + if steps >= a.freeze_step: + loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + else: + print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps") + loss_gen_all = loss_mel + + loss_gen_all.backward() + grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm) + optim_g.step() + + if rank == 0: + # STDOUT logging + if steps % a.stdout_interval == 0: + mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout + print( + f"Steps: {steps:d}, " + f"Gen Loss Total: {loss_gen_all:4.3f}, " + f"Mel Error: {mel_error:4.3f}, " + f"s/b: {time.time() - start_b:4.3f} " + f"lr: {optim_g.param_groups[0]['lr']:4.7f} " + f"grad_norm_g: {grad_norm_g:4.3f}" + ) + + # Checkpointing + if steps % a.checkpoint_interval == 0 and steps != 0: + checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}" + save_checkpoint( + checkpoint_path, + {"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()}, + ) + checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}" + save_checkpoint( + checkpoint_path, + { + "mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(), + "mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(), + "optim_g": optim_g.state_dict(), + "optim_d": optim_d.state_dict(), + "steps": steps, + "epoch": epoch, + }, + ) + + # Tensorboard summary logging + if steps % a.summary_interval == 0: + mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard + sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps) + sw.add_scalar("training/mel_spec_error", mel_error, steps) + sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps) + sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps) + sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps) + sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps) + sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps) + sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps) + sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps) + sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps) + sw.add_scalar("training/grad_norm_g", grad_norm_g, steps) + sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps) + sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps) + sw.add_scalar("training/epoch", epoch + 1, steps) + + # Validation + if steps % a.validation_interval == 0: + # Plot training input x so far used + for i_x in range(x.shape[0]): + sw.add_figure( + f"training_input/x_{i_x}", + plot_spectrogram(x[i_x].cpu()), + steps, + ) + sw.add_audio( + f"training_input/y_{i_x}", + y[i_x][0], + steps, + h.sampling_rate, + ) + + # Seen and unseen speakers validation loops + if not a.debug and steps != 0: + validate( + rank, + a, + h, + validation_loader, + mode=f"seen_{train_loader.dataset.name}", + ) + for i in range(len(list_unseen_validation_loader)): + validate( + rank, + a, + h, + list_unseen_validation_loader[i], + mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}", + ) + steps += 1 + + # BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level + scheduler_g.step() + scheduler_d.step() + + if rank == 0: + print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n") + + +def main(): + print("Initializing Training Process..") + + parser = argparse.ArgumentParser() + + parser.add_argument("--group_name", default=None) + + parser.add_argument("--input_wavs_dir", default="LibriTTS") + parser.add_argument("--input_mels_dir", default="ft_dataset") + parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt") + parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt") + + parser.add_argument( + "--list_input_unseen_wavs_dir", + nargs="+", + default=["tests/LibriTTS", "tests/LibriTTS"], + ) + parser.add_argument( + "--list_input_unseen_validation_file", + nargs="+", + default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"], + ) + + parser.add_argument("--checkpoint_path", default="exp/bigvgan") + parser.add_argument("--config", default="") + + parser.add_argument("--training_epochs", default=100000, type=int) + parser.add_argument("--stdout_interval", default=5, type=int) + parser.add_argument("--checkpoint_interval", default=50000, type=int) + parser.add_argument("--summary_interval", default=100, type=int) + parser.add_argument("--validation_interval", default=50000, type=int) + + parser.add_argument( + "--freeze_step", + default=0, + type=int, + help="freeze D for the first specified steps. G only uses regression loss for these steps.", + ) + + parser.add_argument("--fine_tuning", default=False, type=bool) + + parser.add_argument( + "--debug", + default=False, + type=bool, + help="debug mode. skips validation loop throughout training", + ) + parser.add_argument( + "--evaluate", + default=False, + type=bool, + help="only run evaluation from checkpoint and exit", + ) + parser.add_argument( + "--eval_subsample", + default=5, + type=int, + help="subsampling during evaluation loop", + ) + parser.add_argument( + "--skip_seen", + default=False, + type=bool, + help="skip seen dataset. useful for test set inference", + ) + parser.add_argument( + "--save_audio", + default=False, + type=bool, + help="save audio of test set inference to disk", + ) + + a = parser.parse_args() + + with open(a.config) as f: + data = f.read() + + json_config = json.loads(data) + h = AttrDict(json_config) + + build_env(a.config, "config.json", a.checkpoint_path) + + torch.manual_seed(h.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(h.seed) + h.num_gpus = torch.cuda.device_count() + h.batch_size = int(h.batch_size / h.num_gpus) + print(f"Batch size per GPU: {h.batch_size}") + else: + pass + + if h.num_gpus > 1: + mp.spawn( + train, + nprocs=h.num_gpus, + args=( + a, + h, + ), + ) + else: + train(0, a, h) + + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/BigVGAN/utils0.py b/GPT_SoVITS/BigVGAN/utils0.py new file mode 100644 index 0000000000000000000000000000000000000000..da98a24cf1447778305563f8e909f30b06e06b26 --- /dev/null +++ b/GPT_SoVITS/BigVGAN/utils0.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import glob +import os +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt +from .meldataset import MAX_WAV_VALUE +from scipy.io.wavfile import write + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def plot_spectrogram_clipped(spectrogram, clip_max=2.0): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, + aspect="auto", + origin="lower", + interpolation="none", + vmin=1e-6, + vmax=clip_max, + ) + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix, renamed_file=None): + # Fallback to original scanning logic first + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + + if len(cp_list) > 0: + last_checkpoint_path = sorted(cp_list)[-1] + print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'") + return last_checkpoint_path + + # If no pattern-based checkpoints are found, check for renamed file + if renamed_file: + renamed_path = os.path.join(cp_dir, renamed_file) + if os.path.isfile(renamed_path): + print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'") + return renamed_path + + return None + + +def save_audio(audio, path, sr): + # wav: torch with 1d shape + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype("int16") + write(path, sr, audio) diff --git a/GPT_SoVITS/download.py b/GPT_SoVITS/download.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4ead63bfe3c15326212a6ebabe2dac166e0ff2 --- /dev/null +++ b/GPT_SoVITS/download.py @@ -0,0 +1,13 @@ +import os +import sys + +now_dir = os.getcwd() +sys.path.insert(0, now_dir) +from text.g2pw import G2PWPinyin + +g2pw = G2PWPinyin( + model_dir="GPT_SoVITS/text/G2PWModel", + model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + v_to_u=False, + neutral_tone_with_five=True, +) diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py new file mode 100644 index 0000000000000000000000000000000000000000..69817a3763b140430ad68907be0766823c170056 --- /dev/null +++ b/GPT_SoVITS/export_torch_script.py @@ -0,0 +1,861 @@ +# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py +# reference: https://github.com/lifeiteng/vall-e +import argparse +from typing import Optional +from my_utils import load_audio +import torch +import torchaudio + +from torch import IntTensor, LongTensor, Tensor, nn +from torch.nn import functional as F + +from transformers import AutoModelForMaskedLM, AutoTokenizer +from feature_extractor import cnhubert + +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from module.models_onnx import SynthesizerTrn + +from inference_webui import get_phones_and_bert + +import os +import soundfile + +default_config = { + "embedding_dim": 512, + "hidden_dim": 512, + "num_head": 8, + "num_layers": 12, + "num_codebook": 8, + "p_dropout": 0.0, + "vocab_size": 1024 + 1, + "phoneme_vocab_size": 512, + "EOS": 1024, +} + + +def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: + config = dict_s1["config"] + config["model"]["dropout"] = float(config["model"]["dropout"]) + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + t2s_model = t2s_model.eval() + return t2s_model + + +@torch.jit.script +def logits_to_probs( + logits, + previous_tokens: Optional[torch.Tensor] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + repetition_penalty: float = 1.0, +): + # if previous_tokens is not None: + # previous_tokens = previous_tokens.squeeze() + # print(logits.shape,previous_tokens.shape) + # pdb.set_trace() + if previous_tokens is not None and repetition_penalty != 1.0: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=1, index=previous_tokens) + score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) + logits.scatter_(dim=1, index=previous_tokens, src=score) + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[:, 0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v[:, -1].unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +@torch.jit.script +def multinomial_sample_one_no_sync(probs_sort): + # Does multinomial sampling without a cuda synchronization + q = torch.randn_like(probs_sort) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +@torch.jit.script +def sample( + logits, + previous_tokens, + temperature: float = 1.0, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + repetition_penalty: float = 1.0, +): + probs = logits_to_probs( + logits=logits, + previous_tokens=previous_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +@torch.jit.script +def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False): + hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +@torch.jit.script +class T2SMLP: + def __init__(self, w1, b1, w2, b2): + self.w1 = w1 + self.b1 = b1 + self.w2 = w2 + self.b2 = b2 + + def forward(self, x): + x = F.relu(F.linear(x, self.w1, self.b1)) + x = F.linear(x, self.w2, self.b2) + return x + + +@torch.jit.script +class T2SBlock: + def __init__( + self, + num_heads: int, + hidden_dim: int, + mlp: T2SMLP, + qkv_w, + qkv_b, + out_w, + out_b, + norm_w1, + norm_b1, + norm_eps1: float, + norm_w2, + norm_b2, + norm_eps2: float, + ): + self.num_heads = num_heads + self.mlp = mlp + self.hidden_dim: int = hidden_dim + self.qkv_w = qkv_w + self.qkv_b = qkv_b + self.out_w = out_w + self.out_b = out_b + self.norm_w1 = norm_w1 + self.norm_b1 = norm_b1 + self.norm_eps1 = norm_eps1 + self.norm_w2 = norm_w2 + self.norm_b2 = norm_b2 + self.norm_eps2 = norm_eps2 + + self.false = torch.tensor(False, dtype=torch.bool) + + @torch.jit.ignore + def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]): + if padding_mask is None: + return x + + if padding_mask.dtype == torch.bool: + return x.masked_fill(padding_mask, 0) + else: + return x * padding_mask + + def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k.shape[1] + + q = self.to_mask(q, padding_mask) + k_cache = self.to_mask(k, padding_mask) + v_cache = self.to_mask(v, padding_mask) + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b) + + if padding_mask is not None: + for i in range(batch_size): + # mask = padding_mask[i,:,0] + if self.false.device != padding_mask.device: + self.false = self.false.to(padding_mask.device) + idx = torch.where(padding_mask[i, :, 0] == self.false)[0] + x_item = x[i, idx, :].unsqueeze(0) + attn_item = attn[i, idx, :].unsqueeze(0) + x_item = x_item + attn_item + x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x_item = x_item + self.mlp.forward(x_item) + x_item = F.layer_norm( + x_item, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + x[i, idx, :] = x_item.squeeze(0) + x = self.to_mask(x, padding_mask) + else: + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor): + q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1) + + k_cache = torch.cat([k_cache, k], dim=1) + v_cache = torch.cat([v_cache, v], dim=1) + + batch_size = q.shape[0] + q_len = q.shape[1] + kv_len = k_cache.shape[1] + + q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2) + k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2) + + attn = F.scaled_dot_product_attention(q, k, v) + + attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) + attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) + attn = F.linear(attn, self.out_w, self.out_b) + + x = x + attn + x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1) + x = x + self.mlp.forward(x) + x = F.layer_norm( + x, + [self.hidden_dim], + self.norm_w2, + self.norm_b2, + self.norm_eps2, + ) + return x, k_cache, v_cache + + +@torch.jit.script +class T2STransformer: + def __init__(self, num_blocks: int, blocks: list[T2SBlock]): + self.num_blocks: int = num_blocks + self.blocks = blocks + + def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + k_cache: list[torch.Tensor] = [] + v_cache: list[torch.Tensor] = [] + for i in range(self.num_blocks): + x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask) + k_cache.append(k_cache_) + v_cache.append(v_cache_) + return x, k_cache, v_cache + + def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]): + for i in range(self.num_blocks): + x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) + return x, k_cache, v_cache + + +class VitsModel(nn.Module): + def __init__(self, vits_path): + super().__init__() + # dict_s2 = torch.load(vits_path,map_location="cpu") + dict_s2 = torch.load(vits_path) + self.hps = dict_s2["config"] + if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = "v2" + + self.hps = DictToAttrRecursive(self.hps) + self.hps.model.semantic_frame_rate = "25hz" + self.vq_model = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + n_speakers=self.hps.data.n_speakers, + **self.hps.model, + ) + self.vq_model.eval() + self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + + def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0): + refer = spectrogram_torch( + ref_audio, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0] + + +class T2SModel(nn.Module): + def __init__(self, raw_t2s: Text2SemanticLightningModule): + super(T2SModel, self).__init__() + self.model_dim = raw_t2s.model.model_dim + self.embedding_dim = raw_t2s.model.embedding_dim + self.num_head = raw_t2s.model.num_head + self.num_layers = raw_t2s.model.num_layers + self.vocab_size = raw_t2s.model.vocab_size + self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size + # self.p_dropout = float(raw_t2s.model.p_dropout) + self.EOS: int = int(raw_t2s.model.EOS) + self.norm_first = raw_t2s.model.norm_first + assert self.EOS == self.vocab_size - 1 + self.hz = 50 + + self.bert_proj = raw_t2s.model.bert_proj + self.ar_text_embedding = raw_t2s.model.ar_text_embedding + self.ar_text_position = raw_t2s.model.ar_text_position + self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding + self.ar_audio_position = raw_t2s.model.ar_audio_position + + # self.t2s_transformer = T2STransformer(self.num_layers, blocks) + # self.t2s_transformer = raw_t2s.model.t2s_transformer + + blocks = [] + h = raw_t2s.model.h + + for i in range(self.num_layers): + layer = h.layers[i] + t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias) + + block = T2SBlock( + self.num_head, + self.model_dim, + t2smlp, + layer.self_attn.in_proj_weight, + layer.self_attn.in_proj_bias, + layer.self_attn.out_proj.weight, + layer.self_attn.out_proj.bias, + layer.norm1.weight, + layer.norm1.bias, + layer.norm1.eps, + layer.norm2.weight, + layer.norm2.bias, + layer.norm2.eps, + ) + + blocks.append(block) + + self.t2s_transformer = T2STransformer(self.num_layers, blocks) + + # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False) + self.ar_predict_layer = raw_t2s.model.ar_predict_layer + # self.loss_fct = nn.CrossEntropyLoss(reduction="sum") + self.max_sec = raw_t2s.config["data"]["max_sec"] + self.top_k = int(raw_t2s.config["inference"]["top_k"]) + self.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) + + def forward( + self, + prompts: LongTensor, + ref_seq: LongTensor, + text_seq: LongTensor, + ref_bert: torch.Tensor, + text_bert: torch.Tensor, + top_k: LongTensor, + ): + bert = torch.cat([ref_bert.T, text_bert.T], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + + x = self.ar_text_embedding(all_phoneme_ids) + x = x + self.bert_proj(bert.transpose(1, 2)) + x: torch.Tensor = self.ar_text_position(x) + + early_stop_num = self.early_stop_num + + # [1,N,512] [1,N] + # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + y = prompts + # x_example = x[:,:,0] * 0.0 + + x_len = x.shape[1] + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + y_emb = self.ar_audio_embedding(y) + y_len = y_emb.shape[1] + prefix_len = y.shape[1] + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + bsz = x.shape[0] + src_len = x_len + y_len + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y) + value=True, + ) + y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y) + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = ( + torch.concat([x_attn_mask_pad, y_attn_mask], dim=0) + .unsqueeze(0) + .expand(bsz * self.num_head, -1, -1) + .view(bsz, self.num_head, src_len, src_len) + .to(device=x.device, dtype=torch.bool) + ) + + idx = 0 + top_k = int(top_k) + + xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None) + + logits = self.ar_predict_layer(xy_dec[:, -1]) + logits = logits[:, :-1] + samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] + y = torch.concat([y, samples], dim=1) + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ + :, y_len + idx + ].to(dtype=y_emb.dtype, device=y_emb.device) + + stop = False + # for idx in range(1, 50): + for idx in range(1, 1500): + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example) + xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache) + logits = self.ar_predict_layer(xy_dec[:, -1]) + + if idx < 11: ###至少预测出10个token不然不给停止(0.4s) + logits = logits[:, :-1] + + samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0] + + y = torch.concat([y, samples], dim=1) + + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: + stop = True + if stop: + if y.shape[1] == 0: + y = torch.concat([y, torch.zeros_like(samples)], dim=1) + break + + y_emb = self.ar_audio_embedding(y[:, -1:]) + xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[ + :, y_len + idx + ].to(dtype=y_emb.dtype, device=y_emb.device) + + y[0, -1] = 0 + + return y[:, -idx:].unsqueeze(0) + + +bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large") +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +cnhubert.cnhubert_base_path = cnhubert_base_path + + +@torch.jit.script +def build_phone_level_feature(res: Tensor, word2ph: IntTensor): + phone_level_feature = [] + for i in range(word2ph.shape[0]): + repeat_feature = res[i].repeat(word2ph[i].item(), 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + # [sum(word2ph), 1024] + return phone_level_feature + + +class MyBertModel(torch.nn.Module): + def __init__(self, bert_model): + super(MyBertModel, self).__init__() + self.bert = bert_model + + def forward( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor + ): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) + # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] + res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1] + return build_phone_level_feature(res, word2ph) + + +class SSLModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.ssl = cnhubert.get_model().model + + def forward(self, ref_audio_16k) -> torch.Tensor: + ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + return ssl_content + + +class ExportSSLModel(torch.nn.Module): + def __init__(self, ssl: SSLModel): + super().__init__() + self.ssl = ssl + + def forward(self, ref_audio: torch.Tensor): + return self.ssl(ref_audio) + + @torch.jit.export + def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor: + audio = resamplex(ref_audio, src_sr, dst_sr).float() + return audio + + +def export_bert(output_path): + tokenizer = AutoTokenizer.from_pretrained(bert_path) + + text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么." + ref_bert_inputs = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [",", "。", ":", "?", ",", ".", "?"]: + word2ph.append(1) + else: + word2ph.append(2) + ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int() + + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True) + my_bert_model = MyBertModel(bert_model) + + ref_bert_inputs = { + "input_ids": ref_bert_inputs["input_ids"], + "attention_mask": ref_bert_inputs["attention_mask"], + "token_type_ids": ref_bert_inputs["token_type_ids"], + "word2ph": ref_bert_inputs["word2ph"], + } + + torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0) + + my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs) + output_path = os.path.join(output_path, "bert_model.pt") + my_bert_model.save(output_path) + print("#### exported bert ####") + + +def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"): + if not os.path.exists(output_path): + os.makedirs(output_path) + print(f"目录已创建: {output_path}") + else: + print(f"目录已存在: {output_path}") + + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ssl = SSLModel() + if export_bert_and_ssl: + s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio))) + ssl_path = os.path.join(output_path, "ssl_model.pt") + torch.jit.script(s).save(ssl_path) + print("#### exported ssl ####") + export_bert(output_path) + else: + s = ExportSSLModel(ssl) + + print(f"device: {device}") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") + ref_seq = torch.LongTensor([ref_seq_id]).to(device) + ref_bert = ref_bert_T.T.to(ref_seq.device) + text_seq_id, text_bert_T, norm_text = get_phones_and_bert( + "这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2" + ) + text_seq = torch.LongTensor([text_seq_id]).to(device) + text_bert = text_bert_T.T.to(text_seq.device) + + ssl_content = ssl(ref_audio).to(device) + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + vits = VitsModel(vits_path).to(device) + vits.eval() + + # gpt_path = "GPT_weights_v2/xw-e15.ckpt" + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + t2s = torch.jit.script(t2s_m).to(device) + print("#### script t2s_m ####") + + print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate) + gpt_sovits = GPT_SoVITS(t2s, vits).to(device) + gpt_sovits.eval() + + ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + + top_k = torch.LongTensor([5]).to(device) + + with torch.no_grad(): + gpt_sovits_export = torch.jit.trace( + gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k) + ) + + gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") + gpt_sovits_export.save(gpt_sovits_path) + print("#### exported gpt_sovits ####") + + +@torch.jit.script +def parse_audio(ref_audio): + ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device) + ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device) + return ref_audio_16k, ref_audio_sr + + +@torch.jit.script +def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor: + return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float() + + +class GPT_SoVITS(nn.Module): + def __init__(self, t2s: T2SModel, vits: VitsModel): + super().__init__() + self.t2s = t2s + self.vits = vits + + def forward( + self, + ssl_content: torch.Tensor, + ref_audio_sr: torch.Tensor, + ref_seq: Tensor, + text_seq: Tensor, + ref_bert: Tensor, + text_bert: Tensor, + top_k: LongTensor, + speed=1.0, + ): + codes = self.vits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompts = prompt_semantic.unsqueeze(0) + + pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k) + audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed) + return audio + + +def test(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") + parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") + parser.add_argument("--ref_text", required=True, help="Path to the reference text file") + parser.add_argument("--output_path", required=True, help="Path to the output directory") + + args = parser.parse_args() + gpt_path = args.gpt_model + vits_path = args.sovits_model + ref_audio_path = args.ref_audio + ref_text = args.ref_text + + tokenizer = AutoTokenizer.from_pretrained(bert_path) + # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) + # bert = MyBertModel(bert_model) + my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda") + + # dict_s1 = torch.load(gpt_path, map_location="cuda") + # raw_t2s = get_raw_t2s_model(dict_s1) + # t2s = T2SModel(raw_t2s) + # t2s.eval() + # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') + + # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" + # vits = VitsModel(vits_path) + # vits.eval() + + # ssl = ExportSSLModel(SSLModel()).to('cuda') + # ssl.eval() + ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda") + + # gpt_sovits = GPT_SoVITS(t2s,vits) + gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda") + + ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2") + ref_seq = torch.LongTensor([ref_seq_id]) + ref_bert = ref_bert_T.T.to(ref_seq.device) + # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2') + text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字." + + text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2") + + test_bert = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [",", "。", ":", "?", "?", ",", "."]: + word2ph.append(1) + else: + word2ph.append(2) + test_bert["word2ph"] = torch.Tensor(word2ph).int() + + test_bert = my_bert( + test_bert["input_ids"].to("cuda"), + test_bert["attention_mask"].to("cuda"), + test_bert["token_type_ids"].to("cuda"), + test_bert["word2ph"].to("cuda"), + ) + + text_seq = torch.LongTensor([text_seq_id]) + text_bert = text_bert_T.T.to(text_seq.device) + + print("text_bert:", text_bert.shape, text_bert) + print("test_bert:", test_bert.shape, test_bert) + print(torch.allclose(text_bert.to("cuda"), test_bert)) + + print("text_seq:", text_seq.shape) + print("text_bert:", text_bert.shape, text_bert.type()) + + # [1,N] + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda") + print("ref_audio:", ref_audio.shape) + + ref_audio_sr = ssl.resample(ref_audio, 16000, 32000) + print("start ssl") + ssl_content = ssl(ref_audio) + + print("start gpt_sovits:") + print("ssl_content:", ssl_content.shape) + print("ref_audio_sr:", ref_audio_sr.shape) + print("ref_seq:", ref_seq.shape) + ref_seq = ref_seq.to("cuda") + print("text_seq:", text_seq.shape) + text_seq = text_seq.to("cuda") + print("ref_bert:", ref_bert.shape) + ref_bert = ref_bert.to("cuda") + print("text_bert:", text_bert.shape) + text_bert = text_bert.to("cuda") + + top_k = torch.LongTensor([5]).to("cuda") + + with torch.no_grad(): + audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k) + print("start write wav") + soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000) + + +import text +import json + + +def export_symbel(version="v2"): + if version == "v1": + symbols = text._symbol_to_id_v1 + with open("onnx/symbols_v1.json", "w") as file: + json.dump(symbols, file, indent=4) + else: + symbols = text._symbol_to_id_v2 + with open("onnx/symbols_v2.json", "w") as file: + json.dump(symbols, file, indent=4) + + +def main(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") + parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") + parser.add_argument("--ref_text", required=True, help="Path to the reference text file") + parser.add_argument("--output_path", required=True, help="Path to the output directory") + parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model") + parser.add_argument("--device", help="Device to use") + + args = parser.parse_args() + export( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + export_bert_and_ssl=args.export_common_model, + ) + + +import inference_webui + +if __name__ == "__main__": + inference_webui.is_half = False + inference_webui.dtype = torch.float32 + main() + # test() diff --git a/GPT_SoVITS/export_torch_script_v3.py b/GPT_SoVITS/export_torch_script_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..b34495a7adc3486ea2eb482c43c41e20ddbc7c48 --- /dev/null +++ b/GPT_SoVITS/export_torch_script_v3.py @@ -0,0 +1,1035 @@ +import os +from export_torch_script import ( + T2SModel, + get_raw_t2s_model, + resamplex, + spectrogram_torch, +) +from f5_tts.model.backbones.dit import DiT +from inference_webui import get_phones_and_bert +import librosa +from module import commons +from module.mel_processing import mel_spectrogram_torch +from module.models_onnx import CFM, SynthesizerTrnV3 +import numpy as np +import torch._dynamo.config +import torchaudio +import logging +import uvicorn +import torch +import soundfile +from librosa.filters import mel as librosa_mel_fn + + +from inference_webui import get_spepc, norm_spec, resample, ssl_model + +logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) +logger = logging.getLogger("uvicorn") + +is_half = True +device = "cuda" if torch.cuda.is_available() else "cpu" +now_dir = os.getcwd() + + +class MelSpectrgram(torch.nn.Module): + def __init__( + self, + dtype, + device, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False, + ): + super().__init__() + self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype) + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device) + self.n_fft: int = n_fft + self.hop_size: int = hop_size + self.win_size: int = win_size + self.center: bool = center + + def forward(self, y): + y = torch.nn.functional.pad( + y.unsqueeze(1), + ( + int((self.n_fft - self.hop_size) / 2), + int((self.n_fft - self.hop_size) / 2), + ), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=self.center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9) + spec = torch.matmul(self.mel_basis, spec) + # spec = spectral_normalize_torch(spec) + spec = torch.log(torch.clamp(spec, min=1e-5)) + return spec + + +class ExportDitBlocks(torch.nn.Module): + def __init__(self, dit: DiT): + super().__init__() + self.transformer_blocks = dit.transformer_blocks + self.norm_out = dit.norm_out + self.proj_out = dit.proj_out + self.depth = dit.depth + + def forward(self, x, t, mask, rope): + for block in self.transformer_blocks: + x = block(x, t, mask=mask, rope=(rope, 1.0)) + x = self.norm_out(x, t) + output = self.proj_out(x) + return output + + +class ExportDitEmbed(torch.nn.Module): + def __init__(self, dit: DiT): + super().__init__() + self.time_embed = dit.time_embed + self.d_embed = dit.d_embed + self.text_embed = dit.text_embed + self.input_embed = dit.input_embed + self.rotary_embed = dit.rotary_embed + self.rotary_embed.inv_freq.to(device) + + def forward( + self, + x0: torch.Tensor, # nosied input audio # noqa: F722 + cond0: torch.Tensor, # masked cond audio # noqa: F722 + x_lens: torch.Tensor, + time: torch.Tensor, # time step # noqa: F821 F722 + dt_base_bootstrap: torch.Tensor, + text0: torch.Tensor, # noqa: F722#####condition feature + ): + x = x0.transpose(2, 1) + cond = cond0.transpose(2, 1) + text = text0.transpose(2, 1) + mask = commons.sequence_mask(x_lens, max_length=x.size(1)).to(x.device) + + t = self.time_embed(time) + self.d_embed(dt_base_bootstrap) + text_embed = self.text_embed(text, x.shape[1]) + rope_t = torch.arange(x.shape[1], device=device) + rope, _ = self.rotary_embed(rope_t) + x = self.input_embed(x, cond, text_embed) + return x, t, mask, rope + + +class ExportDiT(torch.nn.Module): + def __init__(self, dit: DiT): + super().__init__() + if dit != None: + self.embed = ExportDitEmbed(dit) + self.blocks = ExportDitBlocks(dit) + else: + self.embed = None + self.blocks = None + + def forward( # x, prompt_x, x_lens, t, style,cond + self, # d is channel,n is T + x0: torch.Tensor, # nosied input audio # noqa: F722 + cond0: torch.Tensor, # masked cond audio # noqa: F722 + x_lens: torch.Tensor, + time: torch.Tensor, # time step # noqa: F821 F722 + dt_base_bootstrap: torch.Tensor, + text0: torch.Tensor, # noqa: F722#####condition feature + ): + x, t, mask, rope = self.embed(x0, cond0, x_lens, time, dt_base_bootstrap, text0) + output = self.blocks(x, t, mask, rope) + return output + + +class ExportCFM(torch.nn.Module): + def __init__(self, cfm: CFM): + super().__init__() + self.cfm = cfm + + def forward( + self, + fea_ref: torch.Tensor, + fea_todo_chunk: torch.Tensor, + mel2: torch.Tensor, + sample_steps: torch.LongTensor, + ): + T_min = fea_ref.size(2) + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + return cfm_res, fea_ref, mel2 + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + +spec_min = -12 +spec_max = 2 + + +@torch.jit.script +def norm_spec(x): + spec_min = -12 + spec_max = 2 + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + spec_min = -12 + spec_max = 2 + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +class ExportGPTSovitsHalf(torch.nn.Module): + def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3): + super().__init__() + self.hps = hps + self.t2s_m = t2s_m + self.vq_model = vq_model + self.mel2 = MelSpectrgram( + dtype=torch.float32, + device=device, + n_fft=1024, + num_mels=100, + sampling_rate=24000, + hop_size=256, + win_size=1024, + fmin=0, + fmax=None, + center=False, + ) + # self.dtype = dtype + self.filter_length: int = hps.data.filter_length + self.sampling_rate: int = hps.data.sampling_rate + self.hop_length: int = hps.data.hop_length + self.win_length: int = hps.data.win_length + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0, + phoneme_ids1, + bert1, + bert2, + top_k, + ): + refer = spectrogram_torch( + ref_audio_32k, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ).to(ssl_content.dtype) + + codes = self.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0) + # print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) + # print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + ge = self.vq_model.create_ge(refer) + # print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + + prompt_ = prompt.unsqueeze(0) + fea_ref = self.vq_model(prompt_, phoneme_ids0, ge) + # print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(prompt_.shape, phoneme_ids0.shape, ge.shape) + # print(fea_ref.shape) + + ref_24k = resamplex(ref_audio_32k, 32000, 24000) + mel2 = norm_spec(self.mel2(ref_24k)).to(ssl_content.dtype) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if T_min > 468: + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + + fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge) + # print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + # print(pred_semantic.shape, phoneme_ids1.shape, ge.shape) + # print(fea_todo.shape) + + return fea_ref, fea_todo, mel2 + + +class GPTSoVITSV3(torch.nn.Module): + def __init__(self, gpt_sovits_half, cfm, bigvgan): + super().__init__() + self.gpt_sovits_half = gpt_sovits_half + self.cfm = cfm + self.bigvgan = bigvgan + + def forward( + self, + ssl_content, + ref_audio_32k: torch.FloatTensor, + phoneme_ids0: torch.LongTensor, + phoneme_ids1: torch.LongTensor, + bert1, + bert2, + top_k: torch.LongTensor, + sample_steps: torch.LongTensor, + ): + # current_time = datetime.now() + # print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_ref, fea_todo, mel2 = self.gpt_sovits_half( + ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k + ) + chunk_len = 934 - fea_ref.shape[2] + wav_gen_list = [] + idx = 0 + wav_gen_length = fea_todo.shape[2] * 256 + while 1: + # current_time = datetime.now() + # print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + # 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样, + # 所以在这里补0让他shape维持不变 + # 但是这样会导致生成的音频长度不对,所以在最后截取一下。 + # 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256 + complete_len = chunk_len - fea_todo_chunk.shape[-1] + if complete_len != 0: + fea_todo_chunk = torch.cat( + [ + fea_todo_chunk, + torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype), + ], + 2, + ) + + cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps) + idx += chunk_len + + cfm_res = denorm_spec(cfm_res) + bigvgan_res = self.bigvgan(cfm_res) + wav_gen_list.append(bigvgan_res) + + wav_gen = torch.cat(wav_gen_list, 2) + return wav_gen[0][0][:wav_gen_length] + + +def init_bigvgan(): + global bigvgan_model + from BigVGAN import bigvgan + + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +class Sovits: + def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps): + self.vq_model = vq_model + self.hps = hps + cfm.estimator = ExportDiT(cfm.estimator) + self.cfm = cfm + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new + + +def get_sovits_weights(sovits_path): + path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + is_exist_s2gv3 = os.path.exists(path_sovits_v3) + + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + if if_lora_v3 == True and is_exist_s2gv3 == False: + logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + + dict_s2 = load_sovits_new(sovits_path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps.model.version = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + + if model_version == "v3": + hps.model.version = "v3" + + logger.info(f"hps: {hps}") + + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + # init_bigvgan() + model_version = hps.model.version + logger.info(f"模型版本: {model_version}") + + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.eval() + + cfm = vq_model.cfm + del vq_model.cfm + + sovits = Sovits(vq_model, cfm, hps) + return sovits + + +logger.info(f"torch version {torch.__version__}") +# ssl_model = cnhubert.get_model() +# if is_half: +# ssl_model = ssl_model.half().to(device) +# else: +# ssl_model = ssl_model.to(device) + + +def export_cfm( + e_cfm: ExportCFM, + mu: torch.Tensor, + x_lens: torch.LongTensor, + prompt: torch.Tensor, + n_timesteps: torch.IntTensor, + temperature=1.0, +): + cfm = e_cfm.cfm + + B, T = mu.size(0), mu.size(1) + x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature + print("x:", x.shape, x.dtype) + prompt_len = prompt.size(-1) + prompt_x = torch.zeros_like(x, dtype=mu.dtype) + prompt_x[..., :prompt_len] = prompt[..., :prompt_len] + x[..., :prompt_len] = 0.0 + mu = mu.transpose(2, 1) + + ntimestep = int(n_timesteps) + + t = torch.tensor(0.0, dtype=x.dtype, device=x.device) + d = torch.tensor(1.0 / ntimestep, dtype=x.dtype, device=x.device) + + t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t + d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d + + print( + "cfm input shapes:", + x.shape, + prompt_x.shape, + x_lens.shape, + t_tensor.shape, + d_tensor.shape, + mu.shape, + ) + + print("cfm input dtypes:", x.dtype, prompt_x.dtype, x_lens.dtype, t_tensor.dtype, d_tensor.dtype, mu.dtype) + + estimator: ExportDiT = torch.jit.trace( + cfm.estimator, + optimize=True, + example_inputs=(x, prompt_x, x_lens, t_tensor, d_tensor, mu), + ) + estimator.save("onnx/ad/estimator.pt") + # torch.onnx.export( + # cfm.estimator, + # (x, prompt_x, x_lens, t_tensor, d_tensor, mu), + # "onnx/ad/dit.onnx", + # input_names=["x", "prompt_x", "x_lens", "t", "d", "mu"], + # output_names=["output"], + # dynamic_axes={ + # "x": [2], + # "prompt_x": [2], + # "mu": [2], + # }, + # ) + print("save estimator ok") + cfm.estimator = estimator + export_cfm = torch.jit.script(e_cfm) + export_cfm.save("onnx/ad/cfm.pt") + # sovits.cfm = cfm + # cfm.save("onnx/ad/cfm.pt") + return export_cfm + + +def export(): + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") + + init_bigvgan() + + dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt") + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + + if is_half: + raw_t2s = raw_t2s.half().to(device) + + t2s_m = T2SModel(raw_t2s) + t2s_m.eval() + script_t2s = torch.jit.script(t2s_m).to(device) + + hps = sovits.hps + ref_wav_path = "onnx/ad/ref.wav" + speed = 1.0 + sample_steps = 32 + dtype = torch.float16 if is_half == True else torch.float32 + refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) + zero_wav = np.zeros( + int(hps.data.sampling_rate * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) + + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = sovits.vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + phones1, bert1, norm_text1 = get_phones_and_bert( + "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" + ) + phones2, bert2, norm_text2 = get_phones_and_bert( + "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", + "auto", + "v3", + ) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + # codes = sovits.vq_model.extract_latent(ssl_content) + # prompt_semantic = codes[0, 0] + # prompts = prompt_semantic.unsqueeze(0) + + top_k = torch.LongTensor([15]).to(device) + print("topk", top_k) + + bert1 = bert1.T.to(device) + bert2 = bert2.T.to(device) + print( + prompt.dtype, + phoneme_ids0.dtype, + phoneme_ids1.dtype, + bert1.dtype, + bert2.dtype, + top_k.dtype, + ) + print( + prompt.shape, + phoneme_ids0.shape, + phoneme_ids1.shape, + bert1.shape, + bert2.shape, + top_k.shape, + ) + pred_semantic = t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k) + + ge = sovits.vq_model.create_ge(refer) + prompt_ = prompt.unsqueeze(0) + + torch._dynamo.mark_dynamic(prompt_, 2) + torch._dynamo.mark_dynamic(phoneme_ids0, 1) + + fea_ref = sovits.vq_model(prompt_, phoneme_ids0, ge) + + inputs = { + "forward": (prompt_, phoneme_ids0, ge), + "extract_latent": ssl_content, + "create_ge": refer, + } + + trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True) + trace_vq_model.save("onnx/ad/vq_model.pt") + + print(fea_ref.shape, fea_ref.dtype, ge.shape) + print(prompt_.shape, phoneme_ids0.shape, ge.shape) + + # vq_model = torch.jit.trace( + # sovits.vq_model, + # optimize=True, + # # strict=False, + # example_inputs=(prompt_, phoneme_ids0, ge), + # ) + # vq_model = sovits.vq_model + vq_model = trace_vq_model + + gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model) + torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt") + + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + if sr != 24000: + ref_audio = resample(ref_audio, sr) + # mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel_fn(ref_audio)) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + fea_ref = fea_ref[:, :, :T_min] + print("fea_ref:", fea_ref.shape, T_min) + if T_min > 468: + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min + mel2 = mel2.to(dtype) + + # fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge) + fea_todo = vq_model(pred_semantic, phoneme_ids1, ge) + + cfm_resss = [] + idx = 0 + sample_steps = torch.LongTensor([sample_steps]).to(device) + export_cfm_ = ExportCFM(sovits.cfm) + while 1: + print("idx:", idx) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + print( + "export_cfm:", + fea_ref.shape, + fea_todo_chunk.shape, + mel2.shape, + sample_steps.shape, + ) + if idx == 0: + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + export_cfm_ = export_cfm( + export_cfm_, + fea, + torch.LongTensor([fea.size(1)]).to(fea.device), + mel2, + sample_steps, + ) + # torch.onnx.export( + # export_cfm_, + # ( + # fea_ref, + # fea_todo_chunk, + # mel2, + # sample_steps, + # ), + # "onnx/ad/cfm.onnx", + # input_names=["fea_ref", "fea_todo_chunk", "mel2", "sample_steps"], + # output_names=["cfm_res", "fea_ref_", "mel2_"], + # dynamic_axes={ + # "fea_ref": [2], + # "fea_todo_chunk": [2], + # "mel2": [2], + # }, + # ) + + idx += chunk_len + + cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps) + cfm_resss.append(cfm_res) + continue + + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res).to(device) + print("cmf_res:", cmf_res.shape, cmf_res.dtype) + with torch.inference_mode(): + cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype) + torch._dynamo.mark_dynamic(cmf_res_rand, 2) + bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,)) + bigvgan_model_.save("onnx/ad/bigvgan_model.pt") + wav_gen = bigvgan_model(cmf_res) + print("wav_gen:", wav_gen.shape, wav_gen.dtype) + audio = wav_gen[0][0].cpu().detach().numpy() + + sr = 24000 + soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr) + + +from datetime import datetime + + +def test_export( + todo_text, + gpt_sovits_v3_half, + cfm, + bigvgan, + output, +): + # hps = sovits.hps + ref_wav_path = "onnx/ad/ref.wav" + speed = 1.0 + sample_steps = 8 + + dtype = torch.float16 if is_half == True else torch.float32 + + zero_wav = np.zeros( + int(16000 * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) + + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + + ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000) + ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float() + + phones1, bert1, norm_text1 = get_phones_and_bert( + "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" + ) + phones2, bert2, norm_text2 = get_phones_and_bert( + todo_text, + "zh", + "v3", + ) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + bert1 = bert1.T.to(device) + bert2 = bert2.T.to(device) + top_k = torch.LongTensor([15]).to(device) + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start inference %s", current_time) + print( + ssl_content.shape, + ref_audio_32k.shape, + phoneme_ids0.shape, + phoneme_ids1.shape, + bert1.shape, + bert2.shape, + top_k.shape, + ) + fea_ref, fea_todo, mel2 = gpt_sovits_v3_half( + ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k + ) + chunk_len = 934 - fea_ref.shape[2] + print(fea_ref.shape, fea_todo.shape, mel2.shape) + + cfm_resss = [] + sample_steps = torch.LongTensor([sample_steps]) + idx = 0 + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start cfm %s", current_time) + wav_gen_length = fea_todo.shape[2] * 256 + + while 1: + current_time = datetime.now() + print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S")) + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + + complete_len = chunk_len - fea_todo_chunk.shape[-1] + if complete_len != 0: + fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(device).to(dtype)], 2) + + cfm_res, fea_ref, mel2 = cfm(fea_ref, fea_todo_chunk, mel2, sample_steps) + # if complete_len > 0 : + # cfm_res = cfm_res[:, :, :-complete_len] + # fea_ref = fea_ref[:, :, :-complete_len] + # mel2 = mel2[:, :, :-complete_len] + + idx += chunk_len + + current_time = datetime.now() + print("cfm end", current_time.strftime("%Y-%m-%d %H:%M:%S")) + cfm_res = denorm_spec(cfm_res).to(device) + bigvgan_res = bigvgan(cfm_res) + cfm_resss.append(bigvgan_res) + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start bigvgan %s", current_time) + wav_gen = torch.cat(cfm_resss, 2) + # cmf_res = denorm_spec(cmf_res) + # cmf_res = cmf_res.to(device) + # print("cmf_res:", cmf_res.shape) + + # cmf_res = torch.cat([cmf_res,torch.zeros([1,100,2000-cmf_res.size(2)],device=device,dtype=cmf_res.dtype)], 2) + + # wav_gen = bigvgan(cmf_res) + print("wav_gen:", wav_gen.shape, wav_gen.dtype) + wav_gen = wav_gen[:, :, :wav_gen_length] + + audio = wav_gen[0][0].cpu().detach().numpy() + logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + sr = 24000 + soundfile.write(output, (audio * 32768).astype(np.int16), sr) + + +def test_export1( + todo_text, + gpt_sovits_v3, + output, +): + # hps = sovits.hps + ref_wav_path = "onnx/ad/ref.wav" + speed = 1.0 + sample_steps = torch.LongTensor([16]) + + dtype = torch.float16 if is_half == True else torch.float32 + + zero_wav = np.zeros( + int(24000 * 0.3), + dtype=np.float16 if is_half == True else np.float32, + ) + + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + print("ssl_content:", ssl_content.shape, ssl_content.dtype) + + ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000) + ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float() + + phones1, bert1, norm_text1 = get_phones_and_bert( + "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" + ) + phones2, bert2, norm_text2 = get_phones_and_bert( + todo_text, + "zh", + "v3", + ) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + bert1 = bert1.T.to(device) + bert2 = bert2.T.to(device) + top_k = torch.LongTensor([15]).to(device) + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + logger.info("start inference %s", current_time) + print( + ssl_content.shape, + ref_audio_32k.shape, + phoneme_ids0.shape, + phoneme_ids1.shape, + bert1.shape, + bert2.shape, + top_k.shape, + ) + wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps) + print("wav_gen:", wav_gen.shape, wav_gen.dtype) + + wav_gen = torch.cat([wav_gen, zero_wav_torch], 0) + + audio = wav_gen.cpu().detach().numpy() + logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + sr = 24000 + soundfile.write(output, (audio * 32768).astype(np.int16), sr) + + +import time + + +def test_(): + sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") + + # cfm = ExportCFM(sovits.cfm) + # cfm.cfm.estimator = dit + sovits.cfm = None + + cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device) + # cfm = torch.jit.optimize_for_inference(cfm) + cfm = cfm.half().to(device) + + cfm.eval() + + logger.info("cfm ok") + + dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt") + # v2 的 gpt 也可以用 + # dict_s1 = torch.load("GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt") + raw_t2s = get_raw_t2s_model(dict_s1).to(device) + print("#### get_raw_t2s_model ####") + print(raw_t2s.config) + if is_half: + raw_t2s = raw_t2s.half().to(device) + t2s_m = T2SModel(raw_t2s).half().to(device) + t2s_m.eval() + t2s_m = torch.jit.script(t2s_m) + t2s_m.eval() + # t2s_m.top_k = 15 + logger.info("t2s_m ok") + + vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device) + # vq_model = torch.jit.optimize_for_inference(vq_model) + # vq_model = vq_model.half().to(device) + vq_model.eval() + # vq_model = sovits.vq_model + logger.info("vq_model ok") + + # gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt") + # gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half) + # gpt_sovits_v3_half = gpt_sovits_v3_half.half() + # gpt_sovits_v3_half = gpt_sovits_v3_half.cuda() + # gpt_sovits_v3_half.eval() + gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model) + logger.info("gpt_sovits_v3_half ok") + + # init_bigvgan() + # global bigvgan_model + bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt") + # bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model) + bigvgan_model = bigvgan_model.half() + bigvgan_model = bigvgan_model.cuda() + bigvgan_model.eval() + + logger.info("bigvgan ok") + + gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model) + gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3) + gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt") + gpt_sovits_v3 = gpt_sovits_v3.half().to(device) + gpt_sovits_v3.eval() + print("save gpt_sovits_v3 ok") + + time.sleep(5) + # print("thread:", torch.get_num_threads()) + # print("thread:", torch.get_num_interop_threads()) + # torch.set_num_interop_threads(1) + # torch.set_num_threads(1) + + test_export1( + "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", + gpt_sovits_v3, + "out.wav", + ) + + test_export1( + "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!", + gpt_sovits_v3, + "out2.wav", + ) + + # test_export( + # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP. 哈哈哈...", + # gpt_sovits_v3_half, + # cfm, + # bigvgan_model, + # "out2.wav", + # ) + + +def test_export_gpt_sovits_v3(): + gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device) + # test_export1( + # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", + # gpt_sovits_v3, + # "out3.wav", + # ) + # test_export1( + # "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!", + # gpt_sovits_v3, + # "out4.wav", + # ) + test_export1( + "风萧萧兮易水寒,壮士一去兮不复还.", + gpt_sovits_v3, + "out5.wav", + ) + + +with torch.no_grad(): + # export() + test_() + # test_export_gpt_sovits_v3() diff --git a/GPT_SoVITS/inference_cli.py b/GPT_SoVITS/inference_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..459a3d3632f599768465c16f6d889f47af5fe271 --- /dev/null +++ b/GPT_SoVITS/inference_cli.py @@ -0,0 +1,86 @@ +import argparse +import os +import soundfile as sf + +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav + +i18n = I18nAuto() + + +def synthesize( + GPT_model_path, + SoVITS_model_path, + ref_audio_path, + ref_text_path, + ref_language, + target_text_path, + target_language, + output_path, +): + # Read reference text + with open(ref_text_path, "r", encoding="utf-8") as file: + ref_text = file.read() + + # Read target text + with open(target_text_path, "r", encoding="utf-8") as file: + target_text = file.read() + + # Change model weights + change_gpt_weights(gpt_path=GPT_model_path) + change_sovits_weights(sovits_path=SoVITS_model_path) + + # Synthesize audio + synthesis_result = get_tts_wav( + ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=i18n(ref_language), + text=target_text, + text_language=i18n(target_language), + top_p=1, + temperature=1, + ) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + output_wav_path = os.path.join(output_path, "output.wav") + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + print(f"Audio saved to {output_wav_path}") + + +def main(): + parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") + parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") + parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file") + parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file") + parser.add_argument("--ref_text", required=True, help="Path to the reference text file") + parser.add_argument( + "--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio" + ) + parser.add_argument("--target_text", required=True, help="Path to the target text file") + parser.add_argument( + "--target_language", + required=True, + choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], + help="Language of the target text", + ) + parser.add_argument("--output_path", required=True, help="Path to the output directory") + + args = parser.parse_args() + + synthesize( + args.gpt_model, + args.sovits_model, + args.ref_audio, + args.ref_text, + args.ref_language, + args.target_text, + args.target_language, + args.output_path, + ) + + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/inference_gui.py b/GPT_SoVITS/inference_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..379f7fa8cdb32b4b56db8b242717c23bdb51eca0 --- /dev/null +++ b/GPT_SoVITS/inference_gui.py @@ -0,0 +1,316 @@ +import os +import sys +from PyQt5.QtCore import QEvent +from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit +from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox +import soundfile as sf + +from tools.i18n.i18n import I18nAuto + +i18n = I18nAuto() + +from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav + + +class GPTSoVITSGUI(QMainWindow): + GPT_Path = gpt_path + SoVITS_Path = sovits_path + + def __init__(self): + super().__init__() + + self.setWindowTitle("GPT-SoVITS GUI") + self.setGeometry(800, 450, 950, 850) + + self.setStyleSheet(""" + QWidget { + background-color: #a3d3b1; + } + + QTabWidget::pane { + background-color: #a3d3b1; + } + + QTabWidget::tab-bar { + alignment: left; + } + + QTabBar::tab { + background: #8da4bf; + color: #ffffff; + padding: 8px; + } + + QTabBar::tab:selected { + background: #2a3f54; + } + + QLabel { + color: #000000; + } + + QPushButton { + background-color: #4CAF50; + color: white; + padding: 8px; + border: 1px solid #4CAF50; + border-radius: 4px; + } + + QPushButton:hover { + background-color: #45a049; + border: 1px solid #45a049; + box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1); + } + """) + + license_text = ( + "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. " + "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE." + ) + license_label = QLabel(license_text) + license_label.setWordWrap(True) + + self.GPT_model_label = QLabel("选择GPT模型:") + self.GPT_model_input = QLineEdit() + self.GPT_model_input.setPlaceholderText("拖拽或选择文件") + self.GPT_model_input.setText(self.GPT_Path) + self.GPT_model_input.setReadOnly(True) + self.GPT_model_button = QPushButton("选择GPT模型文件") + self.GPT_model_button.clicked.connect(self.select_GPT_model) + + self.SoVITS_model_label = QLabel("选择SoVITS模型:") + self.SoVITS_model_input = QLineEdit() + self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件") + self.SoVITS_model_input.setText(self.SoVITS_Path) + self.SoVITS_model_input.setReadOnly(True) + self.SoVITS_model_button = QPushButton("选择SoVITS模型文件") + self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model) + + self.ref_audio_label = QLabel("上传参考音频:") + self.ref_audio_input = QLineEdit() + self.ref_audio_input.setPlaceholderText("拖拽或选择文件") + self.ref_audio_input.setReadOnly(True) + self.ref_audio_button = QPushButton("选择音频文件") + self.ref_audio_button.clicked.connect(self.select_ref_audio) + + self.ref_text_label = QLabel("参考音频文本:") + self.ref_text_input = QLineEdit() + self.ref_text_input.setPlaceholderText("直接输入文字或上传文本") + self.ref_text_button = QPushButton("上传文本") + self.ref_text_button.clicked.connect(self.upload_ref_text) + + self.ref_language_label = QLabel("参考音频语言:") + self.ref_language_combobox = QComboBox() + self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"]) + self.ref_language_combobox.setCurrentText("多语种混合") + + self.target_text_label = QLabel("合成目标文本:") + self.target_text_input = QLineEdit() + self.target_text_input.setPlaceholderText("直接输入文字或上传文本") + self.target_text_button = QPushButton("上传文本") + self.target_text_button.clicked.connect(self.upload_target_text) + + self.target_language_label = QLabel("合成音频语言:") + self.target_language_combobox = QComboBox() + self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"]) + self.target_language_combobox.setCurrentText("多语种混合") + + self.output_label = QLabel("输出音频路径:") + self.output_input = QLineEdit() + self.output_input.setPlaceholderText("拖拽或选择文件") + self.output_input.setReadOnly(True) + self.output_button = QPushButton("选择文件夹") + self.output_button.clicked.connect(self.select_output_path) + + self.output_text = QTextEdit() + self.output_text.setReadOnly(True) + + self.add_drag_drop_events( + [ + self.GPT_model_input, + self.SoVITS_model_input, + self.ref_audio_input, + self.ref_text_input, + self.target_text_input, + self.output_input, + ] + ) + + self.synthesize_button = QPushButton("合成") + self.synthesize_button.clicked.connect(self.synthesize) + + self.clear_output_button = QPushButton("清空输出") + self.clear_output_button.clicked.connect(self.clear_output) + + self.status_bar = QStatusBar() + + main_layout = QVBoxLayout() + + input_layout = QGridLayout(self) + input_layout.setSpacing(10) + + input_layout.addWidget(license_label, 0, 0, 1, 3) + + input_layout.addWidget(self.GPT_model_label, 1, 0) + input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2) + input_layout.addWidget(self.GPT_model_button, 2, 2) + + input_layout.addWidget(self.SoVITS_model_label, 3, 0) + input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2) + input_layout.addWidget(self.SoVITS_model_button, 4, 2) + + input_layout.addWidget(self.ref_audio_label, 5, 0) + input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2) + input_layout.addWidget(self.ref_audio_button, 6, 2) + + input_layout.addWidget(self.ref_language_label, 7, 0) + input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1) + input_layout.addWidget(self.ref_text_label, 9, 0) + input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2) + input_layout.addWidget(self.ref_text_button, 10, 2) + + input_layout.addWidget(self.target_language_label, 11, 0) + input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1) + input_layout.addWidget(self.target_text_label, 13, 0) + input_layout.addWidget(self.target_text_input, 14, 0, 1, 2) + input_layout.addWidget(self.target_text_button, 14, 2) + + input_layout.addWidget(self.output_label, 15, 0) + input_layout.addWidget(self.output_input, 16, 0, 1, 2) + input_layout.addWidget(self.output_button, 16, 2) + + main_layout.addLayout(input_layout) + + output_layout = QVBoxLayout() + output_layout.addWidget(self.output_text) + main_layout.addLayout(output_layout) + + main_layout.addWidget(self.synthesize_button) + + main_layout.addWidget(self.clear_output_button) + + main_layout.addWidget(self.status_bar) + + self.central_widget = QWidget() + self.central_widget.setLayout(main_layout) + self.setCentralWidget(self.central_widget) + + def dragEnterEvent(self, event): + if event.mimeData().hasUrls(): + event.acceptProposedAction() + + def dropEvent(self, event): + if event.mimeData().hasUrls(): + file_paths = [url.toLocalFile() for url in event.mimeData().urls()] + if len(file_paths) == 1: + self.update_ref_audio(file_paths[0]) + else: + self.update_ref_audio(", ".join(file_paths)) + + def add_drag_drop_events(self, widgets): + for widget in widgets: + widget.setAcceptDrops(True) + widget.installEventFilter(self) + + def eventFilter(self, obj, event): + if event.type() in (QEvent.DragEnter, QEvent.Drop): + mime_data = event.mimeData() + if mime_data.hasUrls(): + event.acceptProposedAction() + + return super().eventFilter(obj, event) + + def select_GPT_model(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)") + if file_path: + self.GPT_model_input.setText(file_path) + + def select_SoVITS_model(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)") + if file_path: + self.SoVITS_model_input.setText(file_path) + + def select_ref_audio(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)") + if file_path: + self.update_ref_audio(file_path) + + def upload_ref_text(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") + if file_path: + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + self.ref_text_input.setText(content) + + def upload_target_text(self): + file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)") + if file_path: + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + self.target_text_input.setText(content) + + def select_output_path(self): + options = QFileDialog.Options() + options |= QFileDialog.DontUseNativeDialog + options |= QFileDialog.ShowDirsOnly + + folder_dialog = QFileDialog() + folder_dialog.setOptions(options) + folder_dialog.setFileMode(QFileDialog.Directory) + + if folder_dialog.exec_(): + folder_path = folder_dialog.selectedFiles()[0] + self.output_input.setText(folder_path) + + def update_ref_audio(self, file_path): + self.ref_audio_input.setText(file_path) + + def clear_output(self): + self.output_text.clear() + + def synthesize(self): + GPT_model_path = self.GPT_model_input.text() + SoVITS_model_path = self.SoVITS_model_input.text() + ref_audio_path = self.ref_audio_input.text() + language_combobox = self.ref_language_combobox.currentText() + language_combobox = i18n(language_combobox) + ref_text = self.ref_text_input.text() + target_language_combobox = self.target_language_combobox.currentText() + target_language_combobox = i18n(target_language_combobox) + target_text = self.target_text_input.text() + output_path = self.output_input.text() + + if GPT_model_path != self.GPT_Path: + change_gpt_weights(gpt_path=GPT_model_path) + self.GPT_Path = GPT_model_path + if SoVITS_model_path != self.SoVITS_Path: + change_sovits_weights(sovits_path=SoVITS_model_path) + self.SoVITS_Path = SoVITS_model_path + + synthesis_result = get_tts_wav( + ref_wav_path=ref_audio_path, + prompt_text=ref_text, + prompt_language=language_combobox, + text=target_text, + text_language=target_language_combobox, + ) + + result_list = list(synthesis_result) + + if result_list: + last_sampling_rate, last_audio_data = result_list[-1] + output_wav_path = os.path.join(output_path, "output.wav") + sf.write(output_wav_path, last_audio_data, last_sampling_rate) + + result = "Audio saved to " + output_wav_path + + self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000) + self.output_text.append("处理结果:\n" + result) + + +if __name__ == "__main__": + app = QApplication(sys.argv) + mainWin = GPTSoVITSGUI() + mainWin.show() + sys.exit(app.exec_()) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py new file mode 100644 index 0000000000000000000000000000000000000000..68648014207e393f239fa041223a0935111280ac --- /dev/null +++ b/GPT_SoVITS/inference_webui.py @@ -0,0 +1,1280 @@ +""" +按中英混合识别 +按日英混合识别 +多语种启动切分识别语种 +全部按中文识别 +全部按英文识别 +全部按日文识别 +""" + +import logging +import traceback +import warnings + +import torchaudio + +logging.getLogger("markdown_it").setLevel(logging.ERROR) +logging.getLogger("urllib3").setLevel(logging.ERROR) +logging.getLogger("httpcore").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("charset_normalizer").setLevel(logging.ERROR) +logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) +logging.getLogger("multipart.multipart").setLevel(logging.ERROR) +warnings.simplefilter(action="ignore", category=FutureWarning) + +import json +import os +import re +import sys + +import torch +from text.LangSegmenter import LangSegmenter + +try: + import gradio.analytics as analytics + + analytics.version_check = lambda: None +except: + ... +version = model_version = os.environ.get("version", "v2") +path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" +path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth" +is_exist_s2gv3 = os.path.exists(path_sovits_v3) +is_exist_s2gv4 = os.path.exists(path_sovits_v4) +pretrained_sovits_name = [ + "GPT_SoVITS/pretrained_models/s2G488k.pth", + "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + "GPT_SoVITS/pretrained_models/s2Gv3.pth", + "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", +] +pretrained_gpt_name = [ + "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", + "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", + "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "GPT_SoVITS/pretrained_models/s1v3.ckpt", +] + + +_ = [[], []] +for i in range(4): + if os.path.exists(pretrained_gpt_name[i]): + _[0].append(pretrained_gpt_name[i]) + if os.path.exists(pretrained_sovits_name[i]): + _[-1].append(pretrained_sovits_name[i]) +pretrained_gpt_name, pretrained_sovits_name = _ + + +if os.path.exists("./weight.json"): + pass +else: + with open("./weight.json", "w", encoding="utf-8") as file: + json.dump({"GPT": {}, "SoVITS": {}}, file) + +with open("./weight.json", "r", encoding="utf-8") as file: + weight_data = file.read() + weight_data = json.loads(weight_data) + gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name)) + sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name)) + if isinstance(gpt_path, list): + gpt_path = gpt_path[0] + if isinstance(sovits_path, list): + sovits_path = sovits_path[0] + +# gpt_path = os.environ.get( +# "gpt_path", pretrained_gpt_name +# ) +# sovits_path = os.environ.get("sovits_path", pretrained_sovits_name) +cnhubert_base_path = os.environ.get("cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base") +bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large") +infer_ttswebui = os.environ.get("infer_ttswebui", 9872) +infer_ttswebui = int(infer_ttswebui) +is_share = os.environ.get("is_share", "False") +is_share = eval(is_share) +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +# is_half=False +punctuation = set(["!", "?", "…", ",", ".", "-", " "]) +import gradio as gr +import librosa +import numpy as np +from feature_extractor import cnhubert +from transformers import AutoModelForMaskedLM, AutoTokenizer + +cnhubert.cnhubert_base_path = cnhubert_base_path + +import random + +from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator + + +def set_seed(seed): + if seed == -1: + seed = random.randint(0, 1000000) + seed = int(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +# set_seed(42) + +from time import time as ttime + +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from peft import LoraConfig, get_peft_model +from text import cleaned_text_to_sequence +from text.cleaner import clean_text + +from tools.i18n.i18n import I18nAuto, scan_language_list + +language = os.environ.get("language", "Auto") +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language +i18n = I18nAuto(language=language) + +# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。 + +if torch.cuda.is_available(): + device = "cuda" +else: + device = "cpu" + +dict_language_v1 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 +} +dict_language_v2 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("粤语"): "all_yue", # 全部按中文识别 + i18n("韩文"): "all_ko", # 全部按韩文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("粤英混合"): "yue", # 按粤英混合识别####不变 + i18n("韩英混合"): "ko", # 按韩英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 + i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种 +} +dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +if is_half == True: + bert_model = bert_model.half().to(device) +else: + bert_model = bert_model.to(device) + + +def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + return phone_level_feature.T + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +ssl_model = cnhubert.get_model() +if is_half == True: + ssl_model = ssl_model.half().to(device) +else: + ssl_model = ssl_model.to(device) + +resample_transform_dict = {} + + +def resample(audio_tensor, sr0,sr1): + global resample_transform_dict + key="%s-%s"%(sr0,sr1) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) + + +###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt +# symbol_version-model_version-if_lora_v3 +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new + +v3v4set={"v3","v4"} +def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): + global vq_model, hps, version, model_version, dict_language, if_lora_v3 + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + print(sovits_path,version, model_version, if_lora_v3) + is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4 + if if_lora_v3 == True and is_exist == False: + info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + gr.Warning(info) + raise FileExistsError(info) + dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + if prompt_language is not None and text_language is not None: + if prompt_language in list(dict_language.keys()): + prompt_text_update, prompt_language_update = ( + {"__type__": "update"}, + {"__type__": "update", "value": prompt_language}, + ) + else: + prompt_text_update = {"__type__": "update", "value": ""} + prompt_language_update = {"__type__": "update", "value": i18n("中文")} + if text_language in list(dict_language.keys()): + text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language} + else: + text_update = {"__type__": "update", "value": ""} + text_language_update = {"__type__": "update", "value": i18n("中文")} + if model_version in v3v4set: + visible_sample_steps = True + visible_inp_refs = False + else: + visible_sample_steps = False + visible_inp_refs = True + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + {"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]}, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "visible": True if model_version =="v3" else False}, + {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False}, + ) + + dict_s2 = load_sovits_new(sovits_path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps.model.version = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + version = hps.model.version + # print("sovits版本:",hps.model.version) + if model_version not in v3v4set: + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + model_version = version + else: + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + if "pretrained" not in sovits_path: + try: + del vq_model.enc_q + except: + pass + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + if if_lora_v3 == False: + print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False)) + else: + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + print( + "loading sovits_%spretrained_G"%model_version, + vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False), + ) + lora_rank = dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + print("loading sovits_%s_lora%s" % (model_version,lora_rank)) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() + + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + {"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]}, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "visible": True if model_version =="v3" else False}, + {"__type__": "update", "value": i18n("合成语音"), "interactive": True}, + ) + with open("./weight.json") as f: + data = f.read() + data = json.loads(data) + data["SoVITS"][version] = sovits_path + with open("./weight.json", "w") as f: + f.write(json.dumps(data)) + + +try: + next(change_sovits_weights(sovits_path)) +except: + pass + + +def change_gpt_weights(gpt_path): + global hz, max_sec, t2s_model, config + hz = 50 + dict_s1 = torch.load(gpt_path, map_location="cpu") + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # print("Number of parameter: %.2fM" % (total / 1e6)) + with open("./weight.json") as f: + data = f.read() + data = json.loads(data) + data["GPT"][version] = gpt_path + with open("./weight.json", "w") as f: + f.write(json.dumps(data)) + + +change_gpt_weights(gpt_path) +os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" +import torch + +now_dir = os.getcwd() + + +def init_bigvgan(): + global bigvgan_model,hifigan_model + from BigVGAN import bigvgan + + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + if hifigan_model: + hifigan_model=hifigan_model.cpu() + hifigan_model=None + try:torch.cuda.empty_cache() + except:pass + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + +def init_hifigan(): + global hifigan_model,bigvgan_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, is_bias=True + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu") + print("loading vocoder",hifigan_model.load_state_dict(state_dict_g)) + if bigvgan_model: + bigvgan_model=bigvgan_model.cpu() + bigvgan_model=None + try:torch.cuda.empty_cache() + except:pass + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) + +bigvgan_model=hifigan_model=None +if model_version=="v3": + init_bigvgan() +if model_version=="v4": + init_hifigan() + + +def get_spepc(hps, filename): + # audio = load_audio(filename, int(hps.data.sampling_rate)) + audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate)) + audio = torch.FloatTensor(audio) + maxx = audio.abs().max() + if maxx > 1: + audio /= min(2, maxx) + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + return spec + + +def clean_text_inf(text, language, version): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + +dtype = torch.float16 if is_half == True else torch.float32 + + +def get_bert_inf(phones, word2ph, norm_text, language): + language = language.replace("all_", "") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} + + +def get_first(text): + pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" + text = re.split(pattern, text)[0].strip() + return text + + +from text import chinese + + +def get_phones_and_bert(text, language, version, final=False): + if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + if language == "all_zh": + if re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext, "zh", version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = get_bert_feature(norm_text, word2ph).to(device) + elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext, "yue", version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: + textlist = [] + langlist = [] + if language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert.to(dtype), norm_text + + +from module.mel_processing import mel_spectrogram_torch, spectrogram_torch + +spec_min = -12 +spec_max = 2 + + +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + + +def merge_short_text_in_array(texts, threshold): + if (len(texts)) < 2: + return texts + result = [] + text = "" + for ele in texts: + text += ele + if len(text) >= threshold: + result.append(text) + text = "" + if len(text) > 0: + if len(result) == 0: + result.append(text) + else: + result[len(result) - 1] += text + return result + + +sr_model = None + + +def audio_sr(audio, sr): + global sr_model + if sr_model == None: + from tools.audio_sr import AP_BWE + + try: + sr_model = AP_BWE(device, DictToAttrRecursive) + except FileNotFoundError: + gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好")) + return audio.cpu().detach().numpy(), sr + return sr_model(audio, sr) + + +##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature +# cache_tokens={}#暂未实现清理机制 +cache = {} + + +def get_tts_wav( + ref_wav_path, + prompt_text, + prompt_language, + text, + text_language, + how_to_cut=i18n("不切"), + top_k=20, + top_p=0.6, + temperature=0.6, + ref_free=False, + speed=1, + if_freeze=False, + inp_refs=None, + sample_steps=8, + if_sr=False, + pause_second=0.3, +): + global cache + if ref_wav_path: + pass + else: + gr.Warning(i18n("请上传参考音频")) + if text: + pass + else: + gr.Warning(i18n("请填入推理文本")) + t = [] + if prompt_text is None or len(prompt_text) == 0: + ref_free = True + if model_version in v3v4set: + ref_free = False # s2v3暂不支持ref_free + else: + if_sr = False + t0 = ttime() + prompt_language = dict_language[prompt_language] + text_language = dict_language[text_language] + + if not ref_free: + prompt_text = prompt_text.strip("\n") + if prompt_text[-1] not in splits: + prompt_text += "。" if prompt_language != "en" else "." + print(i18n("实际输入的参考文本:"), prompt_text) + text = text.strip("\n") + # if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text + + print(i18n("实际输入的目标文本:"), text) + zero_wav = np.zeros( + int(hps.data.sampling_rate * pause_second), + dtype=np.float16 if is_half == True else np.float32, + ) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + zero_wav_torch = zero_wav_torch.half().to(device) + else: + zero_wav_torch = zero_wav_torch.to(device) + if not ref_free: + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000: + gr.Warning(i18n("参考音频在3~10秒范围外,请更换!")) + raise OSError(i18n("参考音频在3~10秒范围外,请更换!")) + wav16k = torch.from_numpy(wav16k) + if is_half == True: + wav16k = wav16k.half().to(device) + else: + wav16k = wav16k.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + t1 = ttime() + t.append(t1 - t0) + + if how_to_cut == i18n("凑四句一切"): + text = cut1(text) + elif how_to_cut == i18n("凑50字一切"): + text = cut2(text) + elif how_to_cut == i18n("按中文句号。切"): + text = cut3(text) + elif how_to_cut == i18n("按英文句号.切"): + text = cut4(text) + elif how_to_cut == i18n("按标点符号切"): + text = cut5(text) + while "\n\n" in text: + text = text.replace("\n\n", "\n") + print(i18n("实际输入的目标文本(切句后):"), text) + texts = text.split("\n") + texts = process_text(texts) + texts = merge_short_text_in_array(texts, 5) + audio_opt = [] + ###s2v3暂不支持ref_free + if not ref_free: + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) + + for i_text, text in enumerate(texts): + # 解决输入目标文本的空行导致报错的问题 + if len(text.strip()) == 0: + continue + if text[-1] not in splits: + text += "。" if text_language != "en" else "." + print(i18n("实际输入的目标文本(每句):"), text) + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) + print(i18n("前端处理后的文本(每句):"), norm_text2) + if not ref_free: + bert = torch.cat([bert1, bert2], 1) + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + else: + bert = bert2 + all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0) + + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + + t2 = ttime() + # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature) + # print(cache.keys(),if_freeze) + if i_text in cache and if_freeze == True: + pred_semantic = cache[i_text] + else: + with torch.no_grad(): + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + None if ref_free else prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + cache[i_text] = pred_semantic + t3 = ttime() + ###v3不存在以下逻辑和inp_refs + if model_version not in v3v4set: + refers = [] + if inp_refs: + for path in inp_refs: + try: + refer = get_spepc(hps, path.name).to(dtype).to(device) + refers.append(refer) + except: + traceback.print_exc() + if len(refers) == 0: + refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + audio = vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed + )[0][0] # .cpu().detach().numpy() + else: + refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + # print(11111111, phoneme_ids0, phoneme_ids1) + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + tgt_sr=24000 if model_version=="v3"else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr,tgt_sr) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + Tref=468 if model_version=="v3"else 500 + Tchunk=934 if model_version=="v3"else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min + mel2 = mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = vq_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + if model_version=="v3": + if bigvgan_model == None: + init_bigvgan() + else:#v4 + if hifigan_model == None: + init_hifigan() + vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model + with torch.inference_mode(): + wav_gen = vocoder_model(cfm_res) + audio = wav_gen[0][0] # .cpu().detach().numpy() + max_audio = torch.abs(audio).max() # 简单防止16bit爆音 + if max_audio > 1: + audio = audio / max_audio + audio_opt.append(audio) + audio_opt.append(zero_wav_torch) # zero_wav + t4 = ttime() + t.extend([t2 - t1, t3 - t2, t4 - t3]) + t1 = ttime() + print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))) + audio_opt = torch.cat(audio_opt, 0) # np.concatenate + if model_version in {"v1","v2"}:opt_sr=32000 + elif model_version=="v3":opt_sr=24000 + else:opt_sr=48000#v4 + if if_sr == True and opt_sr == 24000: + print(i18n("音频超分中")) + audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr) + max_audio = np.abs(audio_opt).max() + if max_audio > 1: + audio_opt /= max_audio + else: + audio_opt = audio_opt.cpu().detach().numpy() + yield opt_sr, (audio_opt * 32767).astype(np.int16) + + +def split(todo_text): + todo_text = todo_text.replace("……", "。").replace("——", ",") + if todo_text[-1] not in splits: + todo_text += "。" + i_split_head = i_split_tail = 0 + len_text = len(todo_text) + todo_texts = [] + while 1: + if i_split_head >= len_text: + break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入 + if todo_text[i_split_head] in splits: + i_split_head += 1 + todo_texts.append(todo_text[i_split_tail:i_split_head]) + i_split_tail = i_split_head + else: + i_split_head += 1 + return todo_texts + + +def cut1(inp): + inp = inp.strip("\n") + inps = split(inp) + split_idx = list(range(0, len(inps), 4)) + split_idx[-1] = None + if len(split_idx) > 1: + opts = [] + for idx in range(len(split_idx) - 1): + opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]])) + else: + opts = [inp] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut2(inp): + inp = inp.strip("\n") + inps = split(inp) + if len(inps) < 2: + return inp + opts = [] + summ = 0 + tmp_str = "" + for i in range(len(inps)): + summ += len(inps[i]) + tmp_str += inps[i] + if summ > 50: + summ = 0 + opts.append(tmp_str) + tmp_str = "" + if tmp_str != "": + opts.append(tmp_str) + # print(opts) + if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起 + opts[-2] = opts[-2] + opts[-1] + opts = opts[:-1] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut3(inp): + inp = inp.strip("\n") + opts = ["%s" % item for item in inp.strip("。").split("。")] + opts = [item for item in opts if not set(item).issubset(punctuation)] + return "\n".join(opts) + + +def cut4(inp): + inp = inp.strip("\n") + opts = re.split(r"(? 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit(): + items.append(char) + else: + items.append(char) + mergeitems.append("".join(items)) + items = [] + else: + items.append(char) + + if items: + mergeitems.append("".join(items)) + + opt = [item for item in mergeitems if not set(item).issubset(punds)] + return "\n".join(opt) + + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split("(\d+)", s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + + +def process_text(texts): + _text = [] + if all(text in [None, " ", "\n", ""] for text in texts): + raise ValueError(i18n("请输入有效文本")) + for text in texts: + if text in [None, " ", ""]: + pass + else: + _text.append(text) + return _text + + +def change_choices(): + SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) + return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, { + "choices": sorted(GPT_names, key=custom_sort_key), + "__type__": "update", + } + + +SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"] +GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"] +for path in SoVITS_weight_root + GPT_weight_root: + os.makedirs(path, exist_ok=True) + + +def get_weights_names(GPT_weight_root, SoVITS_weight_root): + SoVITS_names = [i for i in pretrained_sovits_name] + for path in SoVITS_weight_root: + for name in os.listdir(path): + if name.endswith(".pth"): + SoVITS_names.append("%s/%s" % (path, name)) + GPT_names = [i for i in pretrained_gpt_name] + for path in GPT_weight_root: + for name in os.listdir(path): + if name.endswith(".ckpt"): + GPT_names.append("%s/%s" % (path, name)) + return SoVITS_names, GPT_names + + +SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) + + +def html_center(text, label="p"): + return f"""
+ <{label} style="margin: 0; padding: 0;">{text} +
""" + + +def html_left(text, label="p"): + return f"""
+ <{label} style="margin: 0; padding: 0;">{text} +
""" + + +with gr.Blocks(title="GPT-SoVITS WebUI") as app: + gr.Markdown( + value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + "
" + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ) + with gr.Group(): + gr.Markdown(html_center(i18n("模型切换"), "h3")) + with gr.Row(): + GPT_dropdown = gr.Dropdown( + label=i18n("GPT模型列表"), + choices=sorted(GPT_names, key=custom_sort_key), + value=gpt_path, + interactive=True, + scale=14, + ) + SoVITS_dropdown = gr.Dropdown( + label=i18n("SoVITS模型列表"), + choices=sorted(SoVITS_names, key=custom_sort_key), + value=sovits_path, + interactive=True, + scale=14, + ) + refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14) + refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) + gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3")) + with gr.Row(): + inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath", scale=13) + with gr.Column(scale=13): + ref_text_free = gr.Checkbox( + label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。") + + i18n("v3暂不支持该模式,使用了会报错。"), + value=False, + interactive=True if model_version not in v3v4set else False, + show_label=True, + scale=1, + ) + gr.Markdown( + html_left( + i18n("使用无参考文本模式时建议使用微调的GPT") + + "
" + + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。") + ) + ) + prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=5, max_lines=5, scale=1) + with gr.Column(scale=14): + prompt_language = gr.Dropdown( + label=i18n("参考音频的语种"), + choices=list(dict_language.keys()), + value=i18n("中文"), + ) + inp_refs = ( + gr.File( + label=i18n( + "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。" + ), + file_count="multiple", + ) + if model_version not in v3v4set + else gr.File( + label=i18n( + "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。" + ), + file_count="multiple", + visible=False, + ) + ) + sample_steps = ( + gr.Radio( + label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), + value=32 if model_version=="v3"else 8, + choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32,64,128], + visible=True, + ) + if model_version in v3v4set + else gr.Radio( + label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), + choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32,64,128], + visible=False, + value=32 if model_version=="v3"else 8, + ) + ) + if_sr_Checkbox = gr.Checkbox( + label=i18n("v3输出如果觉得闷可以试试开超分"), + value=False, + interactive=True, + show_label=True, + visible=False if model_version !="v3" else True, + ) + gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3")) + with gr.Row(): + with gr.Column(scale=13): + text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26) + with gr.Column(scale=7): + text_language = gr.Dropdown( + label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"), + choices=list(dict_language.keys()), + value=i18n("中文"), + scale=1, + ) + how_to_cut = gr.Dropdown( + label=i18n("怎么切"), + choices=[ + i18n("不切"), + i18n("凑四句一切"), + i18n("凑50字一切"), + i18n("按中文句号。切"), + i18n("按英文句号.切"), + i18n("按标点符号切"), + ], + value=i18n("凑四句一切"), + interactive=True, + scale=1, + ) + gr.Markdown(value=html_center(i18n("语速调整,高为更快"))) + if_freeze = gr.Checkbox( + label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), + value=False, + interactive=True, + show_label=True, + scale=1, + ) + with gr.Row(): + speed = gr.Slider( + minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True, scale=1 + ) + pause_second_slider = gr.Slider( + minimum=0.1, + maximum=0.5, + step=0.01, + label=i18n("句间停顿秒数"), + value=0.3, + interactive=True, + scale=1, + ) + gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):"))) + top_k = gr.Slider( + minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True, scale=1 + ) + top_p = gr.Slider( + minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True, scale=1 + ) + temperature = gr.Slider( + minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True, scale=1 + ) + # with gr.Column(): + # gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。")) + # phoneme=gr.Textbox(label=i18n("音素框"), value="") + # get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary") + with gr.Row(): + inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25) + output = gr.Audio(label=i18n("输出的语音"), scale=14) + + inference_button.click( + get_tts_wav, + [ + inp_ref, + prompt_text, + prompt_language, + text, + text_language, + how_to_cut, + top_k, + top_p, + temperature, + ref_text_free, + speed, + if_freeze, + inp_refs, + sample_steps, + if_sr_Checkbox, + pause_second_slider, + ], + [output], + ) + SoVITS_dropdown.change( + change_sovits_weights, + [SoVITS_dropdown, prompt_language, text_language], + [ + prompt_language, + text_language, + prompt_text, + prompt_language, + text, + text_language, + sample_steps, + inp_refs, + ref_text_free, + if_sr_Checkbox, + inference_button, + ], + ) + GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], []) + + # gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")) + # with gr.Row(): + # text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="") + # button1 = gr.Button(i18n("凑四句一切"), variant="primary") + # button2 = gr.Button(i18n("凑50字一切"), variant="primary") + # button3 = gr.Button(i18n("按中文句号。切"), variant="primary") + # button4 = gr.Button(i18n("按英文句号.切"), variant="primary") + # button5 = gr.Button(i18n("按标点符号切"), variant="primary") + # text_opt = gr.Textbox(label=i18n("切分后文本"), value="") + # button1.click(cut1, [text_inp], [text_opt]) + # button2.click(cut2, [text_inp], [text_opt]) + # button3.click(cut3, [text_inp], [text_opt]) + # button4.click(cut4, [text_inp], [text_opt]) + # button5.click(cut5, [text_inp], [text_opt]) + # gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))) + +if __name__ == "__main__": + app.queue().launch( # concurrency_count=511, max_size=1022 + server_name="0.0.0.0", + inbrowser=True, + share=True, + server_port=infer_ttswebui, + quiet=True, + ) diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..837a2e49c56bbec9a08f4d47ccba200fe7cd5789 --- /dev/null +++ b/GPT_SoVITS/inference_webui_fast.py @@ -0,0 +1,540 @@ +""" +按中英混合识别 +按日英混合识别 +多语种启动切分识别语种 +全部按中文识别 +全部按英文识别 +全部按日文识别 +""" + +import json +import logging +import os +import random +import re +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +logging.getLogger("markdown_it").setLevel(logging.ERROR) +logging.getLogger("urllib3").setLevel(logging.ERROR) +logging.getLogger("httpcore").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) +logging.getLogger("charset_normalizer").setLevel(logging.ERROR) +logging.getLogger("torchaudio._extension").setLevel(logging.ERROR) +import torch + +try: + import gradio.analytics as analytics + + analytics.version_check = lambda: None +except: + ... + + +infer_ttswebui = os.environ.get("infer_ttswebui", 9872) +infer_ttswebui = int(infer_ttswebui) +is_share = os.environ.get("is_share", "False") +is_share = eval(is_share) +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] + +is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +gpt_path = os.environ.get("gpt_path", None) +sovits_path = os.environ.get("sovits_path", None) +cnhubert_base_path = os.environ.get("cnhubert_base_path", None) +bert_path = os.environ.get("bert_path", None) +version = model_version = os.environ.get("version", "v2") + +import gradio as gr +from TTS_infer_pack.text_segmentation_method import get_method +from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config + +from tools.i18n.i18n import I18nAuto, scan_language_list + +language = os.environ.get("language", "Auto") +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language +i18n = I18nAuto(language=language) + + +# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。 + +if torch.cuda.is_available(): + device = "cuda" +# elif torch.backends.mps.is_available(): +# device = "mps" +else: + device = "cpu" + +# is_half = False +# device = "cpu" + +dict_language_v1 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 +} +dict_language_v2 = { + i18n("中文"): "all_zh", # 全部按中文识别 + i18n("英文"): "en", # 全部按英文识别#######不变 + i18n("日文"): "all_ja", # 全部按日文识别 + i18n("粤语"): "all_yue", # 全部按中文识别 + i18n("韩文"): "all_ko", # 全部按韩文识别 + i18n("中英混合"): "zh", # 按中英混合识别####不变 + i18n("日英混合"): "ja", # 按日英混合识别####不变 + i18n("粤英混合"): "yue", # 按粤英混合识别####不变 + i18n("韩英混合"): "ko", # 按韩英混合识别####不变 + i18n("多语种混合"): "auto", # 多语种启动切分识别语种 + i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种 +} +dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + +cut_method = { + i18n("不切"): "cut0", + i18n("凑四句一切"): "cut1", + i18n("凑50字一切"): "cut2", + i18n("按中文句号。切"): "cut3", + i18n("按英文句号.切"): "cut4", + i18n("按标点符号切"): "cut5", +} + +tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") +tts_config.device = device +tts_config.is_half = is_half +tts_config.version = version +if gpt_path is not None: + tts_config.t2s_weights_path = gpt_path +if sovits_path is not None: + tts_config.vits_weights_path = sovits_path +if cnhubert_base_path is not None: + tts_config.cnhuhbert_base_path = cnhubert_base_path +if bert_path is not None: + tts_config.bert_base_path = bert_path + +print(tts_config) +tts_pipeline = TTS(tts_config) +gpt_path = tts_config.t2s_weights_path +sovits_path = tts_config.vits_weights_path +version = tts_config.version + + +def inference( + text, + text_lang, + ref_audio_path, + aux_ref_audio_paths, + prompt_text, + prompt_lang, + top_k, + top_p, + temperature, + text_split_method, + batch_size, + speed_factor, + ref_text_free, + split_bucket, + fragment_interval, + seed, + keep_random, + parallel_infer, + repetition_penalty, + sample_steps, + super_sampling, +): + seed = -1 if keep_random else seed + actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1) + inputs = { + "text": text, + "text_lang": dict_language[text_lang], + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [], + "prompt_text": prompt_text if not ref_text_free else "", + "prompt_lang": dict_language[prompt_lang], + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": cut_method[text_split_method], + "batch_size": int(batch_size), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "return_fragment": False, + "fragment_interval": fragment_interval, + "seed": actual_seed, + "parallel_infer": parallel_infer, + "repetition_penalty": repetition_penalty, + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + } + try: + for item in tts_pipeline.run(inputs): + yield item, actual_seed + except NO_PROMPT_ERROR: + gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!")) + + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split("(\d+)", s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + + +def change_choices(): + SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) + return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, { + "choices": sorted(GPT_names, key=custom_sort_key), + "__type__": "update", + } + + +path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" +pretrained_sovits_name = [ + "GPT_SoVITS/pretrained_models/s2G488k.pth", + "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + path_sovits_v3, +] +pretrained_gpt_name = [ + "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", + "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", + "GPT_SoVITS/pretrained_models/s1v3.ckpt", +] + +_ = [[], []] +for i in range(3): + if os.path.exists(pretrained_gpt_name[i]): + _[0].append(pretrained_gpt_name[i]) + if os.path.exists(pretrained_sovits_name[i]): + _[-1].append(pretrained_sovits_name[i]) +pretrained_gpt_name, pretrained_sovits_name = _ + + +if os.path.exists("./weight.json"): + pass +else: + with open("./weight.json", "w", encoding="utf-8") as file: + json.dump({"GPT": {}, "SoVITS": {}}, file) + +with open("./weight.json", "r", encoding="utf-8") as file: + weight_data = file.read() + weight_data = json.loads(weight_data) + gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name)) + sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name)) + if isinstance(gpt_path, list): + gpt_path = gpt_path[0] + if isinstance(sovits_path, list): + sovits_path = sovits_path[0] + + +SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"] +GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"] +for path in SoVITS_weight_root + GPT_weight_root: + os.makedirs(path, exist_ok=True) + + +def get_weights_names(GPT_weight_root, SoVITS_weight_root): + SoVITS_names = [i for i in pretrained_sovits_name] + for path in SoVITS_weight_root: + for name in os.listdir(path): + if name.endswith(".pth"): + SoVITS_names.append("%s/%s" % (path, name)) + GPT_names = [i for i in pretrained_gpt_name] + for path in GPT_weight_root: + for name in os.listdir(path): + if name.endswith(".ckpt"): + GPT_names.append("%s/%s" % (path, name)) + return SoVITS_names, GPT_names + + +SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root) + + +from process_ckpt import get_sovits_version_from_path_fast + + +def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): + global version, model_version, dict_language, if_lora_v3 + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + # print(sovits_path,version, model_version, if_lora_v3) + if if_lora_v3 and not os.path.exists(path_sovits_v3): + info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + gr.Warning(info) + raise FileExistsError(info) + dict_language = dict_language_v1 if version == "v1" else dict_language_v2 + if prompt_language is not None and text_language is not None: + if prompt_language in list(dict_language.keys()): + prompt_text_update, prompt_language_update = ( + {"__type__": "update"}, + {"__type__": "update", "value": prompt_language}, + ) + else: + prompt_text_update = {"__type__": "update", "value": ""} + prompt_language_update = {"__type__": "update", "value": i18n("中文")} + if text_language in list(dict_language.keys()): + text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language} + else: + text_update = {"__type__": "update", "value": ""} + text_language_update = {"__type__": "update", "value": i18n("中文")} + if model_version == "v3": + visible_sample_steps = True + visible_inp_refs = False + else: + visible_sample_steps = False + visible_inp_refs = True + # prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free, + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + {"__type__": "update", "interactive": visible_sample_steps, "value": 32}, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "interactive": True if model_version != "v3" else False}, + {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False}, + ) + + tts_pipeline.init_vits_weights(sovits_path) + yield ( + {"__type__": "update", "choices": list(dict_language.keys())}, + {"__type__": "update", "choices": list(dict_language.keys())}, + prompt_text_update, + prompt_language_update, + text_update, + text_language_update, + {"__type__": "update", "interactive": visible_sample_steps, "value": 32}, + {"__type__": "update", "visible": visible_inp_refs}, + {"__type__": "update", "interactive": True if model_version != "v3" else False}, + {"__type__": "update", "value": i18n("合成语音"), "interactive": True}, + ) + with open("./weight.json") as f: + data = f.read() + data = json.loads(data) + data["SoVITS"][version] = sovits_path + with open("./weight.json", "w") as f: + f.write(json.dumps(data)) + + +with gr.Blocks(title="GPT-SoVITS WebUI") as app: + gr.Markdown( + value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + "
" + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ) + + with gr.Column(): + # with gr.Group(): + gr.Markdown(value=i18n("模型切换")) + with gr.Row(): + GPT_dropdown = gr.Dropdown( + label=i18n("GPT模型列表"), + choices=sorted(GPT_names, key=custom_sort_key), + value=gpt_path, + interactive=True, + ) + SoVITS_dropdown = gr.Dropdown( + label=i18n("SoVITS模型列表"), + choices=sorted(SoVITS_names, key=custom_sort_key), + value=sovits_path, + interactive=True, + ) + refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary") + refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) + + with gr.Row(): + with gr.Column(): + gr.Markdown(value=i18n("*请上传并填写参考信息")) + with gr.Row(): + inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath") + inp_refs = gr.File( + label=i18n("辅参考音频(可选多个,或不选)"), + file_count="multiple", + visible=True if model_version != "v3" else False, + ) + prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2) + with gr.Row(): + prompt_language = gr.Dropdown( + label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文") + ) + with gr.Column(): + ref_text_free = gr.Checkbox( + label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), + value=False, + interactive=True if model_version != "v3" else False, + show_label=True, + ) + gr.Markdown( + i18n("使用无参考文本模式时建议使用微调的GPT") + + "
" + + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。") + ) + + with gr.Column(): + gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式")) + text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20) + text_language = gr.Dropdown( + label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文") + ) + + with gr.Group(): + gr.Markdown(value=i18n("推理设置")) + with gr.Row(): + with gr.Column(): + with gr.Row(): + batch_size = gr.Slider( + minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True + ) + sample_steps = gr.Radio( + label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True + ) + with gr.Row(): + fragment_interval = gr.Slider( + minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True + ) + speed_factor = gr.Slider( + minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True + ) + with gr.Row(): + top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True) + top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True) + with gr.Row(): + temperature = gr.Slider( + minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True + ) + repetition_penalty = gr.Slider( + minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True + ) + + with gr.Column(): + with gr.Row(): + how_to_cut = gr.Dropdown( + label=i18n("怎么切"), + choices=[ + i18n("不切"), + i18n("凑四句一切"), + i18n("凑50字一切"), + i18n("按中文句号。切"), + i18n("按英文句号.切"), + i18n("按标点符号切"), + ], + value=i18n("凑四句一切"), + interactive=True, + scale=1, + ) + super_sampling = gr.Checkbox( + label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True + ) + + with gr.Row(): + parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True) + split_bucket = gr.Checkbox( + label=i18n("数据分桶(并行推理时会降低一点计算量)"), + value=True, + interactive=True, + show_label=True, + ) + + with gr.Row(): + seed = gr.Number(label=i18n("随机种子"), value=-1) + keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True) + + output = gr.Audio(label=i18n("输出的语音")) + with gr.Row(): + inference_button = gr.Button(i18n("合成语音"), variant="primary") + stop_infer = gr.Button(i18n("终止合成"), variant="primary") + + inference_button.click( + inference, + [ + text, + text_language, + inp_ref, + inp_refs, + prompt_text, + prompt_language, + top_k, + top_p, + temperature, + how_to_cut, + batch_size, + speed_factor, + ref_text_free, + split_bucket, + fragment_interval, + seed, + keep_random, + parallel_infer, + repetition_penalty, + sample_steps, + super_sampling, + ], + [output, seed], + ) + stop_infer.click(tts_pipeline.stop, [], []) + SoVITS_dropdown.change( + change_sovits_weights, + [SoVITS_dropdown, prompt_language, text_language], + [ + prompt_language, + text_language, + prompt_text, + prompt_language, + text, + text_language, + sample_steps, + inp_refs, + ref_text_free, + inference_button, + ], + ) # + GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], []) + + with gr.Group(): + gr.Markdown( + value=i18n( + "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。" + ) + ) + with gr.Row(): + text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4) + with gr.Column(): + _how_to_cut = gr.Radio( + label=i18n("怎么切"), + choices=[ + i18n("不切"), + i18n("凑四句一切"), + i18n("凑50字一切"), + i18n("按中文句号。切"), + i18n("按英文句号.切"), + i18n("按标点符号切"), + ], + value=i18n("凑四句一切"), + interactive=True, + ) + cut_text = gr.Button(i18n("切分"), variant="primary") + + def to_cut(text_inp, how_to_cut): + if len(text_inp.strip()) == 0 or text_inp == []: + return "" + method = get_method(cut_method[how_to_cut]) + return method(text_inp) + + text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4) + cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt]) + gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")) + +if __name__ == "__main__": + app.queue().launch( # concurrency_count=511, max_size=1022 + server_name="0.0.0.0", + inbrowser=True, + share=is_share, + server_port=infer_ttswebui, + quiet=True, + ) diff --git a/GPT_SoVITS/onnx_export.py b/GPT_SoVITS/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..fd680135fb7d71afb4680b05a62b9874c39ad21c --- /dev/null +++ b/GPT_SoVITS/onnx_export.py @@ -0,0 +1,398 @@ +import torch +import torchaudio +from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule +from feature_extractor import cnhubert +from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2 +from torch import nn + +cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +cnhubert.cnhubert_base_path = cnhubert_base_path +ssl_model = cnhubert.get_model() +import json +import os + +import soundfile +from text import cleaned_text_to_sequence + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +class T2SEncoder(nn.Module): + def __init__(self, t2s, vits): + super().__init__() + self.encoder = t2s.onnx_encoder + self.vits = vits + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + codes = self.vits.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1) + all_phoneme_ids = torch.cat([ref_seq, text_seq], 1) + bert = bert.unsqueeze(0) + prompt = prompt_semantic.unsqueeze(0) + return self.encoder(all_phoneme_ids, bert), prompt + + +class T2SModel(nn.Module): + def __init__(self, t2s_path, vits_model): + super().__init__() + dict_s1 = torch.load(t2s_path, map_location="cpu") + self.config = dict_s1["config"] + self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False) + self.t2s_model.load_state_dict(dict_s1["weight"]) + self.t2s_model.eval() + self.vits_model = vits_model.vq_model + self.hz = 50 + self.max_sec = self.config["data"]["max_sec"] + self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]]) + self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec]) + self.t2s_model = self.t2s_model.model + self.t2s_model.init_onnx() + self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model) + self.first_stage_decoder = self.t2s_model.first_stage_decoder + self.stage_decoder = self.t2s_model.stage_decoder + # self.t2s_model = torch.jit.script(self.t2s_model) + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content): + early_stop_num = self.t2s_model.early_stop_num + + # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N] + x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + + prefix_len = prompts.shape[1] + + # [1,N,512] [1,N] + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + + stop = False + for idx in range(1, 1500): + # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N] + enco = self.stage_decoder(y, k, v, y_emb, x_example) + y, k, v, y_emb, logits, samples = enco + if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: + stop = True + if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS: + stop = True + if stop: + break + y[0, -1] = 0 + + return y[:, -idx:].unsqueeze(0) + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False): + # self.onnx_encoder = torch.jit.script(self.onnx_encoder) + if dynamo: + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + onnx_encoder_export_output = torch.onnx.dynamo_export( + self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options + ) + onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx") + return + + torch.onnx.export( + self.onnx_encoder, + (ref_seq, text_seq, ref_bert, text_bert, ssl_content), + f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", + input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"], + output_names=["x", "prompts"], + dynamic_axes={ + "ref_seq": {1: "ref_length"}, + "text_seq": {1: "text_length"}, + "ref_bert": {0: "ref_length"}, + "text_bert": {0: "text_length"}, + "ssl_content": {2: "ssl_length"}, + }, + opset_version=16, + ) + x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + + torch.onnx.export( + self.first_stage_decoder, + (x, prompts), + f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", + input_names=["x", "prompts"], + output_names=["y", "k", "v", "y_emb", "x_example"], + dynamic_axes={ + "x": {1: "x_length"}, + "prompts": {1: "prompts_length"}, + }, + verbose=False, + opset_version=16, + ) + y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts) + + torch.onnx.export( + self.stage_decoder, + (y, k, v, y_emb, x_example), + f"onnx/{project_name}/{project_name}_t2s_sdec.onnx", + input_names=["iy", "ik", "iv", "iy_emb", "ix_example"], + output_names=["y", "k", "v", "y_emb", "logits", "samples"], + dynamic_axes={ + "iy": {1: "iy_length"}, + "ik": {1: "ik_length"}, + "iv": {1: "iv_length"}, + "iy_emb": {1: "iy_emb_length"}, + "ix_example": {1: "ix_example_length"}, + }, + verbose=False, + opset_version=16, + ) + + +class VitsModel(nn.Module): + def __init__(self, vits_path): + super().__init__() + dict_s2 = torch.load(vits_path, map_location="cpu") + self.hps = dict_s2["config"] + if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + self.hps["model"]["version"] = "v1" + else: + self.hps["model"]["version"] = "v2" + + self.hps = DictToAttrRecursive(self.hps) + self.hps.model.semantic_frame_rate = "25hz" + self.vq_model = SynthesizerTrn( + self.hps.data.filter_length // 2 + 1, + self.hps.train.segment_size // self.hps.data.hop_length, + n_speakers=self.hps.data.n_speakers, + **self.hps.model, + ) + self.vq_model.eval() + self.vq_model.load_state_dict(dict_s2["weight"], strict=False) + + def forward(self, text_seq, pred_semantic, ref_audio): + refer = spectrogram_torch( + ref_audio, + self.hps.data.filter_length, + self.hps.data.sampling_rate, + self.hps.data.hop_length, + self.hps.data.win_length, + center=False, + ) + return self.vq_model(pred_semantic, text_seq, refer)[0, 0] + + +class GptSoVits(nn.Module): + def __init__(self, vits, t2s): + super().__init__() + self.vits = vits + self.t2s = t2s + + def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False): + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + audio = self.vits(text_seq, pred_semantic, ref_audio) + if debug: + import onnxruntime + + sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"]) + audio1 = sess.run( + None, + { + "text_seq": text_seq.detach().cpu().numpy(), + "pred_semantic": pred_semantic.detach().cpu().numpy(), + "ref_audio": ref_audio.detach().cpu().numpy(), + }, + ) + return audio, audio1 + return audio + + def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name): + self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name) + pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content) + torch.onnx.export( + self.vits, + (text_seq, pred_semantic, ref_audio), + f"onnx/{project_name}/{project_name}_vits.onnx", + input_names=["text_seq", "pred_semantic", "ref_audio"], + output_names=["audio"], + dynamic_axes={ + "text_seq": {1: "text_length"}, + "pred_semantic": {2: "pred_length"}, + "ref_audio": {1: "audio_length"}, + }, + opset_version=17, + verbose=False, + ) + + +class SSLModel(nn.Module): + def __init__(self): + super().__init__() + self.ssl = ssl_model + + def forward(self, ref_audio_16k): + return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2) + + +def export(vits_path, gpt_path, project_name, vits_model="v2"): + vits = VitsModel(vits_path) + gpt = T2SModel(gpt_path, vits) + gpt_sovits = GptSoVits(vits, gpt) + ssl = SSLModel() + ref_seq = torch.LongTensor( + [ + cleaned_text_to_sequence( + [ + "n", + "i2", + "h", + "ao3", + ",", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + ], + version=vits_model, + ) + ] + ) + text_seq = torch.LongTensor( + [ + cleaned_text_to_sequence( + [ + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + "w", + "o3", + "sh", + "i4", + "b", + "ai2", + "y", + "e4", + ], + version=vits_model, + ) + ] + ) + ref_bert = torch.randn((ref_seq.shape[1], 1024)).float() + text_bert = torch.randn((text_seq.shape[1], 1024)).float() + ref_audio = torch.randn((1, 48000 * 5)).float() + # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float() + ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() + ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float() + + try: + os.mkdir(f"onnx/{project_name}") + except: + pass + + ssl_content = ssl(ref_audio_16k).float() + + # debug = False + debug = True + + # gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name) + + if debug: + a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug) + soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate) + soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate) + else: + a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy() + soundfile.write("out.wav", a, vits.hps.data.sampling_rate) + + if vits_model == "v1": + symbols = symbols_v1 + else: + symbols = symbols_v2 + + MoeVSConf = { + "Folder": f"{project_name}", + "Name": f"{project_name}", + "Type": "GPT-SoVits", + "Rate": vits.hps.data.sampling_rate, + "NumLayers": gpt.t2s_model.num_layers, + "EmbeddingDim": gpt.t2s_model.embedding_dim, + "Dict": "BasicDict", + "BertPath": "chinese-roberta-wwm-ext-large", + # "Symbol": symbols, + "AddBlank": False, + } + + MoeVSConfJson = json.dumps(MoeVSConf) + with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile: + json.dump(MoeVSConf, MoeVsConfFile, indent=4) + + +if __name__ == "__main__": + try: + os.mkdir("onnx") + except: + pass + + gpt_path = "GPT_weights/nahida-e25.ckpt" + vits_path = "SoVITS_weights/nahida_e30_s3930.pth" + exp_path = "nahida" + export(vits_path, gpt_path, exp_path) + + # soundfile.write("out.wav", a, vits.hps.data.sampling_rate) diff --git a/GPT_SoVITS/process_ckpt.py b/GPT_SoVITS/process_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2a1bacd6c5c67b370f5716127a9d38e77c012a --- /dev/null +++ b/GPT_SoVITS/process_ckpt.py @@ -0,0 +1,124 @@ +import traceback +from collections import OrderedDict +from time import time as ttime +import shutil +import os +import torch +from tools.i18n.i18n import I18nAuto + +i18n = I18nAuto() + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + tmp_path = "%s.pth" % (ttime()) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +""" +00:v1 +01:v2 +02:v3 +03:v3lora +04:v4lora + +""" +from io import BytesIO + + +def my_save2(fea, path,cfm_version): + bio = BytesIO() + torch.save(fea, bio) + bio.seek(0) + data = bio.getvalue() + byte=b"03" if cfm_version=="v3"else b"04" + data = byte + data[2:] + with open(path, "wb") as f: + f.write(data) + + +def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None): + try: + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = ckpt[key].half() + opt["config"] = hps + opt["info"] = "%sepoch_%siteration" % (epoch, steps) + if lora_rank: + opt["lora_rank"] = lora_rank + my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version) + else: + my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) + return "Success." + except: + return traceback.format_exc() + + +head2version = { + b"00": ["v1", "v1", False], + b"01": ["v2", "v2", False], + b"02": ["v2", "v3", False], + b"03": ["v2", "v3", True], + b"04": ["v2", "v4", True], +} +hash_pretrained_dict = { + "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained + "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained + "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained + "4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained +} +import hashlib + + +def get_hash_from_file(sovits_path): + with open(sovits_path, "rb") as f: + data = f.read(8192) + hash_md5 = hashlib.md5() + hash_md5.update(data) + return hash_md5.hexdigest() + + +def get_sovits_version_from_path_fast(sovits_path): + ###1-if it is pretrained sovits models, by hash + hash = get_hash_from_file(sovits_path) + if hash in hash_pretrained_dict: + return hash_pretrained_dict[hash] + ###2-new weights, by head + with open(sovits_path, "rb") as f: + version = f.read(2) + if version != b"PK": + return head2version[version] + ###3-old weights, by file size + if_lora_v3 = False + size = os.path.getsize(sovits_path) + """ + v1weights:about 82942KB + half thr:82978KB + v2weights:about 83014KB + v3weights:about 750MB + """ + if size < 82978 * 1024: + model_version = version = "v1" + elif size < 700 * 1024 * 1024: + model_version = version = "v2" + else: + version = "v2" + model_version = "v3" + return version, model_version, if_lora_v3 + + +def load_sovits_new(sovits_path): + f = open(sovits_path, "rb") + meta = f.read(2) + if meta != "PK": + data = b"PK" + f.read() + bio = BytesIO() + bio.write(data) + bio.seek(0) + return torch.load(bio, map_location="cpu", weights_only=False) + return torch.load(sovits_path, map_location="cpu", weights_only=False) diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py new file mode 100644 index 0000000000000000000000000000000000000000..1176f0bcef869d2f66574d47d9a55d0cc7e1ac0c --- /dev/null +++ b/GPT_SoVITS/s1_train.py @@ -0,0 +1,171 @@ +# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py +import os + +if "_CUDA_VISIBLE_DEVICES" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +import argparse +import logging +import platform +from pathlib import Path + +import torch +from AR.data.data_module import Text2SemanticDataModule +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from AR.utils.io import load_yaml_config +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger +from pytorch_lightning.strategies import DDPStrategy + +logging.getLogger("numba").setLevel(logging.WARNING) +logging.getLogger("matplotlib").setLevel(logging.WARNING) +torch.set_float32_matmul_precision("high") +from collections import OrderedDict + +from AR.utils import get_newest_ckpt +from process_ckpt import my_save + + +class my_model_ckpt(ModelCheckpoint): + def __init__( + self, + config, + if_save_latest, + if_save_every_weights, + half_weights_save_dir, + exp_name, + **kwargs, + ): + super().__init__(**kwargs) + self.if_save_latest = if_save_latest + self.if_save_every_weights = if_save_every_weights + self.half_weights_save_dir = half_weights_save_dir + self.exp_name = exp_name + self.config = config + + def on_train_epoch_end(self, trainer, pl_module): + # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): + if self._should_save_on_train_epoch_end(trainer): + monitor_candidates = self._monitor_candidates(trainer) + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: + if ( + self.if_save_latest == True + ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt + to_clean = list(os.listdir(self.dirpath)) + self._save_topk_checkpoint(trainer, monitor_candidates) + if self.if_save_latest == True: + for name in to_clean: + try: + os.remove("%s/%s" % (self.dirpath, name)) + except: + pass + if self.if_save_every_weights == True: + to_save_od = OrderedDict() + to_save_od["weight"] = OrderedDict() + dictt = trainer.strategy._lightning_module.state_dict() + for key in dictt: + to_save_od["weight"][key] = dictt[key].half() + to_save_od["config"] = self.config + to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1) + # torch.save( + # print(os.environ) + if os.environ.get("LOCAL_RANK", "0") == "0": + my_save( + to_save_od, + "%s/%s-e%s.ckpt" + % ( + self.half_weights_save_dir, + self.exp_name, + trainer.current_epoch + 1, + ), + ) + self._save_last_checkpoint(trainer, monitor_candidates) + + +def main(args): + config = load_yaml_config(args.config_file) + + output_dir = Path(config["output_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + + ckpt_dir = output_dir / "ckpt" + ckpt_dir.mkdir(parents=True, exist_ok=True) + + seed_everything(config["train"]["seed"], workers=True) + ckpt_callback: ModelCheckpoint = my_model_ckpt( + config=config, + if_save_latest=config["train"]["if_save_latest"], + if_save_every_weights=config["train"]["if_save_every_weights"], + half_weights_save_dir=config["train"]["half_weights_save_dir"], + exp_name=config["train"]["exp_name"], + save_top_k=-1, + monitor="top_3_acc", + mode="max", + save_on_train_epoch_end=True, + every_n_epochs=config["train"]["save_every_n_epoch"], + dirpath=ckpt_dir, + ) + logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["USE_LIBUV"] = "0" + trainer: Trainer = Trainer( + max_epochs=config["train"]["epochs"], + accelerator="gpu" if torch.cuda.is_available() else "cpu", + # val_check_interval=9999999999999999999999,###不要验证 + # check_val_every_n_epoch=None, + limit_val_batches=0, + devices=-1 if torch.cuda.is_available() else 1, + benchmark=False, + fast_dev_run=False, + strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo") + if torch.cuda.is_available() + else "auto", + precision=config["train"]["precision"], + logger=logger, + num_sanity_val_steps=0, + callbacks=[ckpt_callback], + use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题! + ) + + model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir) + + data_module: Text2SemanticDataModule = Text2SemanticDataModule( + config, + train_semantic_path=config["train_semantic_path"], + train_phoneme_path=config["train_phoneme_path"], + # dev_semantic_path=args.dev_semantic_path, + # dev_phoneme_path=args.dev_phoneme_path + ) + + try: + # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序 + newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) + ckpt_path = ckpt_dir / newest_ckpt_name + except Exception: + ckpt_path = None + print("ckpt_path:", ckpt_path) + trainer.fit(model, data_module, ckpt_path=ckpt_path) + + +# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config_file", + type=str, + default="configs/s1longer.yaml", + help="path of config file", + ) + # args for dataset + # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv') + # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt') + + # parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv') + # parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy') + # parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results') + # parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results') + + args = parser.parse_args() + logging.info(str(args)) + main(args) diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4611862ad1c82269030b9da899d2efd467e20e --- /dev/null +++ b/GPT_SoVITS/s2_train.py @@ -0,0 +1,680 @@ +import warnings + +warnings.filterwarnings("ignore") +import os + +import utils + +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.cuda.amp import GradScaler, autocast +from torch.nn import functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from random import randint + +from module import commons +from module.data_utils import ( + DistributedBucketSampler, + TextAudioSpeakerCollate, + TextAudioSpeakerLoader, +) +from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss +from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from module.models import ( + MultiPeriodDiscriminator, + SynthesizerTrn, +) +from process_ckpt import savee + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data) ######## + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + 1100, + 1200, + 1300, + 1400, + 1500, + 1600, + 1700, + 1800, + 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + # if rank == 0: + # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) + # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, + # batch_size=1, pin_memory=True, + # drop_last=False, collate_fn=collate_fn) + + net_g = ( + SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).cuda(rank) + if torch.cuda.is_available() + else SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + ) + + net_d = ( + MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) + if torch.cuda.is_available() + else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device) + ) + for name, param in net_g.named_parameters(): + if not param.requires_grad: + print(name, "not requires_grad") + + te_p = list(map(id, net_g.enc_p.text_embedding.parameters())) + et_p = list(map(id, net_g.enc_p.encoder_text.parameters())) + mrte_p = list(map(id, net_g.enc_p.mrte.parameters())) + base_params = filter( + lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad, + net_g.parameters(), + ) + + # te_p=net_g.enc_p.text_embedding.parameters() + # et_p=net_g.enc_p.encoder_text.parameters() + # mrte_p=net_g.enc_p.mrte.parameters() + + optim_g = torch.optim.AdamW( + # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致 + [ + {"params": base_params, "lr": hps.train.learning_rate}, + { + "params": net_g.enc_p.text_embedding.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + { + "params": net_g.enc_p.encoder_text.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + { + "params": net_g.enc_p.mrte.parameters(), + "lr": hps.train.learning_rate * hps.train.text_low_lr_rate, + }, + ], + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + optim_d = torch.optim.AdamW( + net_d.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + if torch.cuda.is_available(): + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) + net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + net_d = net_d.to(device) + + try: # 如果能加载自动resume + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"), + net_d, + optim_d, + ) # D多半加载没事 + if rank == 0: + logger.info("loaded D") + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"), + net_g, + optim_g, + ) + epoch_str += 1 + global_step = (epoch_str - 1) * len(train_loader) + # epoch_str = 1 + # global_step = 0 + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + if ( + hps.train.pretrained_s2G != "" + and hps.train.pretrained_s2G != None + and os.path.exists(hps.train.pretrained_s2G) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + "loaded pretrained %s" % hps.train.pretrained_s2G, + net_g.module.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ) + if torch.cuda.is_available() + else net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ), + ) ##测试不加载优化器 + if ( + hps.train.pretrained_s2D != "" + and hps.train.pretrained_s2D != None + and os.path.exists(hps.train.pretrained_s2D) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2D) + print( + "loaded pretrained %s" % hps.train.pretrained_s2D, + net_d.module.load_state_dict( + torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"], + ) + if torch.cuda.is_available() + else net_d.load_state_dict( + torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"], + ), + ) + + # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, + gamma=hps.train.lr_decay, + last_epoch=-1, + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, + gamma=hps.train.lr_decay, + last_epoch=-1, + ) + for _ in range(epoch_str): + scheduler_g.step() + scheduler_d.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + print("start training from epoch %s" % epoch_str) + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + scheduler_d.step() + print("training done") + + +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + for batch_idx, ( + ssl, + ssl_lengths, + spec, + spec_lengths, + y, + y_lengths, + text, + text_lengths, + ) in enumerate(tqdm(train_loader)): + if torch.cuda.is_available(): + spec, spec_lengths = ( + spec.cuda( + rank, + non_blocking=True, + ), + spec_lengths.cuda( + rank, + non_blocking=True, + ), + ) + y, y_lengths = ( + y.cuda( + rank, + non_blocking=True, + ), + y_lengths.cuda( + rank, + non_blocking=True, + ), + ) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = ( + text.cuda( + rank, + non_blocking=True, + ), + text_lengths.cuda( + rank, + non_blocking=True, + ), + ) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + y, y_lengths = y.to(device), y_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.to(device), text_lengths.to(device) + + with autocast(enabled=hps.train.fp16_run): + ( + y_hat, + kl_ssl, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + stats_ssl, + ) = net_g(ssl, spec, spec_lengths, text, text_lengths) + + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, + y_d_hat_g, + ) + loss_disc_all = loss_disc + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + + with autocast(enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False): + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl + + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, + 100.0 * batch_idx / len(train_loader), + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = { + "loss/g/total": loss_gen_all, + "loss/d/total": loss_disc_all, + "learning_rate": lr, + "grad_norm_d": grad_norm_d, + "grad_norm_g": grad_norm_g, + } + scalar_dict.update( + { + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/kl_ssl": kl_ssl, + "loss/g/kl": loss_kl, + } + ) + + # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = None + try: ###Some people installed the wrong version of matplotlib. + image_dict = { + "slice/mel_org": utils.plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy(), + ), + "slice/mel_gen": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy(), + ), + "all/mel": utils.plot_spectrogram_to_numpy( + mel[0].data.cpu().numpy(), + ), + "all/stats_ssl": utils.plot_spectrogram_to_numpy( + stats_ssl[0].data.cpu().numpy(), + ), + } + except: + pass + if image_dict: + utils.summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict, + ) + else: + utils.summarize( + writer=writer, + global_step=global_step, + scalars=scalar_dict, + ) + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(global_step), + ), + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "D_{}.pth".format(global_step), + ), + ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(233333333333), + ), + ) + utils.save_checkpoint( + net_d, + optim_d, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "D_{}.pth".format(233333333333), + ), + ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + ckpt, + hps.name + "_e%s_s%s" % (epoch, global_step), + epoch, + global_step, + hps, + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +def evaluate(hps, generator, eval_loader, writer_eval): + generator.eval() + image_dict = {} + audio_dict = {} + print("Evaluating ...") + with torch.no_grad(): + for batch_idx, ( + ssl, + ssl_lengths, + spec, + spec_lengths, + y, + y_lengths, + text, + text_lengths, + ) in enumerate(eval_loader): + print(111) + if torch.cuda.is_available(): + spec, spec_lengths = spec.cuda(), spec_lengths.cuda() + y, y_lengths = y.cuda(), y_lengths.cuda() + ssl = ssl.cuda() + text, text_lengths = text.cuda(), text_lengths.cuda() + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + y, y_lengths = y.to(device), y_lengths.to(device) + ssl = ssl.to(device) + text, text_lengths = text.to(device), text_lengths.to(device) + for test in [0, 1]: + y_hat, mask, *_ = ( + generator.module.infer( + ssl, + spec, + spec_lengths, + text, + text_lengths, + test=test, + ) + if torch.cuda.is_available() + else generator.infer( + ssl, + spec, + spec_lengths, + text, + text_lengths, + test=test, + ) + ) + y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length + + mel = spec_to_mel_torch( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1).float(), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax, + ) + image_dict.update( + { + f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy( + y_hat_mel[0].cpu().numpy(), + ), + } + ) + audio_dict.update( + { + f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]], + }, + ) + image_dict.update( + { + f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()), + }, + ) + audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]}) + + # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None) + # audio_dict.update({ + # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :] + # }) + + utils.summarize( + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=hps.data.sampling_rate, + ) + generator.train() + + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..71d21967eb252faa37a8f55578705b0ae9fe19ba --- /dev/null +++ b/GPT_SoVITS/s2_train_v3.py @@ -0,0 +1,467 @@ +import warnings + +warnings.filterwarnings("ignore") +import os + +import utils + +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from random import randint + +from module import commons +from module.data_utils import ( + DistributedBucketSampler, +) +from module.data_utils import ( + TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate, +) +from module.data_utils import ( + TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader, +) +from module.models import ( + SynthesizerTrnV3 as SynthesizerTrn, +) +from process_ckpt import savee + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + train_dataset = TextAudioSpeakerLoader(hps.data) ######## + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + # 1100, + # 1200, + # 1300, + # 1400, + # 1500, + # 1600, + # 1700, + # 1800, + # 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + # if rank == 0: + # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) + # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, + # batch_size=1, pin_memory=True, + # drop_last=False, collate_fn=collate_fn) + + net_g = ( + SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).cuda(rank) + if torch.cuda.is_available() + else SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + ) + + # net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device) + # for name, param in net_g.named_parameters(): + # if not param.requires_grad: + # print(name, "not requires_grad") + + optim_g = torch.optim.AdamW( + filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致 + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + # optim_d = torch.optim.AdamW( + # net_d.parameters(), + # hps.train.learning_rate, + # betas=hps.train.betas, + # eps=hps.train.eps, + # ) + if torch.cuda.is_available(): + net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) + # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + # net_d = net_d.to(device) + + try: # 如果能加载自动resume + # _, _, _, epoch_str = utils.load_checkpoint( + # utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"), + # net_d, + # optim_d, + # ) # D多半加载没事 + # if rank == 0: + # logger.info("loaded D") + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"), + net_g, + optim_g, + ) + epoch_str += 1 + global_step = (epoch_str - 1) * len(train_loader) + # epoch_str = 1 + # global_step = 0 + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + if ( + hps.train.pretrained_s2G != "" + and hps.train.pretrained_s2G != None + and os.path.exists(hps.train.pretrained_s2G) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + "loaded pretrained %s" % hps.train.pretrained_s2G, + net_g.module.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ) + if torch.cuda.is_available() + else net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ), + ) ##测试不加载优化器 + # if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D): + # if rank == 0: + # logger.info("loaded pretrained %s" % hps.train.pretrained_s2D) + # print( + # net_d.module.load_state_dict( + # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"] + # ) if torch.cuda.is_available() else net_d.load_state_dict( + # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"] + # ) + # ) + + # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) + # scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + # optim_d, gamma=hps.train.lr_decay, last_epoch=-1 + # ) + for _ in range(epoch_str): + scheduler_g.step() + # scheduler_d.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + net_d = optim_d = scheduler_d = None + print("start training from epoch %s" % epoch_str) + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + # scheduler_d.step() + print("training done") + + +def train_and_evaluate( + rank, + epoch, + hps, + nets, + optims, + schedulers, + scaler, + loaders, + logger, + writers, +): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + # net_d.train() + # for batch_idx, ( + # ssl, + # ssl_lengths, + # spec, + # spec_lengths, + # y, + # y_lengths, + # text, + # text_lengths, + # ) in enumerate(tqdm(train_loader)): + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( + tqdm(train_loader) + ): + if torch.cuda.is_available(): + spec, spec_lengths = ( + spec.cuda( + rank, + non_blocking=True, + ), + spec_lengths.cuda( + rank, + non_blocking=True, + ), + ) + mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = ( + text.cuda( + rank, + non_blocking=True, + ), + text_lengths.cuda( + rank, + non_blocking=True, + ), + ) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + mel, mel_lengths = mel.to(device), mel_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + text, text_lengths = text.to(device), text_lengths.to(device) + + with autocast(enabled=hps.train.fp16_run): + cfm_loss = net_g( + ssl, + spec, + mel, + ssl_lengths, + spec_lengths, + text, + text_lengths, + mel_lengths, + use_grad_ckpt=hps.train.grad_ckpt, + ) + loss_gen_all = cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] + losses = [cfm_loss] + logger.info( + "Train Epoch: {} [{:.0f}%]".format( + epoch, + 100.0 * batch_idx / len(train_loader), + ) + ) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + # image_dict = { + # "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + # "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + # "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), + # "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()), + # } + utils.summarize( + writer=writer, + global_step=global_step, + # images=image_dict, + scalars=scalar_dict, + ) + + # if global_step % hps.train.eval_interval == 0: + # # evaluate(hps, net_g, eval_loader, writer_eval) + # utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step)),scaler) + # # utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step)),scaler) + # # keep_ckpts = getattr(hps.train, 'keep_ckpts', 3) + # # if keep_ckpts > 0: + # # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) + + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(global_step), + ), + ) + # utils.save_checkpoint( + # net_d, + # optim_d, + # hps.train.learning_rate, + # epoch, + # os.path.join( + # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step) + # ), + # ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join( + "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), + "G_{}.pth".format(233333333333), + ), + ) + # utils.save_checkpoint( + # net_d, + # optim_d, + # hps.train.learning_rate, + # epoch, + # os.path.join( + # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333) + # ), + # ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + ckpt, + hps.name + "_e%s_s%s" % (epoch, global_step), + epoch, + global_step, + hps, + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..ddeec4fcb0e046ca0d7a38cc11298b6791417b0b --- /dev/null +++ b/GPT_SoVITS/s2_train_v3_lora.py @@ -0,0 +1,379 @@ +import warnings + +warnings.filterwarnings("ignore") +import os + +import utils + +hps = utils.get_hparams(stage=2) +os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +import logging + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +logging.getLogger("matplotlib").setLevel(logging.INFO) +logging.getLogger("h5py").setLevel(logging.INFO) +logging.getLogger("numba").setLevel(logging.INFO) +from collections import OrderedDict as od +from random import randint + +from module import commons +from module.data_utils import ( + DistributedBucketSampler, + TextAudioSpeakerCollateV3, + TextAudioSpeakerLoaderV3, + TextAudioSpeakerCollateV4, + TextAudioSpeakerLoaderV4, + +) +from module.models import ( + SynthesizerTrnV3 as SynthesizerTrn, +) +from peft import LoraConfig, get_peft_model +from process_ckpt import savee + +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = False +###反正A100fp32更快,那试试tf32吧 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 +# from config import pretrained_s2G,pretrained_s2D +global_step = 0 + +device = "cpu" # cuda以外的设备,等mps优化后加入 + + +def main(): + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + else: + n_gpus = 1 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + + mp.spawn( + run, + nprocs=n_gpus, + args=( + n_gpus, + hps, + ), + ) + + +def run(rank, n_gpus, hps): + global global_step, no_grad_names, save_root, lora_rank + if rank == 0: + logger = utils.get_logger(hps.data.exp_dir) + logger.info(hps) + # utils.check_git_hash(hps.s2_ckpt_dir) + writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) + + dist.init_process_group( + backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False", + world_size=n_gpus, + rank=rank, + ) + torch.manual_seed(hps.train.seed) + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4 + TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4 + train_dataset = TextAudioSpeakerLoader(hps.data) ######## + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [ + 32, + 300, + 400, + 500, + 600, + 700, + 800, + 900, + 1000, + # 1100, + # 1200, + # 1300, + # 1400, + # 1500, + # 1600, + # 1700, + # 1800, + # 1900, + ], + num_replicas=n_gpus, + rank=rank, + shuffle=True, + ) + collate_fn = TextAudioSpeakerCollate() + train_loader = DataLoader( + train_dataset, + num_workers=6, + shuffle=False, + pin_memory=True, + collate_fn=collate_fn, + batch_sampler=train_sampler, + persistent_workers=True, + prefetch_factor=4, + ) + save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank) + os.makedirs(save_root, exist_ok=True) + lora_rank = int(hps.train.lora_rank) + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + + def get_model(hps): + return SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ) + + def get_optim(net_g): + return torch.optim.AdamW( + filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致 + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps, + ) + + def model2cuda(net_g, rank): + if torch.cuda.is_available(): + net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True) + else: + net_g = net_g.to(device) + return net_g + + try: # 如果能加载自动resume + net_g = get_model(hps) + net_g.cfm = get_peft_model(net_g.cfm, lora_config) + net_g = model2cuda(net_g, rank) + optim_g = get_optim(net_g) + # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0) + _, _, _, epoch_str = utils.load_checkpoint( + utils.latest_checkpoint_path(save_root, "G_*.pth"), + net_g, + optim_g, + ) + epoch_str += 1 + global_step = (epoch_str - 1) * len(train_loader) + except: # 如果首次不能加载,加载pretrain + # traceback.print_exc() + epoch_str = 1 + global_step = 0 + net_g = get_model(hps) + if ( + hps.train.pretrained_s2G != "" + and hps.train.pretrained_s2G != None + and os.path.exists(hps.train.pretrained_s2G) + ): + if rank == 0: + logger.info("loaded pretrained %s" % hps.train.pretrained_s2G) + print( + "loaded pretrained %s" % hps.train.pretrained_s2G, + net_g.load_state_dict( + torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"], + strict=False, + ), + ) + net_g.cfm = get_peft_model(net_g.cfm, lora_config) + net_g = model2cuda(net_g, rank) + optim_g = get_optim(net_g) + + no_grad_names = set() + for name, param in net_g.named_parameters(): + if not param.requires_grad: + no_grad_names.add(name.replace("module.", "")) + # print(name, "not requires_grad") + # print(no_grad_names) + # os._exit(233333) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1) + for _ in range(epoch_str): + scheduler_g.step() + + scaler = GradScaler(enabled=hps.train.fp16_run) + + net_d = optim_d = scheduler_d = None + print("start training from epoch %s" % epoch_str) + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + # [train_loader, eval_loader], logger, [writer, writer_eval]) + [train_loader, None], + logger, + [writer, writer_eval], + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + ) + scheduler_g.step() + print("training done") + + +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): + net_g, net_d = nets + optim_g, optim_d = optims + # scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate( + tqdm(train_loader) + ): + if torch.cuda.is_available(): + spec, spec_lengths = ( + spec.cuda( + rank, + non_blocking=True, + ), + spec_lengths.cuda( + rank, + non_blocking=True, + ), + ) + mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True) + ssl = ssl.cuda(rank, non_blocking=True) + ssl.requires_grad = False + text, text_lengths = ( + text.cuda( + rank, + non_blocking=True, + ), + text_lengths.cuda( + rank, + non_blocking=True, + ), + ) + else: + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + mel, mel_lengths = mel.to(device), mel_lengths.to(device) + ssl = ssl.to(device) + ssl.requires_grad = False + text, text_lengths = text.to(device), text_lengths.to(device) + + with autocast(enabled=hps.train.fp16_run): + cfm_loss = net_g( + ssl, + spec, + mel, + ssl_lengths, + spec_lengths, + text, + text_lengths, + mel_lengths, + use_grad_ckpt=hps.train.grad_ckpt, + ) + loss_gen_all = cfm_loss + optim_g.zero_grad() + scaler.scale(loss_gen_all).backward() + scaler.unscale_(optim_g) + grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]["lr"] + losses = [cfm_loss] + logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader))) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g} + utils.summarize( + writer=writer, + global_step=global_step, + scalars=scalar_dict, + ) + + global_step += 1 + if epoch % hps.train.save_every_epoch == 0 and rank == 0: + if hps.train.if_save_latest == 0: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(save_root, "G_{}.pth".format(global_step)), + ) + else: + utils.save_checkpoint( + net_g, + optim_g, + hps.train.learning_rate, + epoch, + os.path.join(save_root, "G_{}.pth".format(233333333333)), + ) + if rank == 0 and hps.train.if_save_every_weights == True: + if hasattr(net_g, "module"): + ckpt = net_g.module.state_dict() + else: + ckpt = net_g.state_dict() + sim_ckpt = od() + for key in ckpt: + # if "cfm"not in key: + # print(key) + if key not in no_grad_names: + sim_ckpt[key] = ckpt[key].half().cpu() + logger.info( + "saving ckpt %s_e%s:%s" + % ( + hps.name, + epoch, + savee( + sim_ckpt, + hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank), + epoch, + global_step, + hps,cfm_version=hps.model.version, + lora_rank=lora_rank, + ), + ) + ) + + if rank == 0: + logger.info("====> Epoch: {}".format(epoch)) + + +if __name__ == "__main__": + main() diff --git a/GPT_SoVITS/utils.py b/GPT_SoVITS/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc2d97f1ccba11802be23cb38a7703740a4e0ec --- /dev/null +++ b/GPT_SoVITS/utils.py @@ -0,0 +1,361 @@ +import argparse +import glob +import json +import logging +import os +import subprocess +import sys +import traceback + +import librosa +import numpy as np +import torch + +logging.getLogger("numba").setLevel(logging.ERROR) +logging.getLogger("matplotlib").setLevel(logging.ERROR) + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.ERROR) +logger = logging + + +def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if optimizer is not None and not skip_optimizer and checkpoint_dict["optimizer"] is not None: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): + try: + # assert "quantizer" not in k + # print("load", k) + new_state_dict[k] = saved_state_dict[k] + assert saved_state_dict[k].shape == v.shape, ( + saved_state_dict[k].shape, + v.shape, + ) + except: + traceback.print_exc() + print("error, %s is not in the checkpoint" % k) # shape不对也会,比如text_embedding当cleaner修改时 + new_state_dict[k] = v + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict) + else: + model.load_state_dict(new_state_dict) + print("load ") + logger.info( + "Loaded checkpoint '{}' (iteration {})".format( + checkpoint_path, + iteration, + ) + ) + return model, optimizer, learning_rate, iteration + + +import shutil +from time import time as ttime + + +def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path + dir = os.path.dirname(path) + name = os.path.basename(path) + tmp_path = "%s.pth" % (ttime()) + torch.save(fea, tmp_path) + shutil.move(tmp_path, "%s/%s" % (dir, name)) + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path)) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + # torch.save( + my_save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), + aspect="auto", + origin="lower", + interpolation="none", + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + data, sampling_rate = librosa.load(full_path, sr=None) + return torch.FloatTensor(data), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True, stage=1): + parser = argparse.ArgumentParser() + parser.add_argument( + "-c", + "--config", + type=str, + default="./configs/s2.json", + help="JSON file for configuration", + ) + parser.add_argument("-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir") + parser.add_argument( + "-rs", + "--resume_step", + type=int, + required=False, + default=None, + help="resume step", + ) + # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory') + # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights') + # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights') + + args = parser.parse_args() + + config_path = args.config + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.pretrain = args.pretrain + hparams.resume_step = args.resume_step + # hparams.data.exp_dir = args.exp_dir + if stage == 1: + model_dir = hparams.s1_ckpt_dir + else: + model_dir = hparams.s2_ckpt_dir + config_save_path = os.path.join(model_dir, "config.json") + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + with open(config_save_path, "w") as f: + f.write(data) + return hparams + + +def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): + """Freeing up space by deleting saved ckpts + + Arguments: + path_to_models -- Path to the model directory + n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth + sort_by_time -- True -> chronologically delete ckpts + False -> lexicographically delete ckpts + """ + import re + + ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] + name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1)) + time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)) + sort_key = time_key if sort_by_time else name_key + x_sorted = lambda _x: sorted( + [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], + key=sort_key, + ) + to_del = [ + os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep]) + ] + del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") + del_routine = lambda x: [os.remove(x), del_info(x)] + rs = [del_routine(fn) for fn in to_del] + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warn( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir, + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warn( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], + cur_hash[:8], + ) + ) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.ERROR) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.ERROR) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +if __name__ == "__main__": + print( + load_wav_to_torch( + "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac", + ) + ) diff --git a/GPT_SoVITS_Inference.ipynb b/GPT_SoVITS_Inference.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..1b8ec64de5313603853be41c3d88e8f1fef9eb1c --- /dev/null +++ b/GPT_SoVITS_Inference.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "himHYZmra7ix" + }, + "source": [ + "# Credits for bubarino giving me the huggingface import code (感谢 bubarino 给了我 huggingface 导入代码)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "e9b7iFV3dm1f" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n", + "%cd GPT-SoVITS\n", + "!apt-get update && apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && git lfs install\n", + "!pip install -r extra-req.txt --no-deps\n", + "!pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "0NgxXg5sjv7z" + }, + "outputs": [], + "source": [ + "# @title Download pretrained models 下载预训练模型\n", + "!mkdir -p /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", + "!mkdir -p /content/GPT-SoVITS/tools/damo_asr/models\n", + "!mkdir -p /content/GPT-SoVITS/tools/uvr5\n", + "%cd /content/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", + "!git clone https://huggingface.co/lj1995/GPT-SoVITS\n", + "%cd /content/GPT-SoVITS/tools/damo_asr/models\n", + "!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n", + "!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n", + "!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n", + "# @title UVR5 pretrains 安装uvr5模型\n", + "%cd /content/GPT-SoVITS/tools/uvr5\n", + "!git clone https://huggingface.co/Delik/uvr5_weights\n", + "!git config core.sparseCheckout true\n", + "!mv /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/GPT-SoVITS/* /content/GPT-SoVITS/GPT_SoVITS/pretrained_models/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "cPDEH-9czOJF" + }, + "outputs": [], + "source": [ + "#@title Create folder models 创建文件夹模型\n", + "import os\n", + "base_directory = \"/content/GPT-SoVITS\"\n", + "folder_names = [\"SoVITS_weights\", \"GPT_weights\"]\n", + "\n", + "for folder_name in folder_names:\n", + " if os.path.exists(os.path.join(base_directory, folder_name)):\n", + " print(f\"The folder '{folder_name}' already exists. (文件夹'{folder_name}'已经存在。)\")\n", + " else:\n", + " os.makedirs(os.path.join(base_directory, folder_name))\n", + " print(f\"The folder '{folder_name}' was created successfully! (文件夹'{folder_name}'已成功创建!)\")\n", + "\n", + "print(\"All folders have been created. (所有文件夹均已创建。)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "vbZY-LnM0tzq" + }, + "outputs": [], + "source": [ + "import requests\n", + "import zipfile\n", + "import shutil\n", + "import os\n", + "\n", + "#@title Import model 导入模型 (HuggingFace)\n", + "hf_link = 'https://huggingface.co/modelloosrvcc/Nagisa_Shingetsu_GPT-SoVITS/resolve/main/Nagisa.zip' #@param {type: \"string\"}\n", + "\n", + "output_path = '/content/'\n", + "\n", + "response = requests.get(hf_link)\n", + "with open(output_path + 'file.zip', 'wb') as file:\n", + " file.write(response.content)\n", + "\n", + "with zipfile.ZipFile(output_path + 'file.zip', 'r') as zip_ref:\n", + " zip_ref.extractall(output_path)\n", + "\n", + "os.remove(output_path + \"file.zip\")\n", + "\n", + "source_directory = output_path\n", + "SoVITS_destination_directory = '/content/GPT-SoVITS/SoVITS_weights'\n", + "GPT_destination_directory = '/content/GPT-SoVITS/GPT_weights'\n", + "\n", + "for filename in os.listdir(source_directory):\n", + " if filename.endswith(\".pth\"):\n", + " source_path = os.path.join(source_directory, filename)\n", + " destination_path = os.path.join(SoVITS_destination_directory, filename)\n", + " shutil.move(source_path, destination_path)\n", + "\n", + "for filename in os.listdir(source_directory):\n", + " if filename.endswith(\".ckpt\"):\n", + " source_path = os.path.join(source_directory, filename)\n", + " destination_path = os.path.join(GPT_destination_directory, filename)\n", + " shutil.move(source_path, destination_path)\n", + "\n", + "print(f'Model downloaded. (模型已下载。)')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "4oRGUzkrk8C7" + }, + "outputs": [], + "source": [ + "# @title launch WebUI 启动WebUI\n", + "!/usr/local/bin/pip install ipykernel\n", + "!sed -i '10s/False/True/' /content/GPT-SoVITS/config.py\n", + "%cd /content/GPT-SoVITS/\n", + "!/usr/local/bin/python webui.py" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f1d3206c3b40e0bcdc73384635ab8c59d377c997 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 RVC-Boss + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 7be5fc7f47d5db027d120b8024982df93db95b74..463649acccc7935d67633a3df19a6986aa19c14d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,409 @@ ---- -license: mit ---- +
+ +

GPT-SoVITS-WebUI

+A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.

+ +[![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/RVC-Boss/GPT-SoVITS) + +RVC-Boss%2FGPT-SoVITS | Trendshift + + + +[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb) +[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE) +[![Huggingface](https://img.shields.io/badge/🤗%20-online%20demo-yellow.svg?style=for-the-badge)](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2) +[![Discord](https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge)](https://discord.gg/dnrgs5GHfG) + +**English** | [**中文简体**](./docs/cn/README.md) | [**日本語**](./docs/ja/README.md) | [**한국어**](./docs/ko/README.md) | [**Türkçe**](./docs/tr/README.md) + +
+ +--- + +## Features: + +1. **Zero-shot TTS:** Input a 5-second vocal sample and experience instant text-to-speech conversion. + +2. **Few-shot TTS:** Fine-tune the model with just 1 minute of training data for improved voice similarity and realism. + +3. **Cross-lingual Support:** Inference in languages different from the training dataset, currently supporting English, Japanese, Korean, Cantonese and Chinese. + +4. **WebUI Tools:** Integrated tools include voice accompaniment separation, automatic training set segmentation, Chinese ASR, and text labeling, assisting beginners in creating training datasets and GPT/SoVITS models. + +**Check out our [demo video](https://www.bilibili.com/video/BV12g4y1m7Uw) here!** + +Unseen speakers few-shot fine-tuning demo: + +https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-80c060ab47fb + +**User guide: [简体中文](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) | [English](https://rentry.co/GPT-SoVITS-guide#/)** + +## Installation + +For users in China, you can [click here](https://www.codewithgpu.com/i/RVC-Boss/GPT-SoVITS/GPT-SoVITS-Official) to use AutoDL Cloud Docker to experience the full functionality online. + +### Tested Environments + +| Python Version | PyTorch Version | Device | +|----------------|------------------|-----------------| +| Python 3.9 | PyTorch 2.0.1 | CUDA 11.8 | +| Python 3.10.13 | PyTorch 2.1.2 | CUDA 12.3 | +| Python 3.10.17 | PyTorch 2.5.1 | CUDA 12.4 | +| Python 3.9 | PyTorch 2.5.1 | Apple silicon | +| Python 3.11 | PyTorch 2.6.0 | Apple silicon | +| Python 3.9 | PyTorch 2.2.2 | CPU | +| Python 3.9 | PyTorch 2.8.0dev | CUDA12.8(for Nvidia50x0) | + +### Windows + +If you are a Windows user (tested with win>=10), you can [download the integrated package](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-v3lora-20250228.7z?download=true) and double-click on _go-webui.bat_ to start GPT-SoVITS-WebUI. + +**Users in China can [download the package here](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e/dkxgpiy9zb96hob4#KTvnO).** + +### Linux + +```bash +conda create -n GPTSoVits python=3.9 +conda activate GPTSoVits +bash install.sh --source [--download-uvr5] +``` + +### macOS + +**Note: The models trained with GPUs on Macs result in significantly lower quality compared to those trained on other devices, so we are temporarily using CPUs instead.** + +1. Install Xcode command-line tools by running `xcode-select --install`. +2. Install the program by running the following commands: + +```bash +conda create -n GPTSoVits python=3.9 +conda activate GPTSoVits +bash install.sh --source [--download-uvr5] +``` + +### Install Manually + +#### Install FFmpeg + +##### Conda Users + +```bash +conda install ffmpeg +``` + +##### Ubuntu/Debian Users + +```bash +sudo apt install ffmpeg +sudo apt install libsox-dev +conda install -c conda-forge 'ffmpeg<7' +``` + +##### Windows Users + +Download and place [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) and [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) in the GPT-SoVITS root. + +Install [Visual Studio 2017](https://aka.ms/vs/17/release/vc_redist.x86.exe) (Korean TTS Only) + +##### MacOS Users + +```bash +brew install ffmpeg +``` + +#### Install Dependences + +```bash +pip install -r extra-req.txt --no-deps +pip install -r requirements.txt +``` + +### Using Docker + +#### docker-compose.yaml configuration + +0. Regarding image tags: Due to rapid updates in the codebase and the slow process of packaging and testing images, please check [Docker Hub](https://hub.docker.com/r/breakstring/gpt-sovits)(outdated) for the currently packaged latest images and select as per your situation, or alternatively, build locally using a Dockerfile according to your own needs. +1. Environment Variables: + - is_half: Controls half-precision/double-precision. This is typically the cause if the content under the directories 4-cnhubert/5-wav32k is not generated correctly during the "SSL extracting" step. Adjust to True or False based on your actual situation. +2. Volumes Configuration, The application's root directory inside the container is set to /workspace. The default docker-compose.yaml lists some practical examples for uploading/downloading content. +3. shm_size: The default available memory for Docker Desktop on Windows is too small, which can cause abnormal operations. Adjust according to your own situation. +4. Under the deploy section, GPU-related settings should be adjusted cautiously according to your system and actual circumstances. + +#### Running with docker compose + +``` +docker compose -f "docker-compose.yaml" up -d +``` + +#### Running with docker command + +As above, modify the corresponding parameters based on your actual situation, then run the following command: + +``` +docker run --rm -it --gpus=all --env=is_half=False --volume=G:\GPT-SoVITS-DockerTest\output:/workspace/output --volume=G:\GPT-SoVITS-DockerTest\logs:/workspace/logs --volume=G:\GPT-SoVITS-DockerTest\SoVITS_weights:/workspace/SoVITS_weights --workdir=/workspace -p 9880:9880 -p 9871:9871 -p 9872:9872 -p 9873:9873 -p 9874:9874 --shm-size="16G" -d breakstring/gpt-sovits:xxxxx +``` + +## Pretrained Models + +**If `install.sh` runs successfully, you may skip No.1,2,3** + +**Users in China can [download all these models here](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e/dkxgpiy9zb96hob4#nVNhX).** + +1. Download pretrained models from [GPT-SoVITS Models](https://huggingface.co/lj1995/GPT-SoVITS) and place them in `GPT_SoVITS/pretrained_models`. + +2. Download G2PW models from [G2PWModel.zip(HF)](https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip)| [G2PWModel.zip(ModelScope)](https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/G2PWModel.zip), unzip and rename to `G2PWModel`, and then place them in `GPT_SoVITS/text`.(Chinese TTS Only) + +3. For UVR5 (Vocals/Accompaniment Separation & Reverberation Removal, additionally), download models from [UVR5 Weights](https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main/uvr5_weights) and place them in `tools/uvr5/uvr5_weights`. + + - If you want to use `bs_roformer` or `mel_band_roformer` models for UVR5, you can manually download the model and corresponding configuration file, and put them in `tools/uvr5/uvr5_weights`. **Rename the model file and configuration file, ensure that the model and configuration files have the same and corresponding names except for the suffix**. In addition, the model and configuration file names **must include `roformer`** in order to be recognized as models of the roformer class. + + - The suggestion is to **directly specify the model type** in the model name and configuration file name, such as `mel_mand_roformer`, `bs_roformer`. If not specified, the features will be compared from the configuration file to determine which type of model it is. For example, the model `bs_roformer_ep_368_sdr_12.9628.ckpt` and its corresponding configuration file `bs_roformer_ep_368_sdr_12.9628.yaml` are a pair, `kim_mel_band_roformer.ckpt` and `kim_mel_band_roformer.yaml` are also a pair. + +4. For Chinese ASR (additionally), download models from [Damo ASR Model](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files), [Damo VAD Model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/files), and [Damo Punc Model](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/files) and place them in `tools/asr/models`. + +5. For English or Japanese ASR (additionally), download models from [Faster Whisper Large V3](https://huggingface.co/Systran/faster-whisper-large-v3) and place them in `tools/asr/models`. Also, [other models](https://huggingface.co/Systran) may have the similar effect with smaller disk footprint. + +## Dataset Format + +The TTS annotation .list file format: + +``` +vocal_path|speaker_name|language|text +``` + +Language dictionary: + +- 'zh': Chinese +- 'ja': Japanese +- 'en': English +- 'ko': Korean +- 'yue': Cantonese + +Example: + +``` +D:\GPT-SoVITS\xxx/xxx.wav|xxx|en|I like playing Genshin. +``` + +## Finetune and inference + +### Open WebUI + +#### Integrated Package Users + +Double-click `go-webui.bat`or use `go-webui.ps1` +if you want to switch to V1,then double-click`go-webui-v1.bat` or use `go-webui-v1.ps1` + +#### Others + +```bash +python webui.py +``` + +if you want to switch to V1,then + +```bash +python webui.py v1 +``` + +Or maunally switch version in WebUI + +### Finetune + +#### Path Auto-filling is now supported + + 1. Fill in the audio path + 2. Slice the audio into small chunks + 3. Denoise(optinal) + 4. ASR + 5. Proofreading ASR transcriptions + 6. Go to the next Tab, then finetune the model + +### Open Inference WebUI + +#### Integrated Package Users + +Double-click `go-webui-v2.bat` or use `go-webui-v2.ps1` ,then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference` + +#### Others + +```bash +python GPT_SoVITS/inference_webui.py +``` + +OR + +```bash +python webui.py +``` + +then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference` + +## V2 Release Notes + +New Features: + +1. Support Korean and Cantonese + +2. An optimized text frontend + +3. Pre-trained model extended from 2k hours to 5k hours + +4. Improved synthesis quality for low-quality reference audio + + [more details]() + +Use v2 from v1 environment: + +1. `pip install -r requirements.txt` to update some packages + +2. Clone the latest codes from github. + +3. Download v2 pretrained models from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main/gsv-v2final-pretrained) and put them into `GPT_SoVITS\pretrained_models\gsv-v2final-pretrained`. + + Chinese v2 additional: [G2PWModel.zip(HF)](https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip)| [G2PWModel.zip(ModelScope)](https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/G2PWModel.zip)(Download G2PW models, unzip and rename to `G2PWModel`, and then place them in `GPT_SoVITS/text`.) + +## V3 Release Notes + +New Features: + +1. The timbre similarity is higher, requiring less training data to approximate the target speaker (the timbre similarity is significantly improved using the base model directly without fine-tuning). + +2. GPT model is more stable, with fewer repetitions and omissions, and it is easier to generate speech with richer emotional expression. + + [more details]() + +Use v3 from v2 environment: + +1. `pip install -r requirements.txt` to update some packages + +2. Clone the latest codes from github. + +3. Download v3 pretrained models (s1v3.ckpt, s2Gv3.pth and models--nvidia--bigvgan_v2_24khz_100band_256x folder) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS\pretrained_models`. + + additional: for Audio Super Resolution model, you can read [how to download](./tools/AP_BWE_main/24kto48k/readme.txt) + +## V4 Release Notes + +New Features: + +1. Version 4 fixes the issue of metallic artifacts in Version 3 caused by non-integer multiple upsampling, and natively outputs 48k audio to prevent muffled sound (whereas Version 3 only natively outputs 24k audio). The author considers Version 4 a direct replacement for Version 3, though further testing is still needed. + [more details]() + +Use v4 from v1/v2/v3 environment: + +1. `pip install -r requirements.txt` to update some packages + +2. Clone the latest codes from github. + +3. Download v4 pretrained models (gsv-v4-pretrained/s2v4.ckpt, and gsv-v4-pretrained/vocoder.pth) from [huggingface](https://huggingface.co/lj1995/GPT-SoVITS/tree/main) and put them into `GPT_SoVITS\pretrained_models`. + +## Todo List + +- [x] **High Priority:** + + - [x] Localization in Japanese and English. + - [x] User guide. + - [x] Japanese and English dataset fine tune training. + +- [ ] **Features:** + - [x] Zero-shot voice conversion (5s) / few-shot voice conversion (1min). + - [x] TTS speaking speed control. + - [ ] ~~Enhanced TTS emotion control.~~ Maybe use pretrained finetuned preset GPT models for better emotion. + - [ ] Experiment with changing SoVITS token inputs to probability distribution of GPT vocabs (transformer latent). + - [x] Improve English and Japanese text frontend. + - [ ] Develop tiny and larger-sized TTS models. + - [x] Colab scripts. + - [x] Try expand training dataset (2k hours -> 10k hours). + - [x] better sovits base model (enhanced audio quality) + - [ ] model mix + +## (Additional) Method for running from the command line + +Use the command line to open the WebUI for UVR5 + +``` +python tools/uvr5/webui.py "" +``` + + + +This is how the audio segmentation of the dataset is done using the command line + +``` +python audio_slicer.py \ + --input_path "" \ + --output_root "" \ + --threshold \ + --min_length \ + --min_interval + --hop_size +``` + +This is how dataset ASR processing is done using the command line(Only Chinese) + +``` +python tools/asr/funasr_asr.py -i -o +``` + +ASR processing is performed through Faster_Whisper(ASR marking except Chinese) + +(No progress bars, GPU performance may cause time delays) + +``` +python ./tools/asr/fasterwhisper_asr.py -i -o -l -p +``` + +A custom list save path is enabled + +## Credits + +Special thanks to the following projects and contributors: + +### Theoretical Research + +- [ar-vits](https://github.com/innnky/ar-vits) +- [SoundStorm](https://github.com/yangdongchao/SoundStorm/tree/master/soundstorm/s1/AR) +- [vits](https://github.com/jaywalnut310/vits) +- [TransferTTS](https://github.com/hcy71o/TransferTTS/blob/master/models.py#L556) +- [contentvec](https://github.com/auspicious3000/contentvec/) +- [hifi-gan](https://github.com/jik876/hifi-gan) +- [fish-speech](https://github.com/fishaudio/fish-speech/blob/main/tools/llama/generate.py#L41) +- [f5-TTS](https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/model/backbones/dit.py) +- [shortcut flow matching](https://github.com/kvfrans/shortcut-models/blob/main/targets_shortcut.py) + +### Pretrained Models + +- [Chinese Speech Pretrain](https://github.com/TencentGameMate/chinese_speech_pretrain) +- [Chinese-Roberta-WWM-Ext-Large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large) +- [BigVGAN](https://github.com/NVIDIA/BigVGAN) + +### Text Frontend for Inference + +- [paddlespeech zh_normalization](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization) +- [split-lang](https://github.com/DoodleBears/split-lang) +- [g2pW](https://github.com/GitYCC/g2pW) +- [pypinyin-g2pW](https://github.com/mozillazg/pypinyin-g2pW) +- [paddlespeech g2pw](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/g2pw) + +### WebUI Tools + +- [ultimatevocalremovergui](https://github.com/Anjok07/ultimatevocalremovergui) +- [audio-slicer](https://github.com/openvpi/audio-slicer) +- [SubFix](https://github.com/cronrpc/SubFix) +- [FFmpeg](https://github.com/FFmpeg/FFmpeg) +- [gradio](https://github.com/gradio-app/gradio) +- [faster-whisper](https://github.com/SYSTRAN/faster-whisper) +- [FunASR](https://github.com/alibaba-damo-academy/FunASR) +- [AP-BWE](https://github.com/yxlu-0102/AP-BWE) + +Thankful to @Naozumi520 for providing the Cantonese training set and for the guidance on Cantonese-related knowledge. + +## Thanks to all contributors for their efforts + + + + diff --git a/api.py b/api.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c917a01e3dd82961eb255e3f8a80aa35662d8e --- /dev/null +++ b/api.py @@ -0,0 +1,1236 @@ +""" +# api.py usage + +` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` + +## 执行参数: + +`-s` - `SoVITS模型路径, 可在 config.py 中指定` +`-g` - `GPT模型路径, 可在 config.py 中指定` + +调用请求缺少参考音频时使用 +`-dr` - `默认参考音频路径` +`-dt` - `默认参考音频文本` +`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"` + +`-d` - `推理设备, "cuda","cpu"` +`-a` - `绑定地址, 默认"127.0.0.1"` +`-p` - `绑定端口, 默认9880, 可在 config.py 中指定` +`-fp` - `覆盖 config.py 使用全精度` +`-hp` - `覆盖 config.py 使用半精度` +`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` +·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` +·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"` +·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入` + +`-hb` - `cnhubert路径` +`-b` - `bert路径` + +## 调用: + +### 推理 + +endpoint: `/` + +使用执行参数指定的参考音频: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +使用执行参数指定的参考音频并设定分割符号: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&cut_punc=,。` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "cut_punc": ",。", +} +``` + +手动指定当次推理所使用的参考音频: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +手动指定当次推理所使用的参考音频,并提供参数: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "top_k": 20, + "top_p": 0.6, + "temperature": 0.6, + "speed": 1, + "inp_refs": ["456.wav","789.wav"] +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 更换默认参考音频 + +endpoint: `/change_refer` + +key与推理端一样 + +GET: + `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh" +} +``` + +RESP: +成功: json, http code 200 +失败: json, 400 + + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: + `http://127.0.0.1:9880/control?command=restart` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + +""" + +import argparse +import os +import re +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import signal +from text.LangSegmenter import LangSegmenter +from time import time as ttime +import torch +import torchaudio +import librosa +import soundfile as sf +from fastapi import FastAPI, Request, Query +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from transformers import AutoModelForMaskedLM, AutoTokenizer +import numpy as np +from feature_extractor import cnhubert +from io import BytesIO +from module.models import SynthesizerTrn, SynthesizerTrnV3 +from peft import LoraConfig, get_peft_model +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from text import cleaned_text_to_sequence +from text.cleaner import clean_text +from module.mel_processing import spectrogram_torch +import config as global_config +import logging +import subprocess + + +class DefaultRefer: + def __init__(self, path, text, language): + self.path = args.default_refer_path + self.text = args.default_refer_text + self.language = args.default_refer_language + + def is_ready(self) -> bool: + return is_full(self.path, self.text, self.language) + + +def is_empty(*items): # 任意一项不为空返回False + for item in items: + if item is not None and item != "": + return False + return True + + +def is_full(*items): # 任意一项为空返回False + for item in items: + if item is None or item == "": + return False + return True + + +def init_bigvgan(): + global bigvgan_model + from BigVGAN import bigvgan + + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +resample_transform_dict = {} + + +def resample(audio_tensor, sr0): + global resample_transform_dict + if sr0 not in resample_transform_dict: + resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device) + return resample_transform_dict[sr0](audio_tensor) + + +from module.mel_processing import mel_spectrogram_torch + +spec_min = -12 +spec_max = 2 + + +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + + +sr_model = None + + +def audio_sr(audio, sr): + global sr_model + if sr_model == None: + from tools.audio_sr import AP_BWE + + try: + sr_model = AP_BWE(device, DictToAttrRecursive) + except FileNotFoundError: + logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载") + return audio.cpu().detach().numpy(), sr + return sr_model(audio, sr) + + +class Speaker: + def __init__(self, name, gpt, sovits, phones=None, bert=None, prompt=None): + self.name = name + self.sovits = sovits + self.gpt = gpt + self.phones = phones + self.bert = bert + self.prompt = prompt + + +speaker_list = {} + + +class Sovits: + def __init__(self, vq_model, hps): + self.vq_model = vq_model + self.hps = hps + + +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new + + +def get_sovits_weights(sovits_path): + path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + is_exist_s2gv3 = os.path.exists(path_sovits_v3) + + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + if if_lora_v3 == True and is_exist_s2gv3 == False: + logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + + dict_s2 = load_sovits_new(sovits_path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps.model.version = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + + if model_version == "v3": + hps.model.version = "v3" + + model_params_dict = vars(hps.model) + if model_version != "v3": + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict, + ) + else: + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict, + ) + init_bigvgan() + model_version = hps.model.version + logger.info(f"模型版本: {model_version}") + if "pretrained" not in sovits_path: + try: + del vq_model.enc_q + except: + pass + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + if if_lora_v3 == False: + vq_model.load_state_dict(dict_s2["weight"], strict=False) + else: + vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) + lora_rank = dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() + + sovits = Sovits(vq_model, hps) + return sovits + + +class Gpt: + def __init__(self, max_sec, t2s_model): + self.max_sec = max_sec + self.t2s_model = t2s_model + + +global hz +hz = 50 + + +def get_gpt_weights(gpt_path): + dict_s1 = torch.load(gpt_path, map_location="cpu") + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # logger.info("Number of parameter: %.2fM" % (total / 1e6)) + + gpt = Gpt(max_sec, t2s_model) + return gpt + + +def change_gpt_sovits_weights(gpt_path, sovits_path): + try: + gpt = get_gpt_weights(gpt_path) + sovits = get_sovits_weights(sovits_path) + except Exception as e: + return JSONResponse({"code": 400, "message": str(e)}, status_code=400) + + speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + # if(is_half==True):phone_level_feature=phone_level_feature.half() + return phone_level_feature.T + + +def clean_text_inf(text, language, version): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + +def get_bert_inf(phones, word2ph, norm_text, language): + language = language.replace("all_", "") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +from text import chinese + + +def get_phones_and_bert(text, language, version, final=False): + if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + if language == "all_zh": + if re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext, "zh", version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = get_bert_feature(norm_text, word2ph).to(device) + elif language == "all_yue" and re.search(r"[A-Za-z]", formattext): + formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext) + formattext = chinese.mix_text_normalize(formattext) + return get_phones_and_bert(formattext, "yue", version) + else: + phones, word2ph, norm_text = clean_text_inf(formattext, language, version) + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}: + textlist = [] + langlist = [] + if language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +def get_spepc(hps, filename): + audio, _ = librosa.load(filename, int(hps.data.sampling_rate)) + audio = torch.FloatTensor(audio) + maxx = audio.abs().max() + if maxx > 1: + audio /= min(2, maxx) + audio_norm = audio + audio_norm = audio_norm.unsqueeze(0) + spec = spectrogram_torch( + audio_norm, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + return spec + + +def pack_audio(audio_bytes, data, rate): + if media_type == "ogg": + audio_bytes = pack_ogg(audio_bytes, data, rate) + elif media_type == "aac": + audio_bytes = pack_aac(audio_bytes, data, rate) + else: + # wav无法流式, 先暂存raw + audio_bytes = pack_raw(audio_bytes, data, rate) + + return audio_bytes + + +def pack_ogg(audio_bytes, data, rate): + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + import threading + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + + return audio_bytes + + +def pack_raw(audio_bytes, data, rate): + audio_bytes.write(data.tobytes()) + + return audio_bytes + + +def pack_wav(audio_bytes, rate): + if is_int32: + data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int32) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format="WAV", subtype="PCM_32") + else: + data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format="WAV") + return wav_bytes + + +def pack_aac(audio_bytes, data, rate): + if is_int32: + pcm = "s32le" + bit_rate = "256k" + else: + pcm = "s16le" + bit_rate = "128k" + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + pcm, # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + bit_rate, # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + audio_bytes.write(out) + + return audio_bytes + + +def read_clean_buffer(audio_bytes): + audio_chunk = audio_bytes.getvalue() + audio_bytes.truncate(0) + audio_bytes.seek(0) + + return audio_bytes, audio_chunk + + +def cut_text(text, punc): + punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}] + if len(punc_list) > 0: + punds = r"[" + "".join(punc_list) + r"]" + text = text.strip("\n") + items = re.split(f"({punds})", text) + mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] + # 在句子不存在符号或句尾无符号的时候保证文本完整 + if len(items) % 2 == 1: + mergeitems.append(items[-1]) + text = "\n".join(mergeitems) + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + + return text + + +def only_punc(text): + return not any(t.isalnum() or t.isalpha() for t in text) + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} + + +def get_tts_wav( + ref_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k=15, + top_p=0.6, + temperature=0.6, + speed=1, + inp_refs=None, + sample_steps=32, + if_sr=False, + spk="default", +): + infer_sovits = speaker_list[spk].sovits + vq_model = infer_sovits.vq_model + hps = infer_sovits.hps + version = vq_model.version + + infer_gpt = speaker_list[spk].gpt + t2s_model = infer_gpt.t2s_model + max_sec = infer_gpt.max_sec + + t0 = ttime() + prompt_text = prompt_text.strip("\n") + if prompt_text[-1] not in splits: + prompt_text += "。" if prompt_language != "en" else "." + prompt_language, text = prompt_language, text.strip("\n") + dtype = torch.float16 if is_half == True else torch.float32 + zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + if version != "v3": + refers = [] + if inp_refs: + for path in inp_refs: + try: + refer = get_spepc(hps, path).to(dtype).to(device) + refers.append(refer) + except Exception as e: + logger.error(e) + if len(refers) == 0: + refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] + else: + refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) + + t1 = ttime() + # os.environ['version'] = version + prompt_language = dict_language[prompt_language.lower()] + text_language = dict_language[text_language.lower()] + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) + texts = text.split("\n") + audio_bytes = BytesIO() + + for text in texts: + # 简单防止纯符号引发参考音频泄露 + if only_punc(text): + continue + + audio_opt = [] + if text[-1] not in splits: + text += "。" if text_language != "en" else "." + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) + bert = torch.cat([bert1, bert2], 1) + + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + t2 = ttime() + with torch.no_grad(): + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + t3 = ttime() + + if version != "v3": + audio = ( + vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) + .detach() + .cpu() + .numpy()[0, 0] + ) ###试试重建不带上prompt部分 + else: + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + # print(11111111, phoneme_ids0, phoneme_ids1) + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + if sr != 24000: + ref_audio = resample(ref_audio, sr) + # print("ref_audio",ref_audio.abs().mean()) + mel2 = mel_fn(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + if T_min > 468: + mel2 = mel2[:, :, -468:] + fea_ref = fea_ref[:, :, -468:] + T_min = 468 + chunk_len = 934 - T_min + # print("fea_ref",fea_ref,fea_ref.shape) + # print("mel2",mel2) + mel2 = mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) + # print("fea_todo",fea_todo) + # print("ge",ge.abs().mean()) + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + # set_seed(123) + cfm_res = vq_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + # print("fea", fea) + # print("mel2in", mel2) + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cmf_res = torch.cat(cfm_resss, 2) + cmf_res = denorm_spec(cmf_res) + if bigvgan_model == None: + init_bigvgan() + with torch.inference_mode(): + wav_gen = bigvgan_model(cmf_res) + audio = wav_gen[0][0].cpu().detach().numpy() + + max_audio = np.abs(audio).max() + if max_audio > 1: + audio /= max_audio + audio_opt.append(audio) + audio_opt.append(zero_wav) + audio_opt = np.concatenate(audio_opt, 0) + t4 = ttime() + + sr = hps.data.sampling_rate if version != "v3" else 24000 + if if_sr and sr == 24000: + audio_opt = torch.from_numpy(audio_opt).float().to(device) + audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) + max_audio = np.abs(audio_opt).max() + if max_audio > 1: + audio_opt /= max_audio + sr = 48000 + + if is_int32: + audio_bytes = pack_audio(audio_bytes, (audio_opt * 2147483647).astype(np.int32), sr) + else: + audio_bytes = pack_audio(audio_bytes, (audio_opt * 32768).astype(np.int16), sr) + # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + if stream_mode == "normal": + audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) + yield audio_chunk + + if not stream_mode == "normal": + if media_type == "wav": + sr = 48000 if if_sr else 24000 + sr = hps.data.sampling_rate if version != "v3" else sr + audio_bytes = pack_wav(audio_bytes, sr) + yield audio_bytes.getvalue() + + +def handle_control(command): + if command == "restart": + os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def handle_change(path, text, language): + if is_empty(path, text, language): + return JSONResponse( + {"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400 + ) + + if path != "" or path is not None: + default_refer.path = path + if text != "" or text is not None: + default_refer.text = text + if language != "" or language is not None: + default_refer.language = language + + logger.info(f"当前默认参考音频路径: {default_refer.path}") + logger.info(f"当前默认参考音频文本: {default_refer.text}") + logger.info(f"当前默认参考音频语种: {default_refer.language}") + logger.info(f"is_ready: {default_refer.is_ready()}") + + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def handle( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + cut_punc, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, +): + if ( + refer_wav_path == "" + or refer_wav_path is None + or prompt_text == "" + or prompt_text is None + or prompt_language == "" + or prompt_language is None + ): + refer_wav_path, prompt_text, prompt_language = ( + default_refer.path, + default_refer.text, + default_refer.language, + ) + if not default_refer.is_ready(): + return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + + if sample_steps not in [4, 8, 16, 32]: + sample_steps = 32 + + if cut_punc == None: + text = cut_text(text, default_cut_punc) + else: + text = cut_text(text, cut_punc) + + return StreamingResponse( + get_tts_wav( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, + ), + media_type="audio/" + media_type, + ) + + +# -------------------------------- +# 初始化部分 +# -------------------------------- +dict_language = { + "中文": "all_zh", + "粤语": "all_yue", + "英文": "en", + "日文": "all_ja", + "韩文": "all_ko", + "中英混合": "zh", + "粤英混合": "yue", + "日英混合": "ja", + "韩英混合": "ko", + "多语种混合": "auto", # 多语种启动切分识别语种 + "多语种混合(粤语)": "auto_yue", + "all_zh": "all_zh", + "all_yue": "all_yue", + "en": "en", + "all_ja": "all_ja", + "all_ko": "all_ko", + "zh": "zh", + "yue": "yue", + "ja": "ja", + "ko": "ko", + "auto": "auto", + "auto_yue": "auto_yue", +} + +# logger +logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) +logger = logging.getLogger("uvicorn") + +# 获取配置 +g_config = global_config.Config() + +# 获取参数 +parser = argparse.ArgumentParser(description="GPT-SoVITS api") + +parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") +parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") +parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") +parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") +parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") +parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") +parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") +parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") +parser.add_argument( + "-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度" +) +parser.add_argument( + "-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度" +) +# bool值的用法为 `python ./api.py -fp ...` +# 此时 full_precision==True, half_precision==False +parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") +parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") +parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32") +parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") +# 切割常用分句符为 `python ./api.py -cp ".?!。?!"` +parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") +parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") + +args = parser.parse_args() +sovits_path = args.sovits_path +gpt_path = args.gpt_path +device = args.device +port = args.port +host = args.bind_addr +cnhubert_base_path = args.hubert_path +bert_path = args.bert_path +default_cut_punc = args.cut_punc + +# 应用参数配置 +default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) + +# 模型路径检查 +if sovits_path == "": + sovits_path = g_config.pretrained_sovits_path + logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}") +if gpt_path == "": + gpt_path = g_config.pretrained_gpt_path + logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}") + +# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 +if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": + default_refer.path, default_refer.text, default_refer.language = "", "", "" + logger.info("未指定默认参考音频") +else: + logger.info(f"默认参考音频路径: {default_refer.path}") + logger.info(f"默认参考音频文本: {default_refer.text}") + logger.info(f"默认参考音频语种: {default_refer.language}") + +# 获取半精度 +is_half = g_config.is_half +if args.full_precision: + is_half = False +if args.half_precision: + is_half = True +if args.full_precision and args.half_precision: + is_half = g_config.is_half # 炒饭fallback +logger.info(f"半精: {is_half}") + +# 流式返回模式 +if args.stream_mode.lower() in ["normal", "n"]: + stream_mode = "normal" + logger.info("流式返回已开启") +else: + stream_mode = "close" + +# 音频编码格式 +if args.media_type.lower() in ["aac", "ogg"]: + media_type = args.media_type.lower() +elif stream_mode == "close": + media_type = "wav" +else: + media_type = "ogg" +logger.info(f"编码格式: {media_type}") + +# 音频数据类型 +if args.sub_type.lower() == "int32": + is_int32 = True + logger.info("数据类型: int32") +else: + is_int32 = False + logger.info("数据类型: int16") + +# 初始化模型 +cnhubert.cnhubert_base_path = cnhubert_base_path +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +ssl_model = cnhubert.get_model() +if is_half: + bert_model = bert_model.half().to(device) + ssl_model = ssl_model.half().to(device) +else: + bert_model = bert_model.to(device) + ssl_model = ssl_model.to(device) +change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path) + + +# -------------------------------- +# 接口部分 +# -------------------------------- +app = FastAPI() + + +@app.post("/set_model") +async def set_model(request: Request): + json_post_raw = await request.json() + return change_gpt_sovits_weights( + gpt_path=json_post_raw.get("gpt_model_path"), sovits_path=json_post_raw.get("sovits_model_path") + ) + + +@app.get("/set_model") +async def set_model( + gpt_model_path: str = None, + sovits_model_path: str = None, +): + return change_gpt_sovits_weights(gpt_path=gpt_model_path, sovits_path=sovits_model_path) + + +@app.post("/control") +async def control(request: Request): + json_post_raw = await request.json() + return handle_control(json_post_raw.get("command")) + + +@app.get("/control") +async def control(command: str = None): + return handle_control(command) + + +@app.post("/change_refer") +async def change_refer(request: Request): + json_post_raw = await request.json() + return handle_change( + json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language") + ) + + +@app.get("/change_refer") +async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None): + return handle_change(refer_wav_path, prompt_text, prompt_language) + + +@app.post("/") +async def tts_endpoint(request: Request): + json_post_raw = await request.json() + return handle( + json_post_raw.get("refer_wav_path"), + json_post_raw.get("prompt_text"), + json_post_raw.get("prompt_language"), + json_post_raw.get("text"), + json_post_raw.get("text_language"), + json_post_raw.get("cut_punc"), + json_post_raw.get("top_k", 15), + json_post_raw.get("top_p", 1.0), + json_post_raw.get("temperature", 1.0), + json_post_raw.get("speed", 1.0), + json_post_raw.get("inp_refs", []), + json_post_raw.get("sample_steps", 32), + json_post_raw.get("if_sr", False), + ) + + +@app.get("/") +async def tts_endpoint( + refer_wav_path: str = None, + prompt_text: str = None, + prompt_language: str = None, + text: str = None, + text_language: str = None, + cut_punc: str = None, + top_k: int = 15, + top_p: float = 1.0, + temperature: float = 1.0, + speed: float = 1.0, + inp_refs: list = Query(default=[]), + sample_steps: int = 32, + if_sr: bool = False, +): + return handle( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + cut_punc, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, + ) + + +if __name__ == "__main__": + uvicorn.run(app, host=host, port=port, workers=1) diff --git a/api_v2.py b/api_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..8708207432d2ec53a99b4e4e53cfbdaa57dd2769 --- /dev/null +++ b/api_v2.py @@ -0,0 +1,500 @@ +""" +# WebAPI文档 + +` python api_v2.py -a 127.0.0.1 -p 9880 -c GPT_SoVITS/configs/tts_infer.yaml ` + +## 执行参数: + `-a` - `绑定地址, 默认"127.0.0.1"` + `-p` - `绑定端口, 默认9880` + `-c` - `TTS配置文件路径, 默认"GPT_SoVITS/configs/tts_infer.yaml"` + +## 调用: + +### 推理 + +endpoint: `/tts` +GET: +``` +http://127.0.0.1:9880/tts?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_lang=zh&ref_audio_path=archive_jingyuan_1.wav&prompt_lang=zh&prompt_text=我是「罗浮」云骑将军景元。不必拘谨,「将军」只是一时的身份,你称呼我景元便可&text_split_method=cut5&batch_size=1&media_type=wav&streaming_mode=true +``` + +POST: +```json +{ + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker tone fusion + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 5, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "streaming_mode": False, # bool. whether to return a streaming response. + "seed": -1, # int. random seed for reproducibility. + "parallel_infer": True, # bool. whether to use parallel inference. + "repetition_penalty": 1.35 # float. repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: +``` +http://127.0.0.1:9880/control?command=restart +``` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + + +### 切换GPT模型 + +endpoint: `/set_gpt_weights` + +GET: +``` +http://127.0.0.1:9880/set_gpt_weights?weights_path=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +``` +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 切换Sovits模型 + +endpoint: `/set_sovits_weights` + +GET: +``` +http://127.0.0.1:9880/set_sovits_weights?weights_path=GPT_SoVITS/pretrained_models/s2G488k.pth +``` + +RESP: +成功: 返回"success", http code 200 +失败: 返回包含错误信息的 json, http code 400 + +""" + +import os +import sys +import traceback +from typing import Generator + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import argparse +import subprocess +import wave +import signal +import numpy as np +import soundfile as sf +from fastapi import FastAPI, Response +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from io import BytesIO +from tools.i18n.i18n import I18nAuto +from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config +from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method_names as get_cut_method_names +from pydantic import BaseModel + +# print(sys.path) +i18n = I18nAuto() +cut_method_names = get_cut_method_names() + +parser = argparse.ArgumentParser(description="GPT-SoVITS api") +parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径") +parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") +parser.add_argument("-p", "--port", type=int, default="9880", help="default: 9880") +args = parser.parse_args() +config_path = args.tts_config +# device = args.device +port = args.port +host = args.bind_addr +argv = sys.argv + +if config_path in [None, ""]: + config_path = "GPT-SoVITS/configs/tts_infer.yaml" + +tts_config = TTS_Config(config_path) +print(tts_config) +tts_pipeline = TTS(tts_config) + +APP = FastAPI() + + +class TTS_Request(BaseModel): + text: str = None + text_lang: str = None + ref_audio_path: str = None + aux_ref_audio_paths: list = None + prompt_lang: str = None + prompt_text: str = "" + top_k: int = 5 + top_p: float = 1 + temperature: float = 1 + text_split_method: str = "cut5" + batch_size: int = 1 + batch_threshold: float = 0.75 + split_bucket: bool = True + speed_factor: float = 1.0 + fragment_interval: float = 0.3 + seed: int = -1 + media_type: str = "wav" + streaming_mode: bool = False + parallel_infer: bool = True + repetition_penalty: float = 1.35 + sample_steps: int = 32 + super_sampling: bool = False + + +### modify from https://github.com/RVC-Boss/GPT-SoVITS/pull/894/files +def pack_ogg(io_buffer: BytesIO, data: np.ndarray, rate: int): + with sf.SoundFile(io_buffer, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + return io_buffer + + +def pack_raw(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer.write(data.tobytes()) + return io_buffer + + +def pack_wav(io_buffer: BytesIO, data: np.ndarray, rate: int): + io_buffer = BytesIO() + sf.write(io_buffer, data, rate, format="wav") + return io_buffer + + +def pack_aac(io_buffer: BytesIO, data: np.ndarray, rate: int): + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + "s16le", # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + "192k", # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + io_buffer.write(out) + return io_buffer + + +def pack_audio(io_buffer: BytesIO, data: np.ndarray, rate: int, media_type: str): + if media_type == "ogg": + io_buffer = pack_ogg(io_buffer, data, rate) + elif media_type == "aac": + io_buffer = pack_aac(io_buffer, data, rate) + elif media_type == "wav": + io_buffer = pack_wav(io_buffer, data, rate) + else: + io_buffer = pack_raw(io_buffer, data, rate) + io_buffer.seek(0) + return io_buffer + + +# from https://huggingface.co/spaces/coqui/voice-chat-with-mistral/blob/main/app.py +def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=32000): + # This will create a wave header then append the frame input + # It should be first on a streaming wav file + # Other frames better should not have it (else you will hear some artifacts each chunk start) + wav_buf = BytesIO() + with wave.open(wav_buf, "wb") as vfout: + vfout.setnchannels(channels) + vfout.setsampwidth(sample_width) + vfout.setframerate(sample_rate) + vfout.writeframes(frame_input) + + wav_buf.seek(0) + return wav_buf.read() + + +def handle_control(command: str): + if command == "restart": + os.execl(sys.executable, sys.executable, *argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def check_params(req: dict): + text: str = req.get("text", "") + text_lang: str = req.get("text_lang", "") + ref_audio_path: str = req.get("ref_audio_path", "") + streaming_mode: bool = req.get("streaming_mode", False) + media_type: str = req.get("media_type", "wav") + prompt_lang: str = req.get("prompt_lang", "") + text_split_method: str = req.get("text_split_method", "cut5") + + if ref_audio_path in [None, ""]: + return JSONResponse(status_code=400, content={"message": "ref_audio_path is required"}) + if text in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text is required"}) + if text_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "text_lang is required"}) + elif text_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"text_lang: {text_lang} is not supported in version {tts_config.version}"}, + ) + if prompt_lang in [None, ""]: + return JSONResponse(status_code=400, content={"message": "prompt_lang is required"}) + elif prompt_lang.lower() not in tts_config.languages: + return JSONResponse( + status_code=400, + content={"message": f"prompt_lang: {prompt_lang} is not supported in version {tts_config.version}"}, + ) + if media_type not in ["wav", "raw", "ogg", "aac"]: + return JSONResponse(status_code=400, content={"message": f"media_type: {media_type} is not supported"}) + elif media_type == "ogg" and not streaming_mode: + return JSONResponse(status_code=400, content={"message": "ogg format is not supported in non-streaming mode"}) + + if text_split_method not in cut_method_names: + return JSONResponse( + status_code=400, content={"message": f"text_split_method:{text_split_method} is not supported"} + ) + + return None + + +async def tts_handle(req: dict): + """ + Text to speech handler. + + Args: + req (dict): + { + "text": "", # str.(required) text to be synthesized + "text_lang: "", # str.(required) language of the text to be synthesized + "ref_audio_path": "", # str.(required) reference audio path + "aux_ref_audio_paths": [], # list.(optional) auxiliary reference audio paths for multi-speaker synthesis + "prompt_text": "", # str.(optional) prompt text for the reference audio + "prompt_lang": "", # str.(required) language of the prompt text for the reference audio + "top_k": 5, # int. top k sampling + "top_p": 1, # float. top p sampling + "temperature": 1, # float. temperature for sampling + "text_split_method": "cut5", # str. text split method, see text_segmentation_method.py for details. + "batch_size": 1, # int. batch size for inference + "batch_threshold": 0.75, # float. threshold for batch splitting. + "split_bucket: True, # bool. whether to split the batch into multiple buckets. + "speed_factor":1.0, # float. control the speed of the synthesized audio. + "fragment_interval":0.3, # float. to control the interval of the audio fragment. + "seed": -1, # int. random seed for reproducibility. + "media_type": "wav", # str. media type of the output audio, support "wav", "raw", "ogg", "aac". + "streaming_mode": False, # bool. whether to return a streaming response. + "parallel_infer": True, # bool.(optional) whether to use parallel inference. + "repetition_penalty": 1.35 # float.(optional) repetition penalty for T2S model. + "sample_steps": 32, # int. number of sampling steps for VITS model V3. + "super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. + } + returns: + StreamingResponse: audio stream response. + """ + + streaming_mode = req.get("streaming_mode", False) + return_fragment = req.get("return_fragment", False) + media_type = req.get("media_type", "wav") + + check_res = check_params(req) + if check_res is not None: + return check_res + + if streaming_mode or return_fragment: + req["return_fragment"] = True + + try: + tts_generator = tts_pipeline.run(req) + + if streaming_mode: + + def streaming_generator(tts_generator: Generator, media_type: str): + if_frist_chunk = True + for sr, chunk in tts_generator: + if if_frist_chunk and media_type == "wav": + yield wave_header_chunk(sample_rate=sr) + media_type = "raw" + if_frist_chunk = False + yield pack_audio(BytesIO(), chunk, sr, media_type).getvalue() + + # _media_type = f"audio/{media_type}" if not (streaming_mode and media_type in ["wav", "raw"]) else f"audio/x-{media_type}" + return StreamingResponse( + streaming_generator( + tts_generator, + media_type, + ), + media_type=f"audio/{media_type}", + ) + + else: + sr, audio_data = next(tts_generator) + audio_data = pack_audio(BytesIO(), audio_data, sr, media_type).getvalue() + return Response(audio_data, media_type=f"audio/{media_type}") + except Exception as e: + return JSONResponse(status_code=400, content={"message": "tts failed", "Exception": str(e)}) + + +@APP.get("/control") +async def control(command: str = None): + if command is None: + return JSONResponse(status_code=400, content={"message": "command is required"}) + handle_control(command) + + +@APP.get("/tts") +async def tts_get_endpoint( + text: str = None, + text_lang: str = None, + ref_audio_path: str = None, + aux_ref_audio_paths: list = None, + prompt_lang: str = None, + prompt_text: str = "", + top_k: int = 5, + top_p: float = 1, + temperature: float = 1, + text_split_method: str = "cut0", + batch_size: int = 1, + batch_threshold: float = 0.75, + split_bucket: bool = True, + speed_factor: float = 1.0, + fragment_interval: float = 0.3, + seed: int = -1, + media_type: str = "wav", + streaming_mode: bool = False, + parallel_infer: bool = True, + repetition_penalty: float = 1.35, + sample_steps: int = 32, + super_sampling: bool = False, +): + req = { + "text": text, + "text_lang": text_lang.lower(), + "ref_audio_path": ref_audio_path, + "aux_ref_audio_paths": aux_ref_audio_paths, + "prompt_text": prompt_text, + "prompt_lang": prompt_lang.lower(), + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "text_split_method": text_split_method, + "batch_size": int(batch_size), + "batch_threshold": float(batch_threshold), + "speed_factor": float(speed_factor), + "split_bucket": split_bucket, + "fragment_interval": fragment_interval, + "seed": seed, + "media_type": media_type, + "streaming_mode": streaming_mode, + "parallel_infer": parallel_infer, + "repetition_penalty": float(repetition_penalty), + "sample_steps": int(sample_steps), + "super_sampling": super_sampling, + } + return await tts_handle(req) + + +@APP.post("/tts") +async def tts_post_endpoint(request: TTS_Request): + req = request.dict() + return await tts_handle(req) + + +@APP.get("/set_refer_audio") +async def set_refer_aduio(refer_audio_path: str = None): + try: + tts_pipeline.set_ref_audio(refer_audio_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "set refer audio failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +# @APP.post("/set_refer_audio") +# async def set_refer_aduio_post(audio_file: UploadFile = File(...)): +# try: +# # 检查文件类型,确保是音频文件 +# if not audio_file.content_type.startswith("audio/"): +# return JSONResponse(status_code=400, content={"message": "file type is not supported"}) + +# os.makedirs("uploaded_audio", exist_ok=True) +# save_path = os.path.join("uploaded_audio", audio_file.filename) +# # 保存音频文件到服务器上的一个目录 +# with open(save_path , "wb") as buffer: +# buffer.write(await audio_file.read()) + +# tts_pipeline.set_ref_audio(save_path) +# except Exception as e: +# return JSONResponse(status_code=400, content={"message": f"set refer audio failed", "Exception": str(e)}) +# return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_gpt_weights") +async def set_gpt_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "gpt weight path is required"}) + tts_pipeline.init_t2s_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change gpt weight failed", "Exception": str(e)}) + + return JSONResponse(status_code=200, content={"message": "success"}) + + +@APP.get("/set_sovits_weights") +async def set_sovits_weights(weights_path: str = None): + try: + if weights_path in ["", None]: + return JSONResponse(status_code=400, content={"message": "sovits weight path is required"}) + tts_pipeline.init_vits_weights(weights_path) + except Exception as e: + return JSONResponse(status_code=400, content={"message": "change sovits weight failed", "Exception": str(e)}) + return JSONResponse(status_code=200, content={"message": "success"}) + + +if __name__ == "__main__": + try: + if host == "None": # 在调用时使用 -a None 参数,可以让api监听双栈 + host = None + uvicorn.run(app=APP, host=host, port=port, workers=1) + except Exception: + traceback.print_exc() + os.kill(os.getpid(), signal.SIGTERM) + exit(0) diff --git a/colab_webui.ipynb b/colab_webui.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c44ea211be8839948fd099b4a1be7610b76ccbae --- /dev/null +++ b/colab_webui.ipynb @@ -0,0 +1,106 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_o6a8GS2lWQM" + }, + "source": [ + "# Env Setup (Run Once Only)\n", + "# 环境配置, 只需运行一次" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile /content/setup.sh\n", + "set -e\n", + "cd /content\n", + "rm -rf GPT-SoVITS\n", + "git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n", + "cd GPT-SoVITS\n", + "\n", + "if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n", + " :\n", + "else\n", + " conda create -n GPTSoVITS python=3.10 -y\n", + "fi\n", + "\n", + "source activate GPTSoVITS\n", + "\n", + "bash install.sh --source HF --download-uvr5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -q condacolab\n", + "import condacolab\n", + "condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n", + "!cd /content && bash setup.sh" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Launch WebUI\n", + "# 启动 WebUI" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4oRGUzkrk8C7" + }, + "outputs": [], + "source": [ + "!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5f90c5cd60994b3f9e064681cd67f80d738a06f6 --- /dev/null +++ b/config.py @@ -0,0 +1,69 @@ +import sys +import os + +import torch + +# 推理用的指定模型 +sovits_path = "" +gpt_path = "" +is_half_str = os.environ.get("is_half", "True") +is_half = True if is_half_str.lower() == "true" else False +is_share_str = os.environ.get("is_share", "False") +is_share = True if is_share_str.lower() == "true" else False + +cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" +pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" +pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" + +exp_root = "logs" +python_exec = sys.executable or "python" +if torch.cuda.is_available(): + infer_device = "cuda" +else: + infer_device = "cpu" + +webui_port_main = 9874 +webui_port_uvr5 = 9873 +webui_port_infer_tts = 9872 +webui_port_subfix = 9871 + +api_port = 9880 + +if infer_device == "cuda": + gpu_name = torch.cuda.get_device_name(0) + if ( + ("16" in gpu_name and "V100" not in gpu_name.upper()) + or "P40" in gpu_name.upper() + or "P10" in gpu_name.upper() + or "1060" in gpu_name + or "1070" in gpu_name + or "1080" in gpu_name + ): + is_half = False + +if infer_device == "cpu": + is_half = False + + +class Config: + def __init__(self): + self.sovits_path = sovits_path + self.gpt_path = gpt_path + self.is_half = is_half + + self.cnhubert_path = cnhubert_path + self.bert_path = bert_path + self.pretrained_sovits_path = pretrained_sovits_path + self.pretrained_gpt_path = pretrained_gpt_path + + self.exp_root = exp_root + self.python_exec = python_exec + self.infer_device = infer_device + + self.webui_port_main = webui_port_main + self.webui_port_uvr5 = webui_port_uvr5 + self.webui_port_infer_tts = webui_port_infer_tts + self.webui_port_subfix = webui_port_subfix + + self.api_port = api_port diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aca8ab9ed1b6938016e1fcacf1d24d673255576e --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,32 @@ +version: '3.8' + +services: + gpt-sovits: + image: breakstring/gpt-sovits:latest # please change the image name and tag base your environment. If the tag contains the word 'elite', such as "latest-elite", it indicates that the image does not include the necessary models such as GPT-SoVITS, UVR5, Damo ASR, etc. You will need to download them yourself and map them into the container. + container_name: gpt-sovits-container + environment: + - is_half=False + - is_share=False + volumes: + - ./output:/workspace/output + - ./logs:/workspace/logs + - ./SoVITS_weights:/workspace/SoVITS_weights + - ./reference:/workspace/reference + working_dir: /workspace + ports: + - "9880:9880" + - "9871:9871" + - "9872:9872" + - "9873:9873" + - "9874:9874" + shm_size: 16G + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: "all" + capabilities: [gpu] + stdin_open: true + tty: true + restart: unless-stopped diff --git a/dockerbuild.sh b/dockerbuild.sh new file mode 100644 index 0000000000000000000000000000000000000000..3a4a1e183d88ac6c0167d5d5e05519cf6da13b9a --- /dev/null +++ b/dockerbuild.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# 获取当前日期,格式为 YYYYMMDD +DATE=$(date +%Y%m%d) +# 获取最新的 Git commit 哈希值的前 7 位 +COMMIT_HASH=$(git rev-parse HEAD | cut -c 1-7) + +# 构建 full 版本的镜像 +docker build --build-arg IMAGE_TYPE=full -t breakstring/gpt-sovits:latest . +# 为同一个镜像添加带日期的标签 +docker tag breakstring/gpt-sovits:latest breakstring/gpt-sovits:dev-$DATE +# 为同一个镜像添加带当前代码库Commit哈希值的标签 +docker tag breakstring/gpt-sovits:latest breakstring/gpt-sovits:dev-$COMMIT_HASH + + +# 构建 elite 版本的镜像(无模型下载步骤,需手工将模型下载安装进容器) +docker build --build-arg IMAGE_TYPE=elite -t breakstring/gpt-sovits:latest-elite . +# 为同一个镜像添加带日期的标签 +docker tag breakstring/gpt-sovits:latest-elite breakstring/gpt-sovits:dev-$DATE-elite +# 为同一个镜像添加带当前代码库Commit哈希值的标签 +docker tag breakstring/gpt-sovits:latest-elite breakstring/gpt-sovits:dev-$COMMIT_HASH-elite diff --git a/extra-req.txt b/extra-req.txt new file mode 100644 index 0000000000000000000000000000000000000000..8d2324117f5edb41b3c5f3c13b2b0bf2dd395942 --- /dev/null +++ b/extra-req.txt @@ -0,0 +1 @@ +faster-whisper diff --git a/go-webui.bat b/go-webui.bat new file mode 100644 index 0000000000000000000000000000000000000000..a2dfff6c0a5444ac2524cd981078d56399b79505 --- /dev/null +++ b/go-webui.bat @@ -0,0 +1,2 @@ +runtime\python.exe -I webui.py zh_CN +pause diff --git a/go-webui.ps1 b/go-webui.ps1 new file mode 100644 index 0000000000000000000000000000000000000000..f9427263d1a1402a5d09923f18b19af3d35e8b2f --- /dev/null +++ b/go-webui.ps1 @@ -0,0 +1,4 @@ +$ErrorActionPreference = "SilentlyContinue" +chcp 65001 +& "$PSScriptRoot\runtime\python.exe" -I "$PSScriptRoot\webui.py" zh_CN +pause diff --git a/gpt-sovits_kaggle.ipynb b/gpt-sovits_kaggle.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9f28f6f46ce7e2fe1b78a49b2eae4f5ecf971808 --- /dev/null +++ b/gpt-sovits_kaggle.ipynb @@ -0,0 +1,235 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "45857cb2", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2024-02-18T14:43:46.735480Z", + "iopub.status.busy": "2024-02-18T14:43:46.735183Z", + "iopub.status.idle": "2024-02-18T14:48:10.724175Z", + "shell.execute_reply": "2024-02-18T14:48:10.723059Z" + }, + "papermill": { + "duration": 263.994935, + "end_time": "2024-02-18T14:48:10.726613", + "exception": false, + "start_time": "2024-02-18T14:43:46.731678", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n", + "%cd GPT-SoVITS\n", + "!apt-get update && apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && git lfs install\n", + "!pip install -r requirements.txt\n", + "!pip install -r extra-req.txt --no-deps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d346b4", + "metadata": { + "execution": { + "iopub.execute_input": "2024-02-18T14:48:10.815802Z", + "iopub.status.busy": "2024-02-18T14:48:10.814899Z", + "iopub.status.idle": "2024-02-18T14:50:31.253276Z", + "shell.execute_reply": "2024-02-18T14:50:31.252024Z" + }, + "papermill": { + "duration": 140.484893, + "end_time": "2024-02-18T14:50:31.255720", + "exception": false, + "start_time": "2024-02-18T14:48:10.770827", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# @title Download pretrained models 下载预训练模型\n", + "!mkdir -p /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", + "!mkdir -p /kaggle/working/GPT-SoVITS/tools/asr/models\n", + "!mkdir -p /kaggle/working/GPT-SoVITS/tools/uvr5\n", + "%cd /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n", + "!git clone https://huggingface.co/lj1995/GPT-SoVITS\n", + "%cd /kaggle/working/GPT-SoVITS/tools/asr/models\n", + "!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n", + "!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n", + "!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n", + "# # @title UVR5 pretrains 安装uvr5模型\n", + "%cd /kaggle/working/GPT-SoVITS/tools/uvr5\n", + "!git clone https://huggingface.co/Delik/uvr5_weights\n", + "!git config core.sparseCheckout true\n", + "!mv /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models/GPT-SoVITS/* /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea94d245", + "metadata": { + "execution": { + "iopub.execute_input": "2024-02-18T14:29:01.071549Z", + "iopub.status.busy": "2024-02-18T14:29:01.070592Z", + "iopub.status.idle": "2024-02-18T14:40:45.318368Z", + "shell.execute_reply": "2024-02-18T14:40:45.317130Z", + "shell.execute_reply.started": "2024-02-18T14:29:01.071512Z" + }, + "papermill": { + "duration": null, + "end_time": null, + "exception": false, + "start_time": "2024-02-18T14:50:31.309013", + "status": "running" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# @title launch WebUI 启动WebUI\n", + "%cd /kaggle/working/GPT-SoVITS/\n", + "!npm install -g localtunnel\n", + "import subprocess\n", + "import threading\n", + "import time\n", + "import socket\n", + "import urllib.request\n", + "\n", + "\n", + "def iframe_thread(port):\n", + " while True:\n", + " time.sleep(0.5)\n", + " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", + " result = sock.connect_ex((\"127.0.0.1\", port))\n", + " if result == 0:\n", + " break\n", + " sock.close()\n", + "\n", + " from colorama import Fore, Style\n", + " print(\n", + " Fore.GREEN + \"\\nIP: \",\n", + " Fore.RED,\n", + " urllib.request.urlopen(\"https://ipv4.icanhazip.com\").read().decode(\"utf8\").strip(\"\\n\"),\n", + " \"\\n\",\n", + " Style.RESET_ALL,\n", + " )\n", + " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", + " for line in p.stdout:\n", + " print(line.decode(), end=\"\")\n", + "\n", + "\n", + "threading.Thread(target=iframe_thread, daemon=True, args=(9874,)).start()\n", + "\n", + "!python webui.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dda88a6d", + "metadata": { + "execution": { + "iopub.execute_input": "2024-02-18T14:40:56.880608Z", + "iopub.status.busy": "2024-02-18T14:40:56.879879Z" + }, + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# 开启推理页面\n", + "%cd /kaggle/working/GPT-SoVITS/\n", + "!npm install -g localtunnel\n", + "import threading\n", + "\n", + "\n", + "def iframe_thread(port):\n", + " while True:\n", + " time.sleep(0.5)\n", + " sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n", + " result = sock.connect_ex((\"127.0.0.1\", port))\n", + " if result == 0:\n", + " break\n", + " sock.close()\n", + "\n", + " from colorama import Fore, Style\n", + " print(\n", + " Fore.GREEN + \"\\nIP: \",\n", + " Fore.RED,\n", + " urllib.request.urlopen(\"https://ipv4.icanhazip.com\").read().decode(\"utf8\").strip(\"\\n\"),\n", + " \"\\n\",\n", + " Style.RESET_ALL,\n", + " )\n", + " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", + " for line in p.stdout:\n", + " print(line.decode(), end=\"\")\n", + "\n", + "\n", + "threading.Thread(target=iframe_thread, daemon=True, args=(9872,)).start()\n", + "\n", + "!python ./GPT_SoVITS/inference_webui.py" + ] + } + ], + "metadata": { + "kaggle": { + "accelerator": "nvidiaTeslaT4", + "dataSources": [ + { + "datasetId": 4459328, + "sourceId": 7649639, + "sourceType": "datasetVersion" + } + ], + "dockerImageVersionId": 30646, + "isGpuEnabled": true, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "papermill": { + "default_parameters": {}, + "duration": null, + "end_time": null, + "environment_variables": {}, + "exception": null, + "input_path": "__notebook__.ipynb", + "output_path": "__notebook__.ipynb", + "parameters": {}, + "start_time": "2024-02-18T14:43:44.011910", + "version": "2.5.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/install.sh b/install.sh new file mode 100644 index 0000000000000000000000000000000000000000..31b4761b273f643064e1a10d06c70c4984f944d2 --- /dev/null +++ b/install.sh @@ -0,0 +1,213 @@ +#!/bin/bash + +# cd into GPT-SoVITS Base Path +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" + +cd "$SCRIPT_DIR" || exit 1 + +set -e + +if ! command -v conda &>/dev/null; then + echo "Conda Not Found" + exit 1 +fi + +trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR + +is_HF=false +is_HF_MIRROR=false +is_MODELSCOPE=false +DOWNLOAD_UVR5=false + +print_help() { + echo "Usage: bash install.sh [OPTIONS]" + echo "" + echo "Options:" + echo " --source HF|HF-Mirror|ModelScope Specify the model source (REQUIRED)" + echo " --download-uvr5 Enable downloading the UVR5 model" + echo " -h, --help Show this help message and exit" + echo "" + echo "Examples:" + echo " bash install.sh --source HF --download-uvr5" + echo " bash install.sh --source ModelScope" +} + +# Show help if no arguments provided +if [[ $# -eq 0 ]]; then + print_help + exit 0 +fi + +# Parse arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --source) + case "$2" in + HF) + is_HF=true + ;; + HF-Mirror) + is_HF_MIRROR=true + ;; + ModelScope) + is_MODELSCOPE=true + ;; + *) + echo "Error: Invalid Download Source: $2" + echo "Choose From: [HF, HF-Mirror, ModelScope]" + exit 1 + ;; + esac + shift 2 + ;; + --download-uvr5) + DOWNLOAD_UVR5=true + shift + ;; + -h|--help) + print_help + exit 0 + ;; + *) + echo "Unknown Argument: $1" + echo "Use -h or --help to see available options." + exit 1 + ;; + esac +done + +if ! $is_HF && ! $is_HF_MIRROR && ! $is_MODELSCOPE; then + echo "Error: Download Source is REQUIRED" + echo "" + print_help + exit 1 +fi + +if [ "$is_HF" = "true" ]; then + echo "Download Model From HuggingFace" + PRETRINED_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip" + G2PW_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip" + UVR5_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip" +elif [ "$is_HF_MIRROR" = "true" ]; then + echo "Download Model From HuggingFace-Mirror" + PRETRINED_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip" + G2PW_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip" + UVR5_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip" +elif [ "$is_MODELSCOPE" = "true" ]; then + echo "Download Model From ModelScope" + PRETRINED_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/pretrained_models.zip" + G2PW_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/G2PWModel.zip" + UVR5_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/uvr5_weights.zip" +fi + +if find "GPT_SoVITS/pretrained_models" -mindepth 1 ! -name '.gitignore' | grep -q .; then + echo "Pretrained Model Exists" +else + echo "Download Pretrained Models" + wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404 "$PRETRINED_URL" + + unzip pretrained_models.zip + rm -rf pretrained_models.zip + mv pretrained_models/* GPT_SoVITS/pretrained_models + rm -rf pretrained_models +fi + +if [ ! -d "GPT_SoVITS/text/G2PWModel" ]; then + echo "Download G2PWModel" + wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404 "$G2PW_URL" + + unzip G2PWModel.zip + rm -rf G2PWModel.zip + mv G2PWModel GPT_SoVITS/text/G2PWModel +else + echo "G2PWModel Exists" +fi + +if [ "$DOWNLOAD_UVR5" = "true" ];then + if find "tools/uvr5/uvr5_weights" -mindepth 1 ! -name '.gitignore' | grep -q .; then + echo "UVR5 Model Exists" + else + echo "Download UVR5 Model" + wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404 "$UVR5_URL" + + unzip uvr5_weights.zip + rm -rf uvr5_weights.zip + mv uvr5_weights/* tools/uvr5/uvr5_weights + rm -rf uvr5_weights + fi +fi + +# 安装构建工具 +# Install build tools +echo "Installing GCC..." +conda install -c conda-forge gcc=14 -y + +echo "Installing G++..." +conda install -c conda-forge gxx -y + +echo "Installing ffmpeg and cmake..." +conda install ffmpeg cmake -y + +echo "Installing git-lfs and zip..." +conda install git-lfs -y +conda install zip -y + +git-lfs install + +echo "Checking for CUDA installation..." +if command -v nvidia-smi &>/dev/null; then + USE_CUDA=true + echo "CUDA found." +else + echo "CUDA not found." + USE_CUDA=false +fi + +if [ "$USE_CUDA" = false ]; then + echo "Checking for ROCm installation..." + if [ -d "/opt/rocm" ]; then + USE_ROCM=true + echo "ROCm found." + if grep -qi "microsoft" /proc/version; then + echo "You are running WSL." + IS_WSL=true + else + echo "You are NOT running WSL." + IS_WSL=false + fi + else + echo "ROCm not found." + USE_ROCM=false + fi +fi + +if [ "$USE_CUDA" = true ]; then + echo "Installing PyTorch with CUDA support..." + pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 +elif [ "$USE_ROCM" = true ]; then + echo "Installing PyTorch with ROCm support..." + pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/rocm6.2 +else + echo "Installing PyTorch for CPU..." + pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu +fi + +echo "Installing Python dependencies from requirements.txt..." + +# 刷新环境 +# Refresh environment +hash -r + +pip install -r extra-req.txt --no-deps + +pip install -r requirements.txt + +if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then + echo "Update to WSL compatible runtime lib..." + location=$(pip show torch | grep Location | awk -F ": " '{print $2}') + cd "${location}"/torch/lib/ || exit + rm libhsa-runtime64.so* + cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so +fi + +echo "Installation completed successfully!" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..30d25825b3fb15519235f73a033ce897d0c17c0c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,45 @@ +numpy<2.0 +scipy +tensorboard +librosa==0.9.2 +numba +pytorch-lightning>=2.4 +gradio>=4.41,<5 +ffmpeg-python +onnxruntime; sys_platform == 'darwin' +onnxruntime-gpu; sys_platform != 'darwin' +tqdm +funasr==1.0.27 +cn2an +pypinyin +pyopenjtalk>=0.4.1 +g2p_en +torchaudio +modelscope==1.10.0 +sentencepiece +transformers>=4.43 +peft +chardet +PyYAML +psutil +jieba_fast +jieba +split-lang +fast_langdetect>=0.3.1 +wordsegment +rotary_embedding_torch +ToJyutping +g2pk2 +ko_pron +opencc; sys_platform != 'linux' +opencc==1.1.1; sys_platform == 'linux' +python_mecab_ko; sys_platform != 'win32' +fastapi[standard]>=0.115.1 +x_transformers +torchmetrics<=1.5 +pydantic<=2.10.6 +ctranslate2>=4.0,<5 +huggingface_hub>=0.13 +tokenizers>=0.13,<1 +av>=11 +tqdm diff --git a/webui.py b/webui.py new file mode 100644 index 0000000000000000000000000000000000000000..e229b216476c01422c170803119e87d138c7306e --- /dev/null +++ b/webui.py @@ -0,0 +1,1963 @@ +import os +import sys + +if len(sys.argv) == 1: + sys.argv.append("v2") +version = "v1" if sys.argv[1] == "v1" else "v2" +os.environ["version"] = version +now_dir = os.getcwd() +sys.path.insert(0, now_dir) +import warnings + +warnings.filterwarnings("ignore") +import json +import platform +import re +import shutil +import signal + +import psutil +import torch +import yaml + +os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO" +torch.manual_seed(233333) +tmp = os.path.join(now_dir, "TEMP") +os.makedirs(tmp, exist_ok=True) +os.environ["TEMP"] = tmp +if os.path.exists(tmp): + for name in os.listdir(tmp): + if name == "jieba.cache": + continue + path = "%s/%s" % (tmp, name) + delete = os.remove if os.path.isfile(path) else shutil.rmtree + try: + delete(path) + except Exception as e: + print(str(e)) + pass +import site +import traceback + +site_packages_roots = [] +for path in site.getsitepackages(): + if "packages" in path: + site_packages_roots.append(path) +if site_packages_roots == []: + site_packages_roots = ["%s/runtime/Lib/site-packages" % now_dir] +# os.environ["OPENBLAS_NUM_THREADS"] = "4" +os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1" +os.environ["all_proxy"] = "" +for site_packages_root in site_packages_roots: + if os.path.exists(site_packages_root): + try: + with open("%s/users.pth" % (site_packages_root), "w") as f: + f.write( + # "%s\n%s/runtime\n%s/tools\n%s/tools/asr\n%s/GPT_SoVITS\n%s/tools/uvr5" + "%s\n%s/GPT_SoVITS/BigVGAN\n%s/tools\n%s/tools/asr\n%s/GPT_SoVITS\n%s/tools/uvr5" + % (now_dir, now_dir, now_dir, now_dir, now_dir, now_dir) + ) + break + except PermissionError: + traceback.print_exc() +import shutil +import subprocess +from subprocess import Popen + +from config import ( + exp_root, + infer_device, + is_half, + is_share, + python_exec, + webui_port_infer_tts, + webui_port_main, + webui_port_subfix, + webui_port_uvr5, +) +from tools import my_utils +from tools.i18n.i18n import I18nAuto, scan_language_list + +language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto" +os.environ["language"] = language +i18n = I18nAuto(language=language) +from multiprocessing import cpu_count + +from tools.my_utils import check_details, check_for_existance + +# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu +try: + import gradio.analytics as analytics + + analytics.version_check = lambda: None +except: + ... +import gradio as gr + +n_cpu = cpu_count() + +ngpu = torch.cuda.device_count() +gpu_infos = [] +mem = [] +if_gpu_ok = False + +# 判断是否有能用来训练和加速推理的N卡 +ok_gpu_keywords = { + "10", + "16", + "20", + "30", + "40", + "A2", + "A3", + "A4", + "P4", + "A50", + "500", + "A60", + "70", + "80", + "90", + "M4", + "T4", + "TITAN", + "L4", + "4060", + "H", + "600", + "506", + "507", + "508", + "509", +} +set_gpu_numbers = set() +if torch.cuda.is_available() or ngpu != 0: + for i in range(ngpu): + gpu_name = torch.cuda.get_device_name(i) + if any(value in gpu_name.upper() for value in ok_gpu_keywords): + # A10#A100#V100#A40#P40#M40#K80#A4500 + if_gpu_ok = True # 至少有一张能用的N卡 + gpu_infos.append("%s\t%s" % (i, gpu_name)) + set_gpu_numbers.add(i) + mem.append(int(torch.cuda.get_device_properties(i).total_memory / 1024 / 1024 / 1024 + 0.4)) +# # 判断是否支持mps加速 +# if torch.backends.mps.is_available(): +# if_gpu_ok = True +# gpu_infos.append("%s\t%s" % ("0", "Apple GPU")) +# mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存 + + +v3v4set={"v3","v4"} +def set_default(): + global \ + default_batch_size, \ + default_max_batch_size, \ + gpu_info, \ + default_sovits_epoch, \ + default_sovits_save_every_epoch, \ + max_sovits_epoch, \ + max_sovits_save_every_epoch, \ + default_batch_size_s1, \ + if_force_ckpt + if_force_ckpt = False + if if_gpu_ok and len(gpu_infos) > 0: + gpu_info = "\n".join(gpu_infos) + minmem = min(mem) + # if version == "v3" and minmem < 14: + # # API读取不到共享显存,直接填充确认 + # try: + # torch.zeros((1024,1024,1024,14),dtype=torch.int8,device="cuda") + # torch.cuda.empty_cache() + # minmem = 14 + # except RuntimeError as _: + # # 强制梯度检查只需要12G显存 + # if minmem >= 12 : + # if_force_ckpt = True + # minmem = 14 + # else: + # try: + # torch.zeros((1024,1024,1024,12),dtype=torch.int8,device="cuda") + # torch.cuda.empty_cache() + # if_force_ckpt = True + # minmem = 14 + # except RuntimeError as _: + # print("显存不足以开启V3训练") + default_batch_size = minmem // 2 if version not in v3v4set else minmem // 8 + default_batch_size_s1 = minmem // 2 + else: + gpu_info = "%s\t%s" % ("0", "CPU") + gpu_infos.append("%s\t%s" % ("0", "CPU")) + set_gpu_numbers.add(0) + default_batch_size = default_batch_size_s1 = int(psutil.virtual_memory().total / 1024 / 1024 / 1024 / 4) + if version not in v3v4set: + default_sovits_epoch = 8 + default_sovits_save_every_epoch = 4 + max_sovits_epoch = 25 # 40 + max_sovits_save_every_epoch = 25 # 10 + else: + default_sovits_epoch = 2 + default_sovits_save_every_epoch = 1 + max_sovits_epoch = 50 # 40 # 3 + max_sovits_save_every_epoch = 10 # 10 # 3 + + default_batch_size = max(1, default_batch_size) + default_batch_size_s1 = max(1, default_batch_size_s1) + default_max_batch_size = default_batch_size * 3 + + +set_default() + +gpus = "-".join([i[0] for i in gpu_infos]) +default_gpu_numbers = str(sorted(list(set_gpu_numbers))[0]) + + +def fix_gpu_number(input): # 将越界的number强制改到界内 + try: + if int(input) not in set_gpu_numbers: + return default_gpu_numbers + except: + return input + return input + + +def fix_gpu_numbers(inputs): + output = [] + try: + for input in inputs.split(","): + output.append(str(fix_gpu_number(input))) + return ",".join(output) + except: + return inputs + + +pretrained_sovits_name = [ + "GPT_SoVITS/pretrained_models/s2G488k.pth", + "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + "GPT_SoVITS/pretrained_models/s2Gv3.pth", + "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", +] +pretrained_gpt_name = [ + "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", + "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", + "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "GPT_SoVITS/pretrained_models/s1v3.ckpt", +] + +pretrained_model_list = ( + pretrained_sovits_name[int(version[-1]) - 1], + pretrained_sovits_name[int(version[-1]) - 1].replace("s2G", "s2D"), + pretrained_gpt_name[int(version[-1]) - 1], + "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + "GPT_SoVITS/pretrained_models/chinese-hubert-base", +) + +_ = "" +for i in pretrained_model_list: + if "s2Dv3" not in i and os.path.exists(i) == False: + _ += f"\n {i}" +if _: + print("warning: ", i18n("以下模型不存在:") + _) + +_ = [[], []] +for i in range(4): + if os.path.exists(pretrained_gpt_name[i]): + _[0].append(pretrained_gpt_name[i]) + else: + _[0].append("") ##没有下pretrained模型的,说不定他们是想自己从零训底模呢 + if os.path.exists(pretrained_sovits_name[i]): + _[-1].append(pretrained_sovits_name[i]) + else: + _[-1].append("") +pretrained_gpt_name, pretrained_sovits_name = _ + +SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"] +GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"] +for root in SoVITS_weight_root + GPT_weight_root: + os.makedirs(root, exist_ok=True) + + +def get_weights_names(): + SoVITS_names = [name for name in pretrained_sovits_name if name != ""] + for path in SoVITS_weight_root: + for name in os.listdir(path): + if name.endswith(".pth"): + SoVITS_names.append("%s/%s" % (path, name)) + GPT_names = [name for name in pretrained_gpt_name if name != ""] + for path in GPT_weight_root: + for name in os.listdir(path): + if name.endswith(".ckpt"): + GPT_names.append("%s/%s" % (path, name)) + return SoVITS_names, GPT_names + + +SoVITS_names, GPT_names = get_weights_names() +for path in SoVITS_weight_root + GPT_weight_root: + os.makedirs(path, exist_ok=True) + + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split("(\d+)", s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + + +def change_choices(): + SoVITS_names, GPT_names = get_weights_names() + return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, { + "choices": sorted(GPT_names, key=custom_sort_key), + "__type__": "update", + } + + +p_label = None +p_uvr5 = None +p_asr = None +p_denoise = None +p_tts_inference = None + + +def kill_proc_tree(pid, including_parent=True): + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + # Process already terminated + return + + children = parent.children(recursive=True) + for child in children: + try: + os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + if including_parent: + try: + os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL + except OSError: + pass + + +system = platform.system() + + +def kill_process(pid, process_name=""): + if system == "Windows": + cmd = "taskkill /t /f /pid %s" % pid + # os.system(cmd) + subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + else: + kill_proc_tree(pid) + print(process_name + i18n("进程已终止")) + + +def process_info(process_name="", indicator=""): + if indicator == "opened": + return process_name + i18n("已开启") + elif indicator == "open": + return i18n("开启") + process_name + elif indicator == "closed": + return process_name + i18n("已关闭") + elif indicator == "close": + return i18n("关闭") + process_name + elif indicator == "running": + return process_name + i18n("运行中") + elif indicator == "occupy": + return process_name + i18n("占用中") + "," + i18n("需先终止才能开启下一次任务") + elif indicator == "finish": + return process_name + i18n("已完成") + elif indicator == "failed": + return process_name + i18n("失败") + elif indicator == "info": + return process_name + i18n("进程输出信息") + else: + return process_name + + +process_name_subfix = i18n("音频标注WebUI") + + +def change_label(path_list): + global p_label + if p_label is None: + check_for_existance([path_list]) + path_list = my_utils.clean_path(path_list) + cmd = '"%s" tools/subfix_webui.py --load_list "%s" --webui_port %s --is_share %s' % ( + python_exec, + path_list, + webui_port_subfix, + is_share, + ) + yield ( + process_info(process_name_subfix, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + print(cmd) + p_label = Popen(cmd, shell=True) + else: + kill_process(p_label.pid, process_name_subfix) + p_label = None + yield ( + process_info(process_name_subfix, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +process_name_uvr5 = i18n("人声分离WebUI") + + +def change_uvr5(): + global p_uvr5 + if p_uvr5 is None: + cmd = '"%s" tools/uvr5/webui.py "%s" %s %s %s' % (python_exec, infer_device, is_half, webui_port_uvr5, is_share) + yield ( + process_info(process_name_uvr5, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + print(cmd) + p_uvr5 = Popen(cmd, shell=True) + else: + kill_process(p_uvr5.pid, process_name_uvr5) + p_uvr5 = None + yield ( + process_info(process_name_uvr5, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +process_name_tts = i18n("TTS推理WebUI") + + +def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, sovits_path, batched_infer_enabled): + global p_tts_inference + if batched_infer_enabled: + cmd = '"%s" GPT_SoVITS/inference_webui_fast.py "%s"' % (python_exec, language) + else: + cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"' % (python_exec, language) + # #####v3暂不支持加速推理 + # if version=="v3": + # cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language) + if p_tts_inference is None: + os.environ["gpt_path"] = gpt_path if "/" in gpt_path else "%s/%s" % (GPT_weight_root, gpt_path) + os.environ["sovits_path"] = sovits_path if "/" in sovits_path else "%s/%s" % (SoVITS_weight_root, sovits_path) + os.environ["cnhubert_base_path"] = cnhubert_base_path + os.environ["bert_path"] = bert_path + os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number) + os.environ["is_half"] = str(is_half) + os.environ["infer_ttswebui"] = str(webui_port_infer_tts) + os.environ["is_share"] = str(is_share) + yield ( + process_info(process_name_tts, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + print(cmd) + p_tts_inference = Popen(cmd, shell=True) + else: + kill_process(p_tts_inference.pid, process_name_tts) + p_tts_inference = None + yield ( + process_info(process_name_tts, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +from tools.asr.config import asr_dict + +process_name_asr = i18n("语音识别") + + +def open_asr(asr_inp_dir, asr_opt_dir, asr_model, asr_model_size, asr_lang, asr_precision): + global p_asr + if p_asr is None: + asr_inp_dir = my_utils.clean_path(asr_inp_dir) + asr_opt_dir = my_utils.clean_path(asr_opt_dir) + check_for_existance([asr_inp_dir]) + cmd = f'"{python_exec}" tools/asr/{asr_dict[asr_model]["path"]}' + cmd += f' -i "{asr_inp_dir}"' + cmd += f' -o "{asr_opt_dir}"' + cmd += f" -s {asr_model_size}" + cmd += f" -l {asr_lang}" + cmd += f" -p {asr_precision}" + output_file_name = os.path.basename(asr_inp_dir) + output_folder = asr_opt_dir or "output/asr_opt" + output_file_path = os.path.abspath(f"{output_folder}/{output_file_name}.list") + yield ( + process_info(process_name_asr, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + {"__type__": "update"}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + print(cmd) + p_asr = Popen(cmd, shell=True) + p_asr.wait() + p_asr = None + yield ( + process_info(process_name_asr, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + {"__type__": "update", "value": output_file_path}, + {"__type__": "update", "value": output_file_path}, + {"__type__": "update", "value": asr_inp_dir}, + ) + else: + yield ( + process_info(process_name_asr, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + {"__type__": "update"}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + + +def close_asr(): + global p_asr + if p_asr is not None: + kill_process(p_asr.pid, process_name_asr) + p_asr = None + return ( + process_info(process_name_asr, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +process_name_denoise = i18n("语音降噪") + + +def open_denoise(denoise_inp_dir, denoise_opt_dir): + global p_denoise + if p_denoise == None: + denoise_inp_dir = my_utils.clean_path(denoise_inp_dir) + denoise_opt_dir = my_utils.clean_path(denoise_opt_dir) + check_for_existance([denoise_inp_dir]) + cmd = '"%s" tools/cmd-denoise.py -i "%s" -o "%s" -p %s' % ( + python_exec, + denoise_inp_dir, + denoise_opt_dir, + "float16" if is_half == True else "float32", + ) + + yield ( + process_info(process_name_denoise, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + print(cmd) + p_denoise = Popen(cmd, shell=True) + p_denoise.wait() + p_denoise = None + yield ( + process_info(process_name_denoise, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + {"__type__": "update", "value": denoise_opt_dir}, + {"__type__": "update", "value": denoise_opt_dir}, + ) + else: + yield ( + process_info(process_name_denoise, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + + +def close_denoise(): + global p_denoise + if p_denoise is not None: + kill_process(p_denoise.pid, process_name_denoise) + p_denoise = None + return ( + process_info(process_name_denoise, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +p_train_SoVITS = None +process_name_sovits = i18n("SoVITS训练") + + +def open1Ba( + batch_size, + total_epoch, + exp_name, + text_low_lr_rate, + if_save_latest, + if_save_every_weights, + save_every_epoch, + gpu_numbers1Ba, + pretrained_s2G, + pretrained_s2D, + if_grad_ckpt, + lora_rank, +): + global p_train_SoVITS + if p_train_SoVITS == None: + with open("GPT_SoVITS/configs/s2.json") as f: + data = f.read() + data = json.loads(data) + s2_dir = "%s/%s" % (exp_root, exp_name) + os.makedirs("%s/logs_s2_%s" % (s2_dir, version), exist_ok=True) + if check_for_existance([s2_dir], is_train=True): + check_details([s2_dir], is_train=True) + if is_half == False: + data["train"]["fp16_run"] = False + batch_size = max(1, batch_size // 2) + data["train"]["batch_size"] = batch_size + data["train"]["epochs"] = total_epoch + data["train"]["text_low_lr_rate"] = text_low_lr_rate + data["train"]["pretrained_s2G"] = pretrained_s2G + data["train"]["pretrained_s2D"] = pretrained_s2D + data["train"]["if_save_latest"] = if_save_latest + data["train"]["if_save_every_weights"] = if_save_every_weights + data["train"]["save_every_epoch"] = save_every_epoch + data["train"]["gpu_numbers"] = gpu_numbers1Ba + data["train"]["grad_ckpt"] = if_grad_ckpt + data["train"]["lora_rank"] = lora_rank + data["model"]["version"] = version + data["data"]["exp_dir"] = data["s2_ckpt_dir"] = s2_dir + data["save_weight_dir"] = SoVITS_weight_root[int(version[-1]) - 1] + data["name"] = exp_name + data["version"] = version + tmp_config_path = "%s/tmp_s2.json" % tmp + with open(tmp_config_path, "w") as f: + f.write(json.dumps(data)) + if version in ["v1", "v2"]: + cmd = '"%s" GPT_SoVITS/s2_train.py --config "%s"' % (python_exec, tmp_config_path) + else: + cmd = '"%s" GPT_SoVITS/s2_train_v3_lora.py --config "%s"' % (python_exec, tmp_config_path) + yield ( + process_info(process_name_sovits, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + print(cmd) + p_train_SoVITS = Popen(cmd, shell=True) + p_train_SoVITS.wait() + p_train_SoVITS = None + yield ( + process_info(process_name_sovits, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + else: + yield ( + process_info(process_name_sovits, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + + +def close1Ba(): + global p_train_SoVITS + if p_train_SoVITS is not None: + kill_process(p_train_SoVITS.pid, process_name_sovits) + p_train_SoVITS = None + return ( + process_info(process_name_sovits, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +p_train_GPT = None +process_name_gpt = i18n("GPT训练") + + +def open1Bb( + batch_size, + total_epoch, + exp_name, + if_dpo, + if_save_latest, + if_save_every_weights, + save_every_epoch, + gpu_numbers, + pretrained_s1, +): + global p_train_GPT + if p_train_GPT == None: + with open( + "GPT_SoVITS/configs/s1longer.yaml" if version == "v1" else "GPT_SoVITS/configs/s1longer-v2.yaml" + ) as f: + data = f.read() + data = yaml.load(data, Loader=yaml.FullLoader) + s1_dir = "%s/%s" % (exp_root, exp_name) + os.makedirs("%s/logs_s1" % (s1_dir), exist_ok=True) + if check_for_existance([s1_dir], is_train=True): + check_details([s1_dir], is_train=True) + if is_half == False: + data["train"]["precision"] = "32" + batch_size = max(1, batch_size // 2) + data["train"]["batch_size"] = batch_size + data["train"]["epochs"] = total_epoch + data["pretrained_s1"] = pretrained_s1 + data["train"]["save_every_n_epoch"] = save_every_epoch + data["train"]["if_save_every_weights"] = if_save_every_weights + data["train"]["if_save_latest"] = if_save_latest + data["train"]["if_dpo"] = if_dpo + data["train"]["half_weights_save_dir"] = GPT_weight_root[int(version[-1]) - 1] + data["train"]["exp_name"] = exp_name + data["train_semantic_path"] = "%s/6-name2semantic.tsv" % s1_dir + data["train_phoneme_path"] = "%s/2-name2text.txt" % s1_dir + data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version) + # data["version"]=version + + os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(gpu_numbers.replace("-", ",")) + os.environ["hz"] = "25hz" + tmp_config_path = "%s/tmp_s1.yaml" % tmp + with open(tmp_config_path, "w") as f: + f.write(yaml.dump(data, default_flow_style=False)) + # cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" --train_semantic_path "%s/6-name2semantic.tsv" --train_phoneme_path "%s/2-name2text.txt" --output_dir "%s/logs_s1"'%(python_exec,tmp_config_path,s1_dir,s1_dir,s1_dir) + cmd = '"%s" GPT_SoVITS/s1_train.py --config_file "%s" ' % (python_exec, tmp_config_path) + yield ( + process_info(process_name_gpt, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + print(cmd) + p_train_GPT = Popen(cmd, shell=True) + p_train_GPT.wait() + p_train_GPT = None + yield ( + process_info(process_name_gpt, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + else: + yield ( + process_info(process_name_gpt, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + + +def close1Bb(): + global p_train_GPT + if p_train_GPT is not None: + kill_process(p_train_GPT.pid, process_name_gpt) + p_train_GPT = None + return ( + process_info(process_name_gpt, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +ps_slice = [] +process_name_slice = i18n("语音切分") + + +def open_slice(inp, opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, n_parts): + global ps_slice + inp = my_utils.clean_path(inp) + opt_root = my_utils.clean_path(opt_root) + check_for_existance([inp]) + if os.path.exists(inp) == False: + yield ( + i18n("输入路径不存在"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + {"__type__": "update"}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + return + if os.path.isfile(inp): + n_parts = 1 + elif os.path.isdir(inp): + pass + else: + yield ( + i18n("输入路径存在但不可用"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + {"__type__": "update"}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + return + if ps_slice == []: + for i_part in range(n_parts): + cmd = '"%s" tools/slice_audio.py "%s" "%s" %s %s %s %s %s %s %s %s %s' % ( + python_exec, + inp, + opt_root, + threshold, + min_length, + min_interval, + hop_size, + max_sil_kept, + _max, + alpha, + i_part, + n_parts, + ) + print(cmd) + p = Popen(cmd, shell=True) + ps_slice.append(p) + yield ( + process_info(process_name_slice, "opened"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + {"__type__": "update"}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + for p in ps_slice: + p.wait() + ps_slice = [] + yield ( + process_info(process_name_slice, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + {"__type__": "update", "value": opt_root}, + {"__type__": "update", "value": opt_root}, + {"__type__": "update", "value": opt_root}, + ) + else: + yield ( + process_info(process_name_slice, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + {"__type__": "update"}, + {"__type__": "update"}, + {"__type__": "update"}, + ) + + +def close_slice(): + global ps_slice + if ps_slice != []: + for p_slice in ps_slice: + try: + kill_process(p_slice.pid, process_name_slice) + except: + traceback.print_exc() + ps_slice = [] + return ( + process_info(process_name_slice, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +ps1a = [] +process_name_1a = i18n("文本分词与特征提取") + + +def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir): + global ps1a + inp_text = my_utils.clean_path(inp_text) + inp_wav_dir = my_utils.clean_path(inp_wav_dir) + if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True): + check_details([inp_text, inp_wav_dir], is_dataset_processing=True) + if ps1a == []: + opt_dir = "%s/%s" % (exp_root, exp_name) + config = { + "inp_text": inp_text, + "inp_wav_dir": inp_wav_dir, + "exp_name": exp_name, + "opt_dir": opt_dir, + "bert_pretrained_dir": bert_pretrained_dir, + } + gpu_names = gpu_numbers.split("-") + all_parts = len(gpu_names) + for i_part in range(all_parts): + config.update( + { + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + "is_half": str(is_half), + } + ) + os.environ.update(config) + cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec + print(cmd) + p = Popen(cmd, shell=True) + ps1a.append(p) + yield ( + process_info(process_name_1a, "running"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + for p in ps1a: + p.wait() + opt = [] + for i_part in range(all_parts): + txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) + with open(txt_path, "r", encoding="utf8") as f: + opt += f.read().strip("\n").split("\n") + os.remove(txt_path) + path_text = "%s/2-name2text.txt" % opt_dir + with open(path_text, "w", encoding="utf8") as f: + f.write("\n".join(opt) + "\n") + ps1a = [] + if len("".join(opt)) > 0: + yield ( + process_info(process_name_1a, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + else: + yield ( + process_info(process_name_1a, "failed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + else: + yield ( + process_info(process_name_1a, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + + +def close1a(): + global ps1a + if ps1a != []: + for p1a in ps1a: + try: + kill_process(p1a.pid, process_name_1a) + except: + traceback.print_exc() + ps1a = [] + return ( + process_info(process_name_1a, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +ps1b = [] +process_name_1b = i18n("语音自监督特征提取") + + +def open1b(inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained_dir): + global ps1b + inp_text = my_utils.clean_path(inp_text) + inp_wav_dir = my_utils.clean_path(inp_wav_dir) + if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True): + check_details([inp_text, inp_wav_dir], is_dataset_processing=True) + if ps1b == []: + config = { + "inp_text": inp_text, + "inp_wav_dir": inp_wav_dir, + "exp_name": exp_name, + "opt_dir": "%s/%s" % (exp_root, exp_name), + "cnhubert_base_dir": ssl_pretrained_dir, + "is_half": str(is_half), + } + gpu_names = gpu_numbers.split("-") + all_parts = len(gpu_names) + for i_part in range(all_parts): + config.update( + { + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + } + ) + os.environ.update(config) + cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec + print(cmd) + p = Popen(cmd, shell=True) + ps1b.append(p) + yield ( + process_info(process_name_1b, "running"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + for p in ps1b: + p.wait() + ps1b = [] + yield ( + process_info(process_name_1b, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + else: + yield ( + process_info(process_name_1b, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + + +def close1b(): + global ps1b + if ps1b != []: + for p1b in ps1b: + try: + kill_process(p1b.pid, process_name_1b) + except: + traceback.print_exc() + ps1b = [] + return ( + process_info(process_name_1b, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +ps1c = [] +process_name_1c = i18n("语义Token提取") + + +def open1c(inp_text, exp_name, gpu_numbers, pretrained_s2G_path): + global ps1c + inp_text = my_utils.clean_path(inp_text) + if check_for_existance([inp_text, ""], is_dataset_processing=True): + check_details([inp_text, ""], is_dataset_processing=True) + if ps1c == []: + opt_dir = "%s/%s" % (exp_root, exp_name) + config = { + "inp_text": inp_text, + "exp_name": exp_name, + "opt_dir": opt_dir, + "pretrained_s2G": pretrained_s2G_path, + "s2config_path": "GPT_SoVITS/configs/s2.json", + "is_half": str(is_half), + } + gpu_names = gpu_numbers.split("-") + all_parts = len(gpu_names) + for i_part in range(all_parts): + config.update( + { + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + } + ) + os.environ.update(config) + cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec + print(cmd) + p = Popen(cmd, shell=True) + ps1c.append(p) + yield ( + process_info(process_name_1c, "running"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + for p in ps1c: + p.wait() + opt = ["item_name\tsemantic_audio"] + path_semantic = "%s/6-name2semantic.tsv" % opt_dir + for i_part in range(all_parts): + semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part) + with open(semantic_path, "r", encoding="utf8") as f: + opt += f.read().strip("\n").split("\n") + os.remove(semantic_path) + with open(path_semantic, "w", encoding="utf8") as f: + f.write("\n".join(opt) + "\n") + ps1c = [] + yield ( + process_info(process_name_1c, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + else: + yield ( + process_info(process_name_1c, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + + +def close1c(): + global ps1c + if ps1c != []: + for p1c in ps1c: + try: + kill_process(p1c.pid, process_name_1c) + except: + traceback.print_exc() + ps1c = [] + return ( + process_info(process_name_1c, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + + +ps1abc = [] +process_name_1abc = i18n("训练集格式化一键三连") + + +def open1abc( + inp_text, + inp_wav_dir, + exp_name, + gpu_numbers1a, + gpu_numbers1Ba, + gpu_numbers1c, + bert_pretrained_dir, + ssl_pretrained_dir, + pretrained_s2G_path, +): + global ps1abc + inp_text = my_utils.clean_path(inp_text) + inp_wav_dir = my_utils.clean_path(inp_wav_dir) + if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True): + check_details([inp_text, inp_wav_dir], is_dataset_processing=True) + if ps1abc == []: + opt_dir = "%s/%s" % (exp_root, exp_name) + try: + #############################1a + path_text = "%s/2-name2text.txt" % opt_dir + if os.path.exists(path_text) == False or ( + os.path.exists(path_text) == True + and len(open(path_text, "r", encoding="utf8").read().strip("\n").split("\n")) < 2 + ): + config = { + "inp_text": inp_text, + "inp_wav_dir": inp_wav_dir, + "exp_name": exp_name, + "opt_dir": opt_dir, + "bert_pretrained_dir": bert_pretrained_dir, + "is_half": str(is_half), + } + gpu_names = gpu_numbers1a.split("-") + all_parts = len(gpu_names) + for i_part in range(all_parts): + config.update( + { + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + } + ) + os.environ.update(config) + cmd = '"%s" GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec + print(cmd) + p = Popen(cmd, shell=True) + ps1abc.append(p) + yield ( + i18n("进度") + ": 1A-Doing", + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + for p in ps1abc: + p.wait() + + opt = [] + for i_part in range(all_parts): # txt_path="%s/2-name2text-%s.txt"%(opt_dir,i_part) + txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part) + with open(txt_path, "r", encoding="utf8") as f: + opt += f.read().strip("\n").split("\n") + os.remove(txt_path) + with open(path_text, "w", encoding="utf8") as f: + f.write("\n".join(opt) + "\n") + assert len("".join(opt)) > 0, process_info(process_name_1a, "failed") + yield ( + i18n("进度") + ": 1A-Done", + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + ps1abc = [] + #############################1b + config = { + "inp_text": inp_text, + "inp_wav_dir": inp_wav_dir, + "exp_name": exp_name, + "opt_dir": opt_dir, + "cnhubert_base_dir": ssl_pretrained_dir, + } + gpu_names = gpu_numbers1Ba.split("-") + all_parts = len(gpu_names) + for i_part in range(all_parts): + config.update( + { + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + } + ) + os.environ.update(config) + cmd = '"%s" GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec + print(cmd) + p = Popen(cmd, shell=True) + ps1abc.append(p) + yield ( + i18n("进度") + ": 1A-Done, 1B-Doing", + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + for p in ps1abc: + p.wait() + yield ( + i18n("进度") + ": 1A-Done, 1B-Done", + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + ps1abc = [] + #############################1c + path_semantic = "%s/6-name2semantic.tsv" % opt_dir + if os.path.exists(path_semantic) == False or ( + os.path.exists(path_semantic) == True and os.path.getsize(path_semantic) < 31 + ): + config = { + "inp_text": inp_text, + "exp_name": exp_name, + "opt_dir": opt_dir, + "pretrained_s2G": pretrained_s2G_path, + "s2config_path": "GPT_SoVITS/configs/s2.json", + } + gpu_names = gpu_numbers1c.split("-") + all_parts = len(gpu_names) + for i_part in range(all_parts): + config.update( + { + "i_part": str(i_part), + "all_parts": str(all_parts), + "_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), + } + ) + os.environ.update(config) + cmd = '"%s" GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec + print(cmd) + p = Popen(cmd, shell=True) + ps1abc.append(p) + yield ( + i18n("进度") + ": 1A-Done, 1B-Done, 1C-Doing", + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + for p in ps1abc: + p.wait() + + opt = ["item_name\tsemantic_audio"] + for i_part in range(all_parts): + semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part) + with open(semantic_path, "r", encoding="utf8") as f: + opt += f.read().strip("\n").split("\n") + os.remove(semantic_path) + with open(path_semantic, "w", encoding="utf8") as f: + f.write("\n".join(opt) + "\n") + yield ( + i18n("进度") + ": 1A-Done, 1B-Done, 1C-Done", + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + ps1abc = [] + yield ( + process_info(process_name_1abc, "finish"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + except: + traceback.print_exc() + close1abc() + yield ( + process_info(process_name_1abc, "failed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + else: + yield ( + process_info(process_name_1abc, "occupy"), + {"__type__": "update", "visible": False}, + {"__type__": "update", "visible": True}, + ) + + +def close1abc(): + global ps1abc + if ps1abc != []: + for p1abc in ps1abc: + try: + kill_process(p1abc.pid, process_name_1abc) + except: + traceback.print_exc() + ps1abc = [] + return ( + process_info(process_name_1abc, "closed"), + {"__type__": "update", "visible": True}, + {"__type__": "update", "visible": False}, + ) + +def switch_version(version_): + os.environ["version"] = version_ + global version + version = version_ + if pretrained_sovits_name[int(version[-1]) - 1] != "" and pretrained_gpt_name[int(version[-1]) - 1] != "": + ... + else: + gr.Warning(i18n("未下载模型") + ": " + version.upper()) + set_default() + return ( + {"__type__": "update", "value": pretrained_sovits_name[int(version[-1]) - 1]}, + {"__type__": "update", "value": pretrained_sovits_name[int(version[-1]) - 1].replace("s2G", "s2D")}, + {"__type__": "update", "value": pretrained_gpt_name[int(version[-1]) - 1]}, + {"__type__": "update", "value": pretrained_gpt_name[int(version[-1]) - 1]}, + {"__type__": "update", "value": pretrained_sovits_name[int(version[-1]) - 1]}, + {"__type__": "update", "value": default_batch_size, "maximum": default_max_batch_size}, + {"__type__": "update", "value": default_sovits_epoch, "maximum": max_sovits_epoch}, + {"__type__": "update", "value": default_sovits_save_every_epoch, "maximum": max_sovits_save_every_epoch}, + {"__type__": "update", "visible": True if version not in v3v4set else False}, + { + "__type__": "update", + "value": False if not if_force_ckpt else True, + "interactive": True if not if_force_ckpt else False, + }, + {"__type__": "update", "interactive": True, "value": False}, + {"__type__": "update", "visible": True if version in v3v4set else False}, + ) # {'__type__': 'update', "interactive": False if version in v3v4set else True, "value": False}, \ ####batch infer + + +if os.path.exists("GPT_SoVITS/text/G2PWModel"): + ... +else: + cmd = '"%s" GPT_SoVITS/download.py' % python_exec + p = Popen(cmd, shell=True) + p.wait() + + +def sync(text): + return {"__type__": "update", "value": text} + + +with gr.Blocks(title="GPT-SoVITS WebUI") as app: + gr.Markdown( + value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + "
" + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ) + gr.Markdown(value=i18n("中文教程文档") + ": " + "https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e") + + with gr.Tabs(): + with gr.TabItem("0-" + i18n("前置数据集获取工具")): # 提前随机切片防止uvr5爆内存->uvr5->slicer->asr->打标 + gr.Markdown(value="0a-" + i18n("UVR5人声伴奏分离&去混响去延迟工具")) + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + uvr5_info = gr.Textbox(label=process_info(process_name_uvr5, "info")) + open_uvr5 = gr.Button(value=process_info(process_name_uvr5, "open"), variant="primary", visible=True) + close_uvr5 = gr.Button(value=process_info(process_name_uvr5, "close"), variant="primary", visible=False) + + gr.Markdown(value="0b-" + i18n("语音切分工具")) + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + slice_inp_path = gr.Textbox(label=i18n("音频自动切分输入路径,可文件可文件夹"), value="") + slice_opt_root = gr.Textbox(label=i18n("切分后的子音频的输出根目录"), value="output/slicer_opt") + with gr.Row(): + threshold = gr.Textbox(label=i18n("threshold:音量小于这个值视作静音的备选切割点"), value="-34") + min_length = gr.Textbox( + label=i18n("min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值"), + value="4000", + ) + min_interval = gr.Textbox(label=i18n("min_interval:最短切割间隔"), value="300") + hop_size = gr.Textbox( + label=i18n("hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)"), + value="10", + ) + max_sil_kept = gr.Textbox(label=i18n("max_sil_kept:切完后静音最多留多长"), value="500") + with gr.Row(): + _max = gr.Slider( + minimum=0, + maximum=1, + step=0.05, + label=i18n("max:归一化后最大值多少"), + value=0.9, + interactive=True, + ) + alpha = gr.Slider( + minimum=0, + maximum=1, + step=0.05, + label=i18n("alpha_mix:混多少比例归一化后音频进来"), + value=0.25, + interactive=True, + ) + with gr.Row(): + n_process = gr.Slider( + minimum=1, maximum=n_cpu, step=1, label=i18n("切割使用的进程数"), value=4, interactive=True + ) + slicer_info = gr.Textbox(label=process_info(process_name_slice, "info")) + open_slicer_button = gr.Button( + value=process_info(process_name_slice, "open"), variant="primary", visible=True + ) + close_slicer_button = gr.Button( + value=process_info(process_name_slice, "close"), variant="primary", visible=False + ) + + gr.Markdown(value="0bb-" + i18n("语音降噪工具")) + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + denoise_input_dir = gr.Textbox(label=i18n("输入文件夹路径"), value="") + denoise_output_dir = gr.Textbox(label=i18n("输出文件夹路径"), value="output/denoise_opt") + with gr.Row(): + denoise_info = gr.Textbox(label=process_info(process_name_denoise, "info")) + open_denoise_button = gr.Button( + value=process_info(process_name_denoise, "open"), variant="primary", visible=True + ) + close_denoise_button = gr.Button( + value=process_info(process_name_denoise, "close"), variant="primary", visible=False + ) + + gr.Markdown(value="0c-" + i18n("语音识别工具")) + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + asr_inp_dir = gr.Textbox( + label=i18n("输入文件夹路径"), value="D:\\GPT-SoVITS\\raw\\xxx", interactive=True + ) + asr_opt_dir = gr.Textbox(label=i18n("输出文件夹路径"), value="output/asr_opt", interactive=True) + with gr.Row(): + asr_model = gr.Dropdown( + label=i18n("ASR 模型"), + choices=list(asr_dict.keys()), + interactive=True, + value="达摩 ASR (中文)", + ) + asr_size = gr.Dropdown( + label=i18n("ASR 模型尺寸"), choices=["large"], interactive=True, value="large" + ) + asr_lang = gr.Dropdown( + label=i18n("ASR 语言设置"), choices=["zh", "yue"], interactive=True, value="zh" + ) + asr_precision = gr.Dropdown( + label=i18n("数据类型精度"), choices=["float32"], interactive=True, value="float32" + ) + with gr.Row(): + asr_info = gr.Textbox(label=process_info(process_name_asr, "info")) + open_asr_button = gr.Button( + value=process_info(process_name_asr, "open"), variant="primary", visible=True + ) + close_asr_button = gr.Button( + value=process_info(process_name_asr, "close"), variant="primary", visible=False + ) + + def change_lang_choices(key): # 根据选择的模型修改可选的语言 + return {"__type__": "update", "choices": asr_dict[key]["lang"], "value": asr_dict[key]["lang"][0]} + + def change_size_choices(key): # 根据选择的模型修改可选的模型尺寸 + return {"__type__": "update", "choices": asr_dict[key]["size"], "value": asr_dict[key]["size"][-1]} + + def change_precision_choices(key): # 根据选择的模型修改可选的语言 + if key == "Faster Whisper (多语种)": + if default_batch_size <= 4: + precision = "int8" + elif is_half: + precision = "float16" + else: + precision = "float32" + else: + precision = "float32" + return {"__type__": "update", "choices": asr_dict[key]["precision"], "value": precision} + + asr_model.change(change_lang_choices, [asr_model], [asr_lang]) + asr_model.change(change_size_choices, [asr_model], [asr_size]) + asr_model.change(change_precision_choices, [asr_model], [asr_precision]) + + gr.Markdown(value="0d-" + i18n("语音文本校对标注工具")) + with gr.Row(): + with gr.Column(scale=3): + with gr.Row(): + path_list = gr.Textbox( + label=i18n("标注文件路径 (含文件后缀 *.list)"), + value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx.list", + interactive=True, + ) + label_info = gr.Textbox(label=process_info(process_name_subfix, "info")) + open_label = gr.Button(value=process_info(process_name_subfix, "open"), variant="primary", visible=True) + close_label = gr.Button( + value=process_info(process_name_subfix, "close"), variant="primary", visible=False + ) + + open_label.click(change_label, [path_list], [label_info, open_label, close_label]) + close_label.click(change_label, [path_list], [label_info, open_label, close_label]) + open_uvr5.click(change_uvr5, [], [uvr5_info, open_uvr5, close_uvr5]) + close_uvr5.click(change_uvr5, [], [uvr5_info, open_uvr5, close_uvr5]) + + with gr.TabItem(i18n("1-GPT-SoVITS-TTS")): + with gr.Row(): + with gr.Row(): + exp_name = gr.Textbox(label=i18n("*实验/模型名"), value="xxx", interactive=True) + gpu_info = gr.Textbox(label=i18n("显卡信息"), value=gpu_info, visible=True, interactive=False) + version_checkbox = gr.Radio(label=i18n("版本"), value=version, choices=["v1", "v2", "v4"])#, "v3" + with gr.Row(): + pretrained_s2G = gr.Textbox( + label=i18n("预训练SoVITS-G模型路径"), + value=pretrained_sovits_name[int(version[-1]) - 1], + interactive=True, + lines=2, + max_lines=3, + scale=9, + ) + pretrained_s2D = gr.Textbox( + label=i18n("预训练SoVITS-D模型路径"), + value=pretrained_sovits_name[int(version[-1]) - 1].replace("s2G", "s2D"), + interactive=True, + lines=2, + max_lines=3, + scale=9, + ) + pretrained_s1 = gr.Textbox( + label=i18n("预训练GPT模型路径"), + value=pretrained_gpt_name[int(version[-1]) - 1], + interactive=True, + lines=2, + max_lines=3, + scale=10, + ) + + with gr.TabItem("1A-" + i18n("训练集格式化工具")): + gr.Markdown(value=i18n("输出logs/实验名目录下应有23456开头的文件和文件夹")) + with gr.Row(): + with gr.Row(): + inp_text = gr.Textbox( + label=i18n("*文本标注文件"), + value=r"D:\RVC1006\GPT-SoVITS\raw\xxx.list", + interactive=True, + scale=10, + ) + with gr.Row(): + inp_wav_dir = gr.Textbox( + label=i18n("*训练集音频文件目录"), + # value=r"D:\RVC1006\GPT-SoVITS\raw\xxx", + interactive=True, + placeholder=i18n( + "填切割后音频所在目录!读取的音频文件完整路径=该目录-拼接-list文件里波形对应的文件名(不是全路径)。如果留空则使用.list文件里的绝对全路径。" + ), + scale=10, + ) + + gr.Markdown(value="1Aa-" + process_name_1a) + with gr.Row(): + with gr.Row(): + gpu_numbers1a = gr.Textbox( + label=i18n("GPU卡号以-分割,每个卡号一个进程"), + value="%s-%s" % (gpus, gpus), + interactive=True, + ) + with gr.Row(): + bert_pretrained_dir = gr.Textbox( + label=i18n("预训练中文BERT模型路径"), + value="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + interactive=False, + lines=2, + ) + with gr.Row(): + button1a_open = gr.Button( + value=process_info(process_name_1a, "open"), variant="primary", visible=True + ) + button1a_close = gr.Button( + value=process_info(process_name_1a, "close"), variant="primary", visible=False + ) + with gr.Row(): + info1a = gr.Textbox(label=process_info(process_name_1a, "info")) + + gr.Markdown(value="1Ab-" + process_name_1b) + with gr.Row(): + with gr.Row(): + gpu_numbers1Ba = gr.Textbox( + label=i18n("GPU卡号以-分割,每个卡号一个进程"), + value="%s-%s" % (gpus, gpus), + interactive=True, + ) + with gr.Row(): + cnhubert_base_dir = gr.Textbox( + label=i18n("预训练SSL模型路径"), + value="GPT_SoVITS/pretrained_models/chinese-hubert-base", + interactive=False, + lines=2, + ) + with gr.Row(): + button1b_open = gr.Button( + value=process_info(process_name_1b, "open"), variant="primary", visible=True + ) + button1b_close = gr.Button( + value=process_info(process_name_1b, "close"), variant="primary", visible=False + ) + with gr.Row(): + info1b = gr.Textbox(label=process_info(process_name_1b, "info")) + + gr.Markdown(value="1Ac-" + process_name_1c) + with gr.Row(): + with gr.Row(): + gpu_numbers1c = gr.Textbox( + label=i18n("GPU卡号以-分割,每个卡号一个进程"), + value="%s-%s" % (gpus, gpus), + interactive=True, + ) + with gr.Row(): + pretrained_s2G_ = gr.Textbox( + label=i18n("预训练SoVITS-G模型路径"), + value=pretrained_sovits_name[int(version[-1]) - 1], + interactive=False, + lines=2, + ) + with gr.Row(): + button1c_open = gr.Button( + value=process_info(process_name_1c, "open"), variant="primary", visible=True + ) + button1c_close = gr.Button( + value=process_info(process_name_1c, "close"), variant="primary", visible=False + ) + with gr.Row(): + info1c = gr.Textbox(label=process_info(process_name_1c, "info")) + + gr.Markdown(value="1Aabc-" + process_name_1abc) + with gr.Row(): + with gr.Row(): + button1abc_open = gr.Button( + value=process_info(process_name_1abc, "open"), variant="primary", visible=True + ) + button1abc_close = gr.Button( + value=process_info(process_name_1abc, "close"), variant="primary", visible=False + ) + with gr.Row(): + info1abc = gr.Textbox(label=process_info(process_name_1abc, "info")) + + pretrained_s2G.change(sync, [pretrained_s2G], [pretrained_s2G_]) + open_asr_button.click( + open_asr, + [asr_inp_dir, asr_opt_dir, asr_model, asr_size, asr_lang, asr_precision], + [asr_info, open_asr_button, close_asr_button, path_list, inp_text, inp_wav_dir], + ) + close_asr_button.click(close_asr, [], [asr_info, open_asr_button, close_asr_button]) + open_slicer_button.click( + open_slice, + [ + slice_inp_path, + slice_opt_root, + threshold, + min_length, + min_interval, + hop_size, + max_sil_kept, + _max, + alpha, + n_process, + ], + [slicer_info, open_slicer_button, close_slicer_button, asr_inp_dir, denoise_input_dir, inp_wav_dir], + ) + close_slicer_button.click(close_slice, [], [slicer_info, open_slicer_button, close_slicer_button]) + open_denoise_button.click( + open_denoise, + [denoise_input_dir, denoise_output_dir], + [denoise_info, open_denoise_button, close_denoise_button, asr_inp_dir, inp_wav_dir], + ) + close_denoise_button.click(close_denoise, [], [denoise_info, open_denoise_button, close_denoise_button]) + + button1a_open.click( + open1a, + [inp_text, inp_wav_dir, exp_name, gpu_numbers1a, bert_pretrained_dir], + [info1a, button1a_open, button1a_close], + ) + button1a_close.click(close1a, [], [info1a, button1a_open, button1a_close]) + button1b_open.click( + open1b, + [inp_text, inp_wav_dir, exp_name, gpu_numbers1Ba, cnhubert_base_dir], + [info1b, button1b_open, button1b_close], + ) + button1b_close.click(close1b, [], [info1b, button1b_open, button1b_close]) + button1c_open.click( + open1c, [inp_text, exp_name, gpu_numbers1c, pretrained_s2G], [info1c, button1c_open, button1c_close] + ) + button1c_close.click(close1c, [], [info1c, button1c_open, button1c_close]) + button1abc_open.click( + open1abc, + [ + inp_text, + inp_wav_dir, + exp_name, + gpu_numbers1a, + gpu_numbers1Ba, + gpu_numbers1c, + bert_pretrained_dir, + cnhubert_base_dir, + pretrained_s2G, + ], + [info1abc, button1abc_open, button1abc_close], + ) + button1abc_close.click(close1abc, [], [info1abc, button1abc_open, button1abc_close]) + + with gr.TabItem("1B-" + i18n("微调训练")): + gr.Markdown(value="1Ba-" + i18n("SoVITS 训练: 模型权重文件在 SoVITS_weights/")) + with gr.Row(): + with gr.Column(): + with gr.Row(): + batch_size = gr.Slider( + minimum=1, + maximum=default_max_batch_size, + step=1, + label=i18n("每张显卡的batch_size"), + value=default_batch_size, + interactive=True, + ) + total_epoch = gr.Slider( + minimum=1, + maximum=max_sovits_epoch, + step=1, + label=i18n("总训练轮数total_epoch,不建议太高"), + value=default_sovits_epoch, + interactive=True, + ) + with gr.Row(): + text_low_lr_rate = gr.Slider( + minimum=0.2, + maximum=0.6, + step=0.05, + label=i18n("文本模块学习率权重"), + value=0.4, + visible=True if version not in v3v4set else False, + ) # v3 not need + lora_rank = gr.Radio( + label=i18n("LoRA秩"), + value="32", + choices=["16", "32", "64", "128"], + visible=True if version in v3v4set else False, + ) # v1v2 not need + save_every_epoch = gr.Slider( + minimum=1, + maximum=max_sovits_save_every_epoch, + step=1, + label=i18n("保存频率save_every_epoch"), + value=default_sovits_save_every_epoch, + interactive=True, + ) + with gr.Column(): + with gr.Column(): + if_save_latest = gr.Checkbox( + label=i18n("是否仅保存最新的权重文件以节省硬盘空间"), + value=True, + interactive=True, + show_label=True, + ) + if_save_every_weights = gr.Checkbox( + label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), + value=True, + interactive=True, + show_label=True, + ) + if_grad_ckpt = gr.Checkbox( + label="v3是否开启梯度检查点节省显存占用", + value=False, + interactive=True if version in v3v4set else False, + show_label=True, + visible=False, + ) # 只有V3s2可以用 + with gr.Row(): + gpu_numbers1Ba = gr.Textbox( + label=i18n("GPU卡号以-分割,每个卡号一个进程"), value="%s" % (gpus), interactive=True + ) + with gr.Row(): + with gr.Row(): + button1Ba_open = gr.Button( + value=process_info(process_name_sovits, "open"), variant="primary", visible=True + ) + button1Ba_close = gr.Button( + value=process_info(process_name_sovits, "close"), variant="primary", visible=False + ) + with gr.Row(): + info1Ba = gr.Textbox(label=process_info(process_name_sovits, "info")) + gr.Markdown(value="1Bb-" + i18n("GPT 训练: 模型权重文件在 GPT_weights/")) + with gr.Row(): + with gr.Column(): + with gr.Row(): + batch_size1Bb = gr.Slider( + minimum=1, + maximum=40, + step=1, + label=i18n("每张显卡的batch_size"), + value=default_batch_size_s1, + interactive=True, + ) + total_epoch1Bb = gr.Slider( + minimum=2, + maximum=50, + step=1, + label=i18n("总训练轮数total_epoch"), + value=15, + interactive=True, + ) + with gr.Row(): + save_every_epoch1Bb = gr.Slider( + minimum=1, + maximum=50, + step=1, + label=i18n("保存频率save_every_epoch"), + value=5, + interactive=True, + ) + if_dpo = gr.Checkbox( + label=i18n("是否开启DPO训练选项(实验性)"), + value=False, + interactive=True, + show_label=True, + ) + with gr.Column(): + with gr.Column(): + if_save_latest1Bb = gr.Checkbox( + label=i18n("是否仅保存最新的权重文件以节省硬盘空间"), + value=True, + interactive=True, + show_label=True, + ) + if_save_every_weights1Bb = gr.Checkbox( + label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"), + value=True, + interactive=True, + show_label=True, + ) + with gr.Row(): + gpu_numbers1Bb = gr.Textbox( + label=i18n("GPU卡号以-分割,每个卡号一个进程"), value="%s" % (gpus), interactive=True + ) + with gr.Row(): + with gr.Row(): + button1Bb_open = gr.Button( + value=process_info(process_name_gpt, "open"), variant="primary", visible=True + ) + button1Bb_close = gr.Button( + value=process_info(process_name_gpt, "close"), variant="primary", visible=False + ) + with gr.Row(): + info1Bb = gr.Textbox(label=process_info(process_name_gpt, "info")) + + button1Ba_open.click( + open1Ba, + [ + batch_size, + total_epoch, + exp_name, + text_low_lr_rate, + if_save_latest, + if_save_every_weights, + save_every_epoch, + gpu_numbers1Ba, + pretrained_s2G, + pretrained_s2D, + if_grad_ckpt, + lora_rank, + ], + [info1Ba, button1Ba_open, button1Ba_close], + ) + button1Ba_close.click(close1Ba, [], [info1Ba, button1Ba_open, button1Ba_close]) + button1Bb_open.click( + open1Bb, + [ + batch_size1Bb, + total_epoch1Bb, + exp_name, + if_dpo, + if_save_latest1Bb, + if_save_every_weights1Bb, + save_every_epoch1Bb, + gpu_numbers1Bb, + pretrained_s1, + ], + [info1Bb, button1Bb_open, button1Bb_close], + ) + button1Bb_close.click(close1Bb, [], [info1Bb, button1Bb_open, button1Bb_close]) + + with gr.TabItem("1C-" + i18n("推理")): + gr.Markdown( + value=i18n( + "选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的一个是底模,体验5秒Zero Shot TTS用。" + ) + ) + with gr.Row(): + with gr.Row(): + GPT_dropdown = gr.Dropdown( + label=i18n("GPT模型列表"), + choices=sorted(GPT_names, key=custom_sort_key), + value=pretrained_gpt_name[0], + interactive=True, + ) + SoVITS_dropdown = gr.Dropdown( + label=i18n("SoVITS模型列表"), + choices=sorted(SoVITS_names, key=custom_sort_key), + value=pretrained_sovits_name[0], + interactive=True, + ) + with gr.Row(): + gpu_number_1C = gr.Textbox(label=i18n("GPU卡号,只能填1个整数"), value=gpus, interactive=True) + refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary") + refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown]) + with gr.Row(): + with gr.Row(): + batched_infer_enabled = gr.Checkbox( + label=i18n("启用并行推理版本"), value=False, interactive=True, show_label=True + ) + with gr.Row(): + open_tts = gr.Button( + value=process_info(process_name_tts, "open"), variant="primary", visible=True + ) + close_tts = gr.Button( + value=process_info(process_name_tts, "close"), variant="primary", visible=False + ) + with gr.Row(): + tts_info = gr.Textbox(label=process_info(process_name_tts, "info")) + open_tts.click( + change_tts_inference, + [ + bert_pretrained_dir, + cnhubert_base_dir, + gpu_number_1C, + GPT_dropdown, + SoVITS_dropdown, + batched_infer_enabled, + ], + [tts_info, open_tts, close_tts], + ) + close_tts.click( + change_tts_inference, + [ + bert_pretrained_dir, + cnhubert_base_dir, + gpu_number_1C, + GPT_dropdown, + SoVITS_dropdown, + batched_infer_enabled, + ], + [tts_info, open_tts, close_tts], + ) + + version_checkbox.change( + switch_version, + [version_checkbox], + [ + pretrained_s2G, + pretrained_s2D, + pretrained_s1, + GPT_dropdown, + SoVITS_dropdown, + batch_size, + total_epoch, + save_every_epoch, + text_low_lr_rate, + if_grad_ckpt, + batched_infer_enabled, + lora_rank, + ], + ) + + with gr.TabItem(i18n("2-GPT-SoVITS-变声")): + gr.Markdown(value=i18n("施工中,请静候佳音")) + + app.queue().launch( # concurrency_count=511, max_size=1022 + server_name="0.0.0.0", + inbrowser=True, + share=is_share, + server_port=webui_port_main, + quiet=True, + )