zfwcpc kevinwang676 commited on
Commit
dab0054
0 Parent(s):

Duplicate from kevinwang676/ChatGLM2-SadTalker-VC

Browse files

Co-authored-by: Kevin Wang <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .flake8 +21 -0
  2. .gitattributes +52 -0
  3. .gitignore +159 -0
  4. Dockerfile +59 -0
  5. LICENSE +21 -0
  6. README.md +15 -0
  7. app.py +608 -0
  8. checkpoint/__init__.py +0 -0
  9. checkpoint/freevc-24.pth +3 -0
  10. checkpoints/BFM_Fitting/01_MorphableModel.mat +1 -0
  11. checkpoints/BFM_Fitting/BFM09_model_info.mat +1 -0
  12. checkpoints/BFM_Fitting/BFM_exp_idx.mat +1 -0
  13. checkpoints/BFM_Fitting/BFM_front_idx.mat +1 -0
  14. checkpoints/BFM_Fitting/Exp_Pca.bin +1 -0
  15. checkpoints/BFM_Fitting/facemodel_info.mat +1 -0
  16. checkpoints/BFM_Fitting/select_vertex_id.mat +1 -0
  17. checkpoints/BFM_Fitting/similarity_Lm3D_all.mat +1 -0
  18. checkpoints/BFM_Fitting/std_exp.txt +1 -0
  19. checkpoints/auido2exp_00300-model.pth +1 -0
  20. checkpoints/auido2pose_00140-model.pth +1 -0
  21. checkpoints/epoch_20.pth +1 -0
  22. checkpoints/facevid2vid_00189-model.pth.tar +1 -0
  23. checkpoints/hub/checkpoints/2DFAN4-cd938726ad.zip +1 -0
  24. checkpoints/hub/checkpoints/s3fd-619a316812.pth +1 -0
  25. checkpoints/mapping_00229-model.pth.tar +1 -0
  26. checkpoints/shape_predictor_68_face_landmarks.dat +1 -0
  27. checkpoints/wav2lip.pth +1 -0
  28. commons.py +171 -0
  29. configs/freevc-24.json +54 -0
  30. mel_processing.py +112 -0
  31. models.py +351 -0
  32. modules.py +342 -0
  33. packages.txt +2 -0
  34. requirements.txt +35 -0
  35. speaker_encoder/__init__.py +1 -0
  36. speaker_encoder/audio.py +107 -0
  37. speaker_encoder/ckpt/__init__.py +1 -0
  38. speaker_encoder/ckpt/pretrained_bak_5805000.pt +3 -0
  39. speaker_encoder/compute_embed.py +40 -0
  40. speaker_encoder/config.py +45 -0
  41. speaker_encoder/data_objects/__init__.py +2 -0
  42. speaker_encoder/data_objects/random_cycler.py +37 -0
  43. speaker_encoder/data_objects/speaker.py +40 -0
  44. speaker_encoder/data_objects/speaker_batch.py +12 -0
  45. speaker_encoder/data_objects/speaker_verification_dataset.py +56 -0
  46. speaker_encoder/data_objects/utterance.py +26 -0
  47. speaker_encoder/hparams.py +31 -0
  48. speaker_encoder/inference.py +177 -0
  49. speaker_encoder/model.py +135 -0
  50. speaker_encoder/params_data.py +29 -0
.flake8 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore =
3
+ # E203 whitespace before ':'
4
+ E203
5
+ D203,
6
+ # line too long
7
+ E501
8
+ per-file-ignores =
9
+ # imported but unused
10
+ # __init__.py: F401
11
+ test_*.py: F401
12
+ exclude =
13
+ .git,
14
+ __pycache__,
15
+ docs/source/conf.py,
16
+ old,
17
+ build,
18
+ dist,
19
+ .venv
20
+ pad*.py
21
+ max-complexity = 25
.gitattributes ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ checkpoints/BFM_Fitting/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/BFM_Fitting/BFM09_model_info.mat filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/facevid2vid_00189-model.pth.tar filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/mapping_00229-model.pth.tar filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
40
+ examples/driven_audio/chinese_news.wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/driven_audio/deyu.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/driven_audio/eluosi.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/driven_audio/fayu.wav filter=lfs diff=lfs merge=lfs -text
44
+ examples/driven_audio/imagine.wav filter=lfs diff=lfs merge=lfs -text
45
+ examples/driven_audio/japanese.wav filter=lfs diff=lfs merge=lfs -text
46
+ examples/source_image/art_16.png filter=lfs diff=lfs merge=lfs -text
47
+ examples/source_image/art_17.png filter=lfs diff=lfs merge=lfs -text
48
+ examples/source_image/art_3.png filter=lfs diff=lfs merge=lfs -text
49
+ examples/source_image/art_4.png filter=lfs diff=lfs merge=lfs -text
50
+ examples/source_image/art_5.png filter=lfs diff=lfs merge=lfs -text
51
+ examples/source_image/art_8.png filter=lfs diff=lfs merge=lfs -text
52
+ examples/source_image/art_9.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ results/
156
+ checkpoints/
157
+ gradio_cached_examples/
158
+ gfpgan/
159
+ start.sh
Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ zip \
8
+ unzip \
9
+ git-lfs \
10
+ wget \
11
+ curl \
12
+ # ffmpeg \
13
+ ffmpeg \
14
+ x264 \
15
+ # python build dependencies \
16
+ build-essential \
17
+ libssl-dev \
18
+ zlib1g-dev \
19
+ libbz2-dev \
20
+ libreadline-dev \
21
+ libsqlite3-dev \
22
+ libncursesw5-dev \
23
+ xz-utils \
24
+ tk-dev \
25
+ libxml2-dev \
26
+ libxmlsec1-dev \
27
+ libffi-dev \
28
+ liblzma-dev && \
29
+ apt-get clean && \
30
+ rm -rf /var/lib/apt/lists/*
31
+
32
+ RUN useradd -m -u 1000 user
33
+ USER user
34
+ ENV HOME=/home/user \
35
+ PATH=/home/user/.local/bin:${PATH}
36
+ WORKDIR ${HOME}/app
37
+
38
+ RUN curl https://pyenv.run | bash
39
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
40
+ ENV PYTHON_VERSION=3.10.9
41
+ RUN pyenv install ${PYTHON_VERSION} && \
42
+ pyenv global ${PYTHON_VERSION} && \
43
+ pyenv rehash && \
44
+ pip install --no-cache-dir -U pip setuptools wheel
45
+
46
+ RUN pip install --no-cache-dir -U torch==1.12.1 torchvision==0.13.1
47
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
48
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
49
+
50
+ COPY --chown=1000 . ${HOME}/app
51
+ RUN ls -a
52
+ ENV PYTHONPATH=${HOME}/app \
53
+ PYTHONUNBUFFERED=1 \
54
+ GRADIO_ALLOW_FLAGGING=never \
55
+ GRADIO_NUM_PORTS=1 \
56
+ GRADIO_SERVER_NAME=0.0.0.0 \
57
+ GRADIO_THEME=huggingface \
58
+ SYSTEM=spaces
59
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Tencent AI Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ChatGLM2-SadTalker
3
+ emoji: 📺
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: kevinwang676/ChatGLM2-SadTalker-VC
12
+ ---
13
+
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import tempfile
3
+ import gradio as gr
4
+ from src.gradio_demo import SadTalker
5
+ # from src.utils.text2speech import TTSTalker
6
+ from huggingface_hub import snapshot_download
7
+
8
+ import torch
9
+ import librosa
10
+ from scipy.io.wavfile import write
11
+ from transformers import WavLMModel
12
+
13
+ import utils
14
+ from models import SynthesizerTrn
15
+ from mel_processing import mel_spectrogram_torch
16
+ from speaker_encoder.voice_encoder import SpeakerEncoder
17
+
18
+ import time
19
+ from textwrap import dedent
20
+
21
+ import mdtex2html
22
+ from loguru import logger
23
+ from transformers import AutoModel, AutoTokenizer
24
+
25
+ from tts_voice import tts_order_voice
26
+ import edge_tts
27
+ import tempfile
28
+ import anyio
29
+
30
+
31
+ def get_source_image(image):
32
+ return image
33
+
34
+ try:
35
+ import webui # in webui
36
+ in_webui = True
37
+ except:
38
+ in_webui = False
39
+
40
+
41
+ def toggle_audio_file(choice):
42
+ if choice == False:
43
+ return gr.update(visible=True), gr.update(visible=False)
44
+ else:
45
+ return gr.update(visible=False), gr.update(visible=True)
46
+
47
+ def ref_video_fn(path_of_ref_video):
48
+ if path_of_ref_video is not None:
49
+ return gr.update(value=True)
50
+ else:
51
+ return gr.update(value=False)
52
+
53
+ def download_model():
54
+ REPO_ID = 'vinthony/SadTalker-V002rc'
55
+ snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
56
+
57
+ def sadtalker_demo():
58
+
59
+ download_model()
60
+
61
+ sad_talker = SadTalker(lazy_load=True)
62
+ # tts_talker = TTSTalker()
63
+
64
+ download_model()
65
+ sad_talker = SadTalker(lazy_load=True)
66
+
67
+
68
+ # ChatGLM2 & FreeVC
69
+
70
+ '''
71
+ def get_wavlm():
72
+ os.system('gdown https://drive.google.com/uc?id=12-cB34qCTvByWT-QtOcZaqwwO21FLSqU')
73
+ shutil.move('WavLM-Large.pt', 'wavlm')
74
+ '''
75
+
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+
78
+ smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
79
+
80
+ print("Loading FreeVC(24k)...")
81
+ hps = utils.get_hparams_from_file("configs/freevc-24.json")
82
+ freevc_24 = SynthesizerTrn(
83
+ hps.data.filter_length // 2 + 1,
84
+ hps.train.segment_size // hps.data.hop_length,
85
+ **hps.model).to(device)
86
+ _ = freevc_24.eval()
87
+ _ = utils.load_checkpoint("checkpoint/freevc-24.pth", freevc_24, None)
88
+
89
+ print("Loading WavLM for content...")
90
+ cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
91
+
92
+ def convert(model, src, tgt):
93
+ with torch.no_grad():
94
+ # tgt
95
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
96
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
97
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
98
+ g_tgt = smodel.embed_utterance(wav_tgt)
99
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
100
+ else:
101
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
102
+ mel_tgt = mel_spectrogram_torch(
103
+ wav_tgt,
104
+ hps.data.filter_length,
105
+ hps.data.n_mel_channels,
106
+ hps.data.sampling_rate,
107
+ hps.data.hop_length,
108
+ hps.data.win_length,
109
+ hps.data.mel_fmin,
110
+ hps.data.mel_fmax
111
+ )
112
+ # src
113
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
114
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
115
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
116
+ # infer
117
+ if model == "FreeVC":
118
+ audio = freevc.infer(c, g=g_tgt)
119
+ elif model == "FreeVC-s":
120
+ audio = freevc_s.infer(c, mel=mel_tgt)
121
+ else:
122
+ audio = freevc_24.infer(c, g=g_tgt)
123
+ audio = audio[0][0].data.cpu().float().numpy()
124
+ if model == "FreeVC" or model == "FreeVC-s":
125
+ write("out.wav", hps.data.sampling_rate, audio)
126
+ else:
127
+ write("out.wav", 24000, audio)
128
+ out = "out.wav"
129
+ return out
130
+
131
+ # GLM2
132
+
133
+ language_dict = tts_order_voice
134
+
135
+ # fix timezone in Linux
136
+ os.environ["TZ"] = "Asia/Shanghai"
137
+ try:
138
+ time.tzset() # type: ignore # pylint: disable=no-member
139
+ except Exception:
140
+ # Windows
141
+ logger.warning("Windows, cant run time.tzset()")
142
+
143
+ # model_name = "THUDM/chatglm2-6b"
144
+ model_name = "THUDM/chatglm2-6b-int4"
145
+
146
+ RETRY_FLAG = False
147
+
148
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
149
+
150
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
151
+
152
+ # 4/8 bit
153
+ # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
154
+
155
+ has_cuda = torch.cuda.is_available()
156
+
157
+ # has_cuda = False # force cpu
158
+
159
+ if has_cuda:
160
+ model_glm = (
161
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
162
+ ) # 3.92G
163
+ else:
164
+ model_glm = AutoModel.from_pretrained(
165
+ model_name, trust_remote_code=True
166
+ ).float() # .float() .half().float()
167
+
168
+ model_glm = model_glm.eval()
169
+
170
+ _ = """Override Chatbot.postprocess"""
171
+
172
+
173
+ def postprocess(self, y):
174
+ if y is None:
175
+ return []
176
+ for i, (message, response) in enumerate(y):
177
+ y[i] = (
178
+ None if message is None else mdtex2html.convert((message)),
179
+ None if response is None else mdtex2html.convert(response),
180
+ )
181
+ return y
182
+
183
+
184
+ gr.Chatbot.postprocess = postprocess
185
+
186
+
187
+ def parse_text(text):
188
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
189
+ lines = text.split("\n")
190
+ lines = [line for line in lines if line != ""]
191
+ count = 0
192
+ for i, line in enumerate(lines):
193
+ if "```" in line:
194
+ count += 1
195
+ items = line.split("`")
196
+ if count % 2 == 1:
197
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
198
+ else:
199
+ lines[i] = "<br></code></pre>"
200
+ else:
201
+ if i > 0:
202
+ if count % 2 == 1:
203
+ line = line.replace("`", r"\`")
204
+ line = line.replace("<", "&lt;")
205
+ line = line.replace(">", "&gt;")
206
+ line = line.replace(" ", "&nbsp;")
207
+ line = line.replace("*", "&ast;")
208
+ line = line.replace("_", "&lowbar;")
209
+ line = line.replace("-", "&#45;")
210
+ line = line.replace(".", "&#46;")
211
+ line = line.replace("!", "&#33;")
212
+ line = line.replace("(", "&#40;")
213
+ line = line.replace(")", "&#41;")
214
+ line = line.replace("$", "&#36;")
215
+ lines[i] = "<br>" + line
216
+ text = "".join(lines)
217
+ return text
218
+
219
+
220
+ def predict(
221
+ RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
222
+ ):
223
+ try:
224
+ chatbot.append((parse_text(input), ""))
225
+ except Exception as exc:
226
+ logger.error(exc)
227
+ logger.debug(f"{chatbot=}")
228
+ _ = """
229
+ if chatbot:
230
+ chatbot[-1] = (parse_text(input), str(exc))
231
+ yield chatbot, history, past_key_values
232
+ # """
233
+ yield chatbot, history, past_key_values
234
+
235
+ for response, history, past_key_values in model_glm.stream_chat(
236
+ tokenizer,
237
+ input,
238
+ history,
239
+ past_key_values=past_key_values,
240
+ return_past_key_values=True,
241
+ max_length=max_length,
242
+ top_p=top_p,
243
+ temperature=temperature,
244
+ ):
245
+ chatbot[-1] = (parse_text(input), parse_text(response))
246
+ # chatbot[-1][-1] = parse_text(response)
247
+
248
+ yield chatbot, history, past_key_values, parse_text(response)
249
+
250
+
251
+ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
252
+ if max_length < 10:
253
+ max_length = 4096
254
+ if top_p < 0.1 or top_p > 1:
255
+ top_p = 0.85
256
+ if temperature <= 0 or temperature > 1:
257
+ temperature = 0.01
258
+ try:
259
+ res, _ = model_glm.chat(
260
+ tokenizer,
261
+ input,
262
+ history=[],
263
+ past_key_values=None,
264
+ max_length=max_length,
265
+ top_p=top_p,
266
+ temperature=temperature,
267
+ )
268
+ # logger.debug(f"{res=} \n{_=}")
269
+ except Exception as exc:
270
+ logger.error(f"{exc=}")
271
+ res = str(exc)
272
+
273
+ return res
274
+
275
+
276
+ def reset_user_input():
277
+ return gr.update(value="")
278
+
279
+
280
+ def reset_state():
281
+ return [], [], None, ""
282
+
283
+
284
+ # Delete last turn
285
+ def delete_last_turn(chat, history):
286
+ if chat and history:
287
+ chat.pop(-1)
288
+ history.pop(-1)
289
+ return chat, history
290
+
291
+
292
+ # Regenerate response
293
+ def retry_last_answer(
294
+ user_input, chatbot, max_length, top_p, temperature, history, past_key_values
295
+ ):
296
+ if chatbot and history:
297
+ # Removing the previous conversation from chat
298
+ chatbot.pop(-1)
299
+ # Setting up a flag to capture a retry
300
+ RETRY_FLAG = True
301
+ # Getting last message from user
302
+ user_input = history[-1][0]
303
+ # Removing bot response from the history
304
+ history.pop(-1)
305
+
306
+ yield from predict(
307
+ RETRY_FLAG, # type: ignore
308
+ user_input,
309
+ chatbot,
310
+ max_length,
311
+ top_p,
312
+ temperature,
313
+ history,
314
+ past_key_values,
315
+ )
316
+
317
+ # print
318
+
319
+ def print(text):
320
+ return text
321
+
322
+ # TTS
323
+
324
+ async def text_to_speech_edge(text, language_code):
325
+ voice = language_dict[language_code]
326
+ communicate = edge_tts.Communicate(text, voice)
327
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
328
+ tmp_path = tmp_file.name
329
+
330
+ await communicate.save(tmp_path)
331
+
332
+ return tmp_path
333
+
334
+
335
+ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm"), analytics_enabled=False) as demo:
336
+ gr.HTML("<center>"
337
+ "<h1>📺💕🎶 - ChatGLM2+声音克隆+视频对话:和喜欢的角色畅所欲言吧!</h1>"
338
+ "</center>")
339
+ gr.Markdown("## <center>🥳 - ChatGLM2+FreeVC+SadTalker,为您打造沉浸式的视频对话体验,支持中英双语</center>")
340
+ gr.Markdown("## <center>🌊 - 更多精彩应用,尽在[滔滔AI](http://www.talktalkai.com);滔滔AI,为爱滔滔!💕</center>")
341
+ gr.Markdown("### <center>⭐ - 如果您喜欢这个程序,欢迎给我的[GitHub项目](https://github.com/KevinWang676/ChatGLM2-Voice-Cloning)点赞支持!</center>")
342
+
343
+ with gr.Tab("🍻 - ChatGLM2聊天区"):
344
+ with gr.Accordion("📒 相关信息", open=False):
345
+ _ = f""" ChatGLM2的可选参数信息:
346
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
347
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
348
+ * Top P controls dynamic vocabulary selection based on context.\n
349
+ 如果您想让ChatGLM2进行角色扮演并与之对话,请先输入恰当的提示词,如“请你扮演成动漫角色蜡笔小新并和我进行对话”;您也可以为ChatGLM2提供自定义的角色设定\n
350
+ 当您使用声音克隆功能时,请先在此程序的对应位置上传一段您喜欢的音频
351
+ """
352
+ gr.Markdown(dedent(_))
353
+ chatbot = gr.Chatbot(height=300)
354
+ with gr.Row():
355
+ with gr.Column(scale=4):
356
+ with gr.Column(scale=12):
357
+ user_input = gr.Textbox(
358
+ label="请在此处和GLM2聊天 (按回车键即可发送)",
359
+ placeholder="聊点什么吧",
360
+ )
361
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
362
+ with gr.Column(min_width=32, scale=1):
363
+ with gr.Row():
364
+ submitBtn = gr.Button("开始和GLM2交流吧", variant="primary")
365
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
366
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
367
+
368
+ with gr.Accordion("🔧 更多设置", open=False):
369
+ with gr.Row():
370
+ emptyBtn = gr.Button("清空所有聊天记录")
371
+ max_length = gr.Slider(
372
+ 0,
373
+ 32768,
374
+ value=8192,
375
+ step=1.0,
376
+ label="Maximum length",
377
+ interactive=True,
378
+ )
379
+ top_p = gr.Slider(
380
+ 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
381
+ )
382
+ temperature = gr.Slider(
383
+ 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
384
+ )
385
+
386
+
387
+ with gr.Row():
388
+ test1 = gr.Textbox(label="GLM2的最新回答 (可编辑)", lines = 3)
389
+ with gr.Column():
390
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
391
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
392
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
393
+
394
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
395
+
396
+ with gr.Row():
397
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
398
+ audio1 = output_audio
399
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
400
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
401
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
402
+
403
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
404
+
405
+ history = gr.State([])
406
+ past_key_values = gr.State(None)
407
+
408
+ user_input.submit(
409
+ predict,
410
+ [
411
+ RETRY_FLAG,
412
+ user_input,
413
+ chatbot,
414
+ max_length,
415
+ top_p,
416
+ temperature,
417
+ history,
418
+ past_key_values,
419
+ ],
420
+ [chatbot, history, past_key_values, test1],
421
+ show_progress="full",
422
+ )
423
+ submitBtn.click(
424
+ predict,
425
+ [
426
+ RETRY_FLAG,
427
+ user_input,
428
+ chatbot,
429
+ max_length,
430
+ top_p,
431
+ temperature,
432
+ history,
433
+ past_key_values,
434
+ ],
435
+ [chatbot, history, past_key_values, test1],
436
+ show_progress="full",
437
+ api_name="predict",
438
+ )
439
+ submitBtn.click(reset_user_input, [], [user_input])
440
+
441
+ emptyBtn.click(
442
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
443
+ )
444
+
445
+ retryBtn.click(
446
+ retry_last_answer,
447
+ inputs=[
448
+ user_input,
449
+ chatbot,
450
+ max_length,
451
+ top_p,
452
+ temperature,
453
+ history,
454
+ past_key_values,
455
+ ],
456
+ # outputs = [chatbot, history, last_user_message, user_message]
457
+ outputs=[chatbot, history, past_key_values, test1],
458
+ )
459
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
460
+
461
+ with gr.Accordion("📔 提示词示例", open=False):
462
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
463
+ examples = gr.Examples(
464
+ examples=[
465
+ ["Explain the plot of Cinderella in a sentence."],
466
+ [
467
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
468
+ ],
469
+ ["What are some common mistakes to avoid when writing code?"],
470
+ ["Build a prompt to generate a beautiful portrait of a horse"],
471
+ ["Suggest four metaphors to describe the benefits of AI"],
472
+ ["Write a pop song about leaving home for the sandy beaches."],
473
+ ["Write a summary demonstrating my ability to tame lions"],
474
+ ["鲁迅和周树人什么关系"],
475
+ ["从前有一头牛,这头牛后面有什么?"],
476
+ ["正无穷大加一大于正无穷大吗?"],
477
+ ["正无穷大加正无穷大大于正无穷大吗?"],
478
+ ["-2的平方根等于什么"],
479
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
480
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
481
+ ["鲁迅和周树人什么关系 用英文回答"],
482
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
483
+ [f"{etext} 翻成中文,列出3个版本"],
484
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
485
+ ["js 判断一个数是不是质数"],
486
+ ["js 实现python 的 range(10)"],
487
+ ["js 实现python 的 [*(range(10)]"],
488
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
489
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
490
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
491
+ ],
492
+ inputs=[user_input],
493
+ examples_per_page=30,
494
+ )
495
+
496
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
497
+ input_text = gr.Text()
498
+ tr_btn = gr.Button("Go", variant="primary")
499
+ out_text = gr.Text()
500
+ tr_btn.click(
501
+ trans_api,
502
+ [input_text, max_length, top_p, temperature],
503
+ out_text,
504
+ # show_progress="full",
505
+ api_name="tr",
506
+ )
507
+ _ = """
508
+ input_text.submit(
509
+ trans_api,
510
+ [input_text, max_length, top_p, temperature],
511
+ out_text,
512
+ show_progress="full",
513
+ api_name="tr1",
514
+ )
515
+ # """
516
+ with gr.Tab("📺 - 视频聊天区"):
517
+ with gr.Row().style(equal_height=False):
518
+ with gr.Column(variant='panel'):
519
+ with gr.Tabs(elem_id="sadtalker_source_image"):
520
+ with gr.TabItem('图片上传'):
521
+ with gr.Row():
522
+ source_image = gr.Image(label="请上传一张您喜欢角色的图片", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
523
+
524
+
525
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
526
+ with gr.TabItem('💡您还可以将视频下载到本地'):
527
+
528
+ with gr.Row():
529
+ driven_audio = audio_cloned
530
+ driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)
531
+
532
+ with gr.Column():
533
+ use_idle_mode = gr.Checkbox(label="Use Idle Animation", visible=False)
534
+ length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.", visible=False)
535
+ use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo
536
+
537
+ with gr.Row():
538
+ ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref", visible=False).style(width=512)
539
+
540
+ with gr.Column():
541
+ use_ref_video = gr.Checkbox(label="Use Reference Video", visible=False)
542
+ ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))", visible=False)
543
+
544
+ ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo
545
+
546
+
547
+ with gr.Column(variant='panel'):
548
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
549
+ with gr.TabItem('视频设置'):
550
+ with gr.Column(variant='panel'):
551
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
552
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
553
+ with gr.Row():
554
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0, visible=False) #
555
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1, visible=False) #
556
+ blink_every = gr.Checkbox(label="use eye blink", value=True, visible=False)
557
+
558
+ with gr.Row():
559
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?", visible=False) #
560
+ preprocess_type = gr.Radio(['crop', 'full'], value='crop', label='是否聚焦角色面部', info="crop:视频会聚焦角色面部;full:视频会显示图片全貌")
561
+
562
+ with gr.Row():
563
+ is_still_mode = gr.Checkbox(label="静态模式 (开启静态模式,角色的面部动作会减少;默认开启)", value=True)
564
+ facerender = gr.Radio(['facevid2vid','pirender'], value='facevid2vid', label='facerender', info="which face render?", visible=False)
565
+
566
+ with gr.Row():
567
+ batch_size = gr.Slider(label="Batch size (数值越大,生成速度越快;若显卡性能好,可增大数值)", step=1, maximum=32, value=2)
568
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer", value=True, visible=False)
569
+
570
+ submit = gr.Button('开始视频聊天吧', elem_id="sadtalker_generate", variant='primary')
571
+
572
+ with gr.Tabs(elem_id="sadtalker_genearted"):
573
+ gen_video = gr.Video(label="为您生成的专属视频", format="mp4").style(width=256)
574
+
575
+
576
+
577
+ submit.click(
578
+ fn=sad_talker.test,
579
+ inputs=[source_image,
580
+ driven_audio,
581
+ preprocess_type,
582
+ is_still_mode,
583
+ enhancer,
584
+ batch_size,
585
+ size_of_image,
586
+ pose_style,
587
+ facerender,
588
+ exp_weight,
589
+ use_ref_video,
590
+ ref_video,
591
+ ref_info,
592
+ use_idle_mode,
593
+ length_of_audio,
594
+ blink_every
595
+ ],
596
+ outputs=[gen_video]
597
+ )
598
+ gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。</center>")
599
+ gr.Markdown("<center>💡- 如何使用此程序:输入您对ChatGLM的提问后,依次点击“开始和GLM2交流吧”、“生成对应的音频吧”、“开始AI声音克隆吧”、“开始视频聊天吧”四个按键即可;使用声音克隆功能时,请先上传一段您喜欢的音频</center>")
600
+ gr.HTML('''
601
+ <div class="footer">
602
+ <p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘
603
+ </p>
604
+ </div>
605
+ ''')
606
+
607
+
608
+ demo.queue().launch(show_error=True, debug=True)
checkpoint/__init__.py ADDED
File without changes
checkpoint/freevc-24.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20
3
+ size 472644351
checkpoints/BFM_Fitting/01_MorphableModel.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2
checkpoints/BFM_Fitting/BFM09_model_info.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/db8d00544f0b0182f1b8430a3bb87662b3ff674eb33c84e6f52dbe2971adb81b
checkpoints/BFM_Fitting/BFM_exp_idx.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/1146e4e9c3bef303a497383aa7974c014fe945c7
checkpoints/BFM_Fitting/BFM_front_idx.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/b9d7b0953dd1dc5b1e28144610485409ac321f9b
checkpoints/BFM_Fitting/Exp_Pca.bin ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726
checkpoints/BFM_Fitting/facemodel_info.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/3e516ec7297fa3248098f49ecea10579f4831c0a
checkpoints/BFM_Fitting/select_vertex_id.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/5b8b220093d93b133acc94ffed159f31a74854cd
checkpoints/BFM_Fitting/similarity_Lm3D_all.mat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/a0e23588302bc71fc899eef53ff06df5f4df4c1d
checkpoints/BFM_Fitting/std_exp.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/767b8de4ea1ca78b6f22b98ff2dee4fa345500bb
checkpoints/auido2exp_00300-model.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/b7608f0e6b477e50e03ca569ac5b04a841b9217f89d502862fc78fda4e46dec4
checkpoints/auido2pose_00140-model.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/4fba6701852dc57efbed25b1e4276e4ff752941860d69fc4429f08a02326ebce
checkpoints/epoch_20.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/6d17a6b23457b521801baae583cb6a58f7238fe6721fc3d65d76407460e9149b
checkpoints/facevid2vid_00189-model.pth.tar ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/fbad01d46f0510276dc4521322dde6824a873a4222cd0740c85762e7067ea71d
checkpoints/hub/checkpoints/2DFAN4-cd938726ad.zip ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/cd938726adb1f15f361263cce2db9cb820c42585fa8796ec72ce19107f369a46
checkpoints/hub/checkpoints/s3fd-619a316812.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/619a31681264d3f7f7fc7a16a42cbbe8b23f31a256f75a366e5a1bcd59b33543
checkpoints/mapping_00229-model.pth.tar ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker-V002rc/blobs/62a1e06006cc963220f6477438518ed86e9788226c62ae382ddc42fbcefb83f1
checkpoints/shape_predictor_68_face_landmarks.dat ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
checkpoints/wav2lip.pth ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../root/.cache/huggingface/hub/models--vinthony--SadTalker/blobs/b78b681b68ad9fe6c6fb1debc6ff43ad05834a8af8a62ffc4167b7b34ef63c37
commons.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size*dilation - dilation)/2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def rand_spec_segments(x, x_lengths=None, segment_size=4):
68
+ b, d, t = x.size()
69
+ if x_lengths is None:
70
+ x_lengths = t
71
+ ids_str_max = x_lengths - segment_size
72
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
73
+ ret = slice_segments(x, ids_str, segment_size)
74
+ return ret, ids_str
75
+
76
+
77
+ def get_timing_signal_1d(
78
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
79
+ position = torch.arange(length, dtype=torch.float)
80
+ num_timescales = channels // 2
81
+ log_timescale_increment = (
82
+ math.log(float(max_timescale) / float(min_timescale)) /
83
+ (num_timescales - 1))
84
+ inv_timescales = min_timescale * torch.exp(
85
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
86
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
87
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
88
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
89
+ signal = signal.view(1, channels, length)
90
+ return signal
91
+
92
+
93
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
94
+ b, channels, length = x.size()
95
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
96
+ return x + signal.to(dtype=x.dtype, device=x.device)
97
+
98
+
99
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
100
+ b, channels, length = x.size()
101
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
102
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
103
+
104
+
105
+ def subsequent_mask(length):
106
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
107
+ return mask
108
+
109
+
110
+ @torch.jit.script
111
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
112
+ n_channels_int = n_channels[0]
113
+ in_act = input_a + input_b
114
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
115
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
116
+ acts = t_act * s_act
117
+ return acts
118
+
119
+
120
+ def convert_pad_shape(pad_shape):
121
+ l = pad_shape[::-1]
122
+ pad_shape = [item for sublist in l for item in sublist]
123
+ return pad_shape
124
+
125
+
126
+ def shift_1d(x):
127
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
128
+ return x
129
+
130
+
131
+ def sequence_mask(length, max_length=None):
132
+ if max_length is None:
133
+ max_length = length.max()
134
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
135
+ return x.unsqueeze(0) < length.unsqueeze(1)
136
+
137
+
138
+ def generate_path(duration, mask):
139
+ """
140
+ duration: [b, 1, t_x]
141
+ mask: [b, 1, t_y, t_x]
142
+ """
143
+ device = duration.device
144
+
145
+ b, _, t_y, t_x = mask.shape
146
+ cum_duration = torch.cumsum(duration, -1)
147
+
148
+ cum_duration_flat = cum_duration.view(b * t_x)
149
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
150
+ path = path.view(b, t_x, t_y)
151
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
152
+ path = path.unsqueeze(1).transpose(2,3) * mask
153
+ return path
154
+
155
+
156
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
157
+ if isinstance(parameters, torch.Tensor):
158
+ parameters = [parameters]
159
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
160
+ norm_type = float(norm_type)
161
+ if clip_value is not None:
162
+ clip_value = float(clip_value)
163
+
164
+ total_norm = 0
165
+ for p in parameters:
166
+ param_norm = p.grad.data.norm(norm_type)
167
+ total_norm += param_norm.item() ** norm_type
168
+ if clip_value is not None:
169
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
170
+ total_norm = total_norm ** (1. / norm_type)
171
+ return total_norm
configs/freevc-24.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 10000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 64,
11
+ "fp16_run": false,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8640,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0,
18
+ "use_sr": true,
19
+ "max_speclen": 128,
20
+ "port": "8008"
21
+ },
22
+ "data": {
23
+ "training_files":"filelists/train.txt",
24
+ "validation_files":"filelists/val.txt",
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 16000,
27
+ "filter_length": 1280,
28
+ "hop_length": 320,
29
+ "win_length": 1280,
30
+ "n_mel_channels": 80,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null
33
+ },
34
+ "model": {
35
+ "inter_channels": 192,
36
+ "hidden_channels": 192,
37
+ "filter_channels": 768,
38
+ "n_heads": 2,
39
+ "n_layers": 6,
40
+ "kernel_size": 3,
41
+ "p_dropout": 0.1,
42
+ "resblock": "1",
43
+ "resblock_kernel_sizes": [3,7,11],
44
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
45
+ "upsample_rates": [10,6,4,2],
46
+ "upsample_initial_channel": 512,
47
+ "upsample_kernel_sizes": [16,16,4,4],
48
+ "n_layers_q": 3,
49
+ "use_spectral_norm": false,
50
+ "gin_channels": 256,
51
+ "ssl_dim": 1024,
52
+ "use_spk": true
53
+ }
54
+ }
mel_processing.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ import numpy as np
9
+ import librosa
10
+ import librosa.util as librosa_util
11
+ from librosa.util import normalize, pad_center, tiny
12
+ from scipy.signal import get_window
13
+ from scipy.io.wavfile import read
14
+ from librosa.filters import mel as librosa_mel_fn
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ """
21
+ PARAMS
22
+ ------
23
+ C: compression factor
24
+ """
25
+ return torch.log(torch.clamp(x, min=clip_val) * C)
26
+
27
+
28
+ def dynamic_range_decompression_torch(x, C=1):
29
+ """
30
+ PARAMS
31
+ ------
32
+ C: compression factor used to compress
33
+ """
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
+ if torch.min(y) < -1.:
53
+ print('min value is ', torch.min(y))
54
+ if torch.max(y) > 1.:
55
+ print('max value is ', torch.max(y))
56
+
57
+ global hann_window
58
+ dtype_device = str(y.dtype) + '_' + str(y.device)
59
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
60
+ if wnsize_dtype_device not in hann_window:
61
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62
+
63
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64
+ y = y.squeeze(1)
65
+
66
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
68
+
69
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70
+ return spec
71
+
72
+
73
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74
+ global mel_basis
75
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
76
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
77
+ if fmax_dtype_device not in mel_basis:
78
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81
+ spec = spectral_normalize_torch(spec)
82
+ return spec
83
+
84
+
85
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86
+ if torch.min(y) < -1.:
87
+ print('min value is ', torch.min(y))
88
+ if torch.max(y) > 1.:
89
+ print('max value is ', torch.max(y))
90
+
91
+ global mel_basis, hann_window
92
+ dtype_device = str(y.dtype) + '_' + str(y.device)
93
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
94
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
95
+ if fmax_dtype_device not in mel_basis:
96
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
97
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98
+ if wnsize_dtype_device not in hann_window:
99
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100
+
101
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102
+ y = y.squeeze(1)
103
+
104
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
106
+
107
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108
+
109
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110
+ spec = spectral_normalize_torch(spec)
111
+
112
+ return spec
models.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ import modules
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from commons import init_weights, get_padding
13
+
14
+
15
+ class ResidualCouplingBlock(nn.Module):
16
+ def __init__(self,
17
+ channels,
18
+ hidden_channels,
19
+ kernel_size,
20
+ dilation_rate,
21
+ n_layers,
22
+ n_flows=4,
23
+ gin_channels=0):
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.hidden_channels = hidden_channels
27
+ self.kernel_size = kernel_size
28
+ self.dilation_rate = dilation_rate
29
+ self.n_layers = n_layers
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.flows = nn.ModuleList()
34
+ for i in range(n_flows):
35
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
36
+ self.flows.append(modules.Flip())
37
+
38
+ def forward(self, x, x_mask, g=None, reverse=False):
39
+ if not reverse:
40
+ for flow in self.flows:
41
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
42
+ else:
43
+ for flow in reversed(self.flows):
44
+ x = flow(x, x_mask, g=g, reverse=reverse)
45
+ return x
46
+
47
+
48
+ class Encoder(nn.Module):
49
+ def __init__(self,
50
+ in_channels,
51
+ out_channels,
52
+ hidden_channels,
53
+ kernel_size,
54
+ dilation_rate,
55
+ n_layers,
56
+ gin_channels=0):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ self.out_channels = out_channels
60
+ self.hidden_channels = hidden_channels
61
+ self.kernel_size = kernel_size
62
+ self.dilation_rate = dilation_rate
63
+ self.n_layers = n_layers
64
+ self.gin_channels = gin_channels
65
+
66
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
67
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
68
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
69
+
70
+ def forward(self, x, x_lengths, g=None):
71
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
72
+ x = self.pre(x) * x_mask
73
+ x = self.enc(x, x_mask, g=g)
74
+ stats = self.proj(x) * x_mask
75
+ m, logs = torch.split(stats, self.out_channels, dim=1)
76
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
77
+ return z, m, logs, x_mask
78
+
79
+
80
+ class Generator(torch.nn.Module):
81
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
82
+ super(Generator, self).__init__()
83
+ self.num_kernels = len(resblock_kernel_sizes)
84
+ self.num_upsamples = len(upsample_rates)
85
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
86
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
87
+
88
+ self.ups = nn.ModuleList()
89
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
90
+ self.ups.append(weight_norm(
91
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
92
+ k, u, padding=(k-u)//2)))
93
+
94
+ self.resblocks = nn.ModuleList()
95
+ for i in range(len(self.ups)):
96
+ ch = upsample_initial_channel//(2**(i+1))
97
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
98
+ self.resblocks.append(resblock(ch, k, d))
99
+
100
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
101
+ self.ups.apply(init_weights)
102
+
103
+ if gin_channels != 0:
104
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
105
+
106
+ def forward(self, x, g=None):
107
+ x = self.conv_pre(x)
108
+ if g is not None:
109
+ x = x + self.cond(g)
110
+
111
+ for i in range(self.num_upsamples):
112
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
113
+ x = self.ups[i](x)
114
+ xs = None
115
+ for j in range(self.num_kernels):
116
+ if xs is None:
117
+ xs = self.resblocks[i*self.num_kernels+j](x)
118
+ else:
119
+ xs += self.resblocks[i*self.num_kernels+j](x)
120
+ x = xs / self.num_kernels
121
+ x = F.leaky_relu(x)
122
+ x = self.conv_post(x)
123
+ x = torch.tanh(x)
124
+
125
+ return x
126
+
127
+ def remove_weight_norm(self):
128
+ print('Removing weight norm...')
129
+ for l in self.ups:
130
+ remove_weight_norm(l)
131
+ for l in self.resblocks:
132
+ l.remove_weight_norm()
133
+
134
+
135
+ class DiscriminatorP(torch.nn.Module):
136
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
137
+ super(DiscriminatorP, self).__init__()
138
+ self.period = period
139
+ self.use_spectral_norm = use_spectral_norm
140
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
141
+ self.convs = nn.ModuleList([
142
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
143
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
144
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
145
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
146
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
147
+ ])
148
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
149
+
150
+ def forward(self, x):
151
+ fmap = []
152
+
153
+ # 1d to 2d
154
+ b, c, t = x.shape
155
+ if t % self.period != 0: # pad first
156
+ n_pad = self.period - (t % self.period)
157
+ x = F.pad(x, (0, n_pad), "reflect")
158
+ t = t + n_pad
159
+ x = x.view(b, c, t // self.period, self.period)
160
+
161
+ for l in self.convs:
162
+ x = l(x)
163
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
164
+ fmap.append(x)
165
+ x = self.conv_post(x)
166
+ fmap.append(x)
167
+ x = torch.flatten(x, 1, -1)
168
+
169
+ return x, fmap
170
+
171
+
172
+ class DiscriminatorS(torch.nn.Module):
173
+ def __init__(self, use_spectral_norm=False):
174
+ super(DiscriminatorS, self).__init__()
175
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
176
+ self.convs = nn.ModuleList([
177
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
178
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
179
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
180
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
181
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
182
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
183
+ ])
184
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
185
+
186
+ def forward(self, x):
187
+ fmap = []
188
+
189
+ for l in self.convs:
190
+ x = l(x)
191
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
192
+ fmap.append(x)
193
+ x = self.conv_post(x)
194
+ fmap.append(x)
195
+ x = torch.flatten(x, 1, -1)
196
+
197
+ return x, fmap
198
+
199
+
200
+ class MultiPeriodDiscriminator(torch.nn.Module):
201
+ def __init__(self, use_spectral_norm=False):
202
+ super(MultiPeriodDiscriminator, self).__init__()
203
+ periods = [2,3,5,7,11]
204
+
205
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
206
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
207
+ self.discriminators = nn.ModuleList(discs)
208
+
209
+ def forward(self, y, y_hat):
210
+ y_d_rs = []
211
+ y_d_gs = []
212
+ fmap_rs = []
213
+ fmap_gs = []
214
+ for i, d in enumerate(self.discriminators):
215
+ y_d_r, fmap_r = d(y)
216
+ y_d_g, fmap_g = d(y_hat)
217
+ y_d_rs.append(y_d_r)
218
+ y_d_gs.append(y_d_g)
219
+ fmap_rs.append(fmap_r)
220
+ fmap_gs.append(fmap_g)
221
+
222
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
223
+
224
+
225
+ class SpeakerEncoder(torch.nn.Module):
226
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
227
+ super(SpeakerEncoder, self).__init__()
228
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
229
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
230
+ self.relu = nn.ReLU()
231
+
232
+ def forward(self, mels):
233
+ self.lstm.flatten_parameters()
234
+ _, (hidden, _) = self.lstm(mels)
235
+ embeds_raw = self.relu(self.linear(hidden[-1]))
236
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
237
+
238
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
239
+ mel_slices = []
240
+ for i in range(0, total_frames-partial_frames, partial_hop):
241
+ mel_range = torch.arange(i, i+partial_frames)
242
+ mel_slices.append(mel_range)
243
+
244
+ return mel_slices
245
+
246
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
247
+ mel_len = mel.size(1)
248
+ last_mel = mel[:,-partial_frames:]
249
+
250
+ if mel_len > partial_frames:
251
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
252
+ mels = list(mel[:,s] for s in mel_slices)
253
+ mels.append(last_mel)
254
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
255
+
256
+ with torch.no_grad():
257
+ partial_embeds = self(mels)
258
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
259
+ #embed = embed / torch.linalg.norm(embed, 2)
260
+ else:
261
+ with torch.no_grad():
262
+ embed = self(last_mel)
263
+
264
+ return embed
265
+
266
+
267
+ class SynthesizerTrn(nn.Module):
268
+ """
269
+ Synthesizer for Training
270
+ """
271
+
272
+ def __init__(self,
273
+ spec_channels,
274
+ segment_size,
275
+ inter_channels,
276
+ hidden_channels,
277
+ filter_channels,
278
+ n_heads,
279
+ n_layers,
280
+ kernel_size,
281
+ p_dropout,
282
+ resblock,
283
+ resblock_kernel_sizes,
284
+ resblock_dilation_sizes,
285
+ upsample_rates,
286
+ upsample_initial_channel,
287
+ upsample_kernel_sizes,
288
+ gin_channels,
289
+ ssl_dim,
290
+ use_spk,
291
+ **kwargs):
292
+
293
+ super().__init__()
294
+ self.spec_channels = spec_channels
295
+ self.inter_channels = inter_channels
296
+ self.hidden_channels = hidden_channels
297
+ self.filter_channels = filter_channels
298
+ self.n_heads = n_heads
299
+ self.n_layers = n_layers
300
+ self.kernel_size = kernel_size
301
+ self.p_dropout = p_dropout
302
+ self.resblock = resblock
303
+ self.resblock_kernel_sizes = resblock_kernel_sizes
304
+ self.resblock_dilation_sizes = resblock_dilation_sizes
305
+ self.upsample_rates = upsample_rates
306
+ self.upsample_initial_channel = upsample_initial_channel
307
+ self.upsample_kernel_sizes = upsample_kernel_sizes
308
+ self.segment_size = segment_size
309
+ self.gin_channels = gin_channels
310
+ self.ssl_dim = ssl_dim
311
+ self.use_spk = use_spk
312
+
313
+ self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16)
314
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
315
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
316
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
317
+
318
+ if not self.use_spk:
319
+ self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels)
320
+
321
+ def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
322
+ if c_lengths == None:
323
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
324
+ if spec_lengths == None:
325
+ spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
326
+
327
+ if not self.use_spk:
328
+ g = self.enc_spk(mel.transpose(1,2))
329
+ g = g.unsqueeze(-1)
330
+
331
+ _, m_p, logs_p, _ = self.enc_p(c, c_lengths)
332
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
333
+ z_p = self.flow(z, spec_mask, g=g)
334
+
335
+ z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
336
+ o = self.dec(z_slice, g=g)
337
+
338
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
339
+
340
+ def infer(self, c, g=None, mel=None, c_lengths=None):
341
+ if c_lengths == None:
342
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
343
+ if not self.use_spk:
344
+ g = self.enc_spk.embed_utterance(mel.transpose(1,2))
345
+ g = g.unsqueeze(-1)
346
+
347
+ z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
348
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
349
+ o = self.dec(z * c_mask, g=g)
350
+
351
+ return o
modules.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48
+ self.norm_layers.append(LayerNorm(hidden_channels))
49
+ self.relu_drop = nn.Sequential(
50
+ nn.ReLU(),
51
+ nn.Dropout(p_dropout))
52
+ for _ in range(n_layers-1):
53
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56
+ self.proj.weight.data.zero_()
57
+ self.proj.bias.data.zero_()
58
+
59
+ def forward(self, x, x_mask):
60
+ x_org = x
61
+ for i in range(self.n_layers):
62
+ x = self.conv_layers[i](x * x_mask)
63
+ x = self.norm_layers[i](x)
64
+ x = self.relu_drop(x)
65
+ x = x_org + self.proj(x)
66
+ return x * x_mask
67
+
68
+
69
+ class DDSConv(nn.Module):
70
+ """
71
+ Dialted and Depth-Separable Convolution
72
+ """
73
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.kernel_size = kernel_size
77
+ self.n_layers = n_layers
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = nn.Dropout(p_dropout)
81
+ self.convs_sep = nn.ModuleList()
82
+ self.convs_1x1 = nn.ModuleList()
83
+ self.norms_1 = nn.ModuleList()
84
+ self.norms_2 = nn.ModuleList()
85
+ for i in range(n_layers):
86
+ dilation = kernel_size ** i
87
+ padding = (kernel_size * dilation - dilation) // 2
88
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89
+ groups=channels, dilation=dilation, padding=padding
90
+ ))
91
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92
+ self.norms_1.append(LayerNorm(channels))
93
+ self.norms_2.append(LayerNorm(channels))
94
+
95
+ def forward(self, x, x_mask, g=None):
96
+ if g is not None:
97
+ x = x + g
98
+ for i in range(self.n_layers):
99
+ y = self.convs_sep[i](x * x_mask)
100
+ y = self.norms_1[i](y)
101
+ y = F.gelu(y)
102
+ y = self.convs_1x1[i](y)
103
+ y = self.norms_2[i](y)
104
+ y = F.gelu(y)
105
+ y = self.drop(y)
106
+ x = x + y
107
+ return x * x_mask
108
+
109
+
110
+ class WN(torch.nn.Module):
111
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112
+ super(WN, self).__init__()
113
+ assert(kernel_size % 2 == 1)
114
+ self.hidden_channels =hidden_channels
115
+ self.kernel_size = kernel_size,
116
+ self.dilation_rate = dilation_rate
117
+ self.n_layers = n_layers
118
+ self.gin_channels = gin_channels
119
+ self.p_dropout = p_dropout
120
+
121
+ self.in_layers = torch.nn.ModuleList()
122
+ self.res_skip_layers = torch.nn.ModuleList()
123
+ self.drop = nn.Dropout(p_dropout)
124
+
125
+ if gin_channels != 0:
126
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128
+
129
+ for i in range(n_layers):
130
+ dilation = dilation_rate ** i
131
+ padding = int((kernel_size * dilation - dilation) / 2)
132
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133
+ dilation=dilation, padding=padding)
134
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135
+ self.in_layers.append(in_layer)
136
+
137
+ # last one is not necessary
138
+ if i < n_layers - 1:
139
+ res_skip_channels = 2 * hidden_channels
140
+ else:
141
+ res_skip_channels = hidden_channels
142
+
143
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145
+ self.res_skip_layers.append(res_skip_layer)
146
+
147
+ def forward(self, x, x_mask, g=None, **kwargs):
148
+ output = torch.zeros_like(x)
149
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
150
+
151
+ if g is not None:
152
+ g = self.cond_layer(g)
153
+
154
+ for i in range(self.n_layers):
155
+ x_in = self.in_layers[i](x)
156
+ if g is not None:
157
+ cond_offset = i * 2 * self.hidden_channels
158
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159
+ else:
160
+ g_l = torch.zeros_like(x_in)
161
+
162
+ acts = commons.fused_add_tanh_sigmoid_multiply(
163
+ x_in,
164
+ g_l,
165
+ n_channels_tensor)
166
+ acts = self.drop(acts)
167
+
168
+ res_skip_acts = self.res_skip_layers[i](acts)
169
+ if i < self.n_layers - 1:
170
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
171
+ x = (x + res_acts) * x_mask
172
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
173
+ else:
174
+ output = output + res_skip_acts
175
+ return output * x_mask
176
+
177
+ def remove_weight_norm(self):
178
+ if self.gin_channels != 0:
179
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
180
+ for l in self.in_layers:
181
+ torch.nn.utils.remove_weight_norm(l)
182
+ for l in self.res_skip_layers:
183
+ torch.nn.utils.remove_weight_norm(l)
184
+
185
+
186
+ class ResBlock1(torch.nn.Module):
187
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188
+ super(ResBlock1, self).__init__()
189
+ self.convs1 = nn.ModuleList([
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191
+ padding=get_padding(kernel_size, dilation[0]))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193
+ padding=get_padding(kernel_size, dilation[1]))),
194
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195
+ padding=get_padding(kernel_size, dilation[2])))
196
+ ])
197
+ self.convs1.apply(init_weights)
198
+
199
+ self.convs2 = nn.ModuleList([
200
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
+ padding=get_padding(kernel_size, 1))),
202
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
+ padding=get_padding(kernel_size, 1))),
204
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205
+ padding=get_padding(kernel_size, 1)))
206
+ ])
207
+ self.convs2.apply(init_weights)
208
+
209
+ def forward(self, x, x_mask=None):
210
+ for c1, c2 in zip(self.convs1, self.convs2):
211
+ xt = F.leaky_relu(x, LRELU_SLOPE)
212
+ if x_mask is not None:
213
+ xt = xt * x_mask
214
+ xt = c1(xt)
215
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
216
+ if x_mask is not None:
217
+ xt = xt * x_mask
218
+ xt = c2(xt)
219
+ x = xt + x
220
+ if x_mask is not None:
221
+ x = x * x_mask
222
+ return x
223
+
224
+ def remove_weight_norm(self):
225
+ for l in self.convs1:
226
+ remove_weight_norm(l)
227
+ for l in self.convs2:
228
+ remove_weight_norm(l)
229
+
230
+
231
+ class ResBlock2(torch.nn.Module):
232
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233
+ super(ResBlock2, self).__init__()
234
+ self.convs = nn.ModuleList([
235
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]))),
237
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238
+ padding=get_padding(kernel_size, dilation[1])))
239
+ ])
240
+ self.convs.apply(init_weights)
241
+
242
+ def forward(self, x, x_mask=None):
243
+ for c in self.convs:
244
+ xt = F.leaky_relu(x, LRELU_SLOPE)
245
+ if x_mask is not None:
246
+ xt = xt * x_mask
247
+ xt = c(xt)
248
+ x = xt + x
249
+ if x_mask is not None:
250
+ x = x * x_mask
251
+ return x
252
+
253
+ def remove_weight_norm(self):
254
+ for l in self.convs:
255
+ remove_weight_norm(l)
256
+
257
+
258
+ class Log(nn.Module):
259
+ def forward(self, x, x_mask, reverse=False, **kwargs):
260
+ if not reverse:
261
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262
+ logdet = torch.sum(-y, [1, 2])
263
+ return y, logdet
264
+ else:
265
+ x = torch.exp(x) * x_mask
266
+ return x
267
+
268
+
269
+ class Flip(nn.Module):
270
+ def forward(self, x, *args, reverse=False, **kwargs):
271
+ x = torch.flip(x, [1])
272
+ if not reverse:
273
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274
+ return x, logdet
275
+ else:
276
+ return x
277
+
278
+
279
+ class ElementwiseAffine(nn.Module):
280
+ def __init__(self, channels):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.m = nn.Parameter(torch.zeros(channels,1))
284
+ self.logs = nn.Parameter(torch.zeros(channels,1))
285
+
286
+ def forward(self, x, x_mask, reverse=False, **kwargs):
287
+ if not reverse:
288
+ y = self.m + torch.exp(self.logs) * x
289
+ y = y * x_mask
290
+ logdet = torch.sum(self.logs * x_mask, [1,2])
291
+ return y, logdet
292
+ else:
293
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
294
+ return x
295
+
296
+
297
+ class ResidualCouplingLayer(nn.Module):
298
+ def __init__(self,
299
+ channels,
300
+ hidden_channels,
301
+ kernel_size,
302
+ dilation_rate,
303
+ n_layers,
304
+ p_dropout=0,
305
+ gin_channels=0,
306
+ mean_only=False):
307
+ assert channels % 2 == 0, "channels should be divisible by 2"
308
+ super().__init__()
309
+ self.channels = channels
310
+ self.hidden_channels = hidden_channels
311
+ self.kernel_size = kernel_size
312
+ self.dilation_rate = dilation_rate
313
+ self.n_layers = n_layers
314
+ self.half_channels = channels // 2
315
+ self.mean_only = mean_only
316
+
317
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320
+ self.post.weight.data.zero_()
321
+ self.post.bias.data.zero_()
322
+
323
+ def forward(self, x, x_mask, g=None, reverse=False):
324
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325
+ h = self.pre(x0) * x_mask
326
+ h = self.enc(h, x_mask, g=g)
327
+ stats = self.post(h) * x_mask
328
+ if not self.mean_only:
329
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
330
+ else:
331
+ m = stats
332
+ logs = torch.zeros_like(m)
333
+
334
+ if not reverse:
335
+ x1 = m + x1 * torch.exp(logs) * x_mask
336
+ x = torch.cat([x0, x1], 1)
337
+ logdet = torch.sum(logs, [1,2])
338
+ return x, logdet
339
+ else:
340
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
341
+ x = torch.cat([x0, x1], 1)
342
+ return x
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ numpy==1.22.0
5
+ face_alignment==1.3.0
6
+ imageio==2.19.3
7
+ imageio-ffmpeg==0.4.7
8
+ librosa==0.8.1
9
+ numba
10
+ resampy==0.3.1
11
+ pydub==0.25.1
12
+ scipy
13
+ kornia==0.6.8
14
+ tqdm
15
+ yacs==0.1.8
16
+ pyyaml
17
+ joblib==1.1.0
18
+ scikit-image==0.19.3
19
+ basicsr==1.4.2
20
+ facexlib==0.3.0
21
+ dlib-bin
22
+ gfpgan
23
+ av
24
+ safetensors
25
+ transformers
26
+ webrtcvad==2.0.10
27
+ protobuf
28
+ cpm_kernels
29
+ mdtex2html
30
+ sentencepiece
31
+ accelerate
32
+ loguru
33
+ edge_tts
34
+ altair
35
+ gradio==3.36.1
speaker_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
speaker_encoder/audio.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from speaker_encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ import numpy as np
6
+ import webrtcvad
7
+ import librosa
8
+ import struct
9
+
10
+ int16_max = (2 ** 15) - 1
11
+
12
+
13
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14
+ source_sr: Optional[int] = None):
15
+ """
16
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18
+
19
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20
+ just .wav), either the waveform as a numpy array of floats.
21
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
23
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24
+ this argument will be ignored.
25
+ """
26
+ # Load the wav from disk if needed
27
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28
+ wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29
+ else:
30
+ wav = fpath_or_wav
31
+
32
+ # Resample the wav if needed
33
+ if source_sr is not None and source_sr != sampling_rate:
34
+ wav = librosa.resample(wav, source_sr, sampling_rate)
35
+
36
+ # Apply the preprocessing: normalize volume and shorten long silences
37
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38
+ wav = trim_long_silences(wav)
39
+
40
+ return wav
41
+
42
+
43
+ def wav_to_mel_spectrogram(wav):
44
+ """
45
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46
+ Note: this not a log-mel spectrogram.
47
+ """
48
+ frames = librosa.feature.melspectrogram(
49
+ y=wav,
50
+ sr=sampling_rate,
51
+ n_fft=int(sampling_rate * mel_window_length / 1000),
52
+ hop_length=int(sampling_rate * mel_window_step / 1000),
53
+ n_mels=mel_n_channels
54
+ )
55
+ return frames.astype(np.float32).T
56
+
57
+
58
+ def trim_long_silences(wav):
59
+ """
60
+ Ensures that segments without voice in the waveform remain no longer than a
61
+ threshold determined by the VAD parameters in params.py.
62
+
63
+ :param wav: the raw waveform as a numpy array of floats
64
+ :return: the same waveform with silences trimmed away (length <= original wav length)
65
+ """
66
+ # Compute the voice detection window size
67
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
68
+
69
+ # Trim the end of the audio to have a multiple of the window size
70
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71
+
72
+ # Convert the float waveform to 16-bit mono PCM
73
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74
+
75
+ # Perform voice activation detection
76
+ voice_flags = []
77
+ vad = webrtcvad.Vad(mode=3)
78
+ for window_start in range(0, len(wav), samples_per_window):
79
+ window_end = window_start + samples_per_window
80
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81
+ sample_rate=sampling_rate))
82
+ voice_flags = np.array(voice_flags)
83
+
84
+ # Smooth the voice detection with a moving average
85
+ def moving_average(array, width):
86
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87
+ ret = np.cumsum(array_padded, dtype=float)
88
+ ret[width:] = ret[width:] - ret[:-width]
89
+ return ret[width - 1:] / width
90
+
91
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
92
+ audio_mask = np.round(audio_mask).astype(np.bool)
93
+
94
+ # Dilate the voiced regions
95
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96
+ audio_mask = np.repeat(audio_mask, samples_per_window)
97
+
98
+ return wav[audio_mask == True]
99
+
100
+
101
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102
+ if increase_only and decrease_only:
103
+ raise ValueError("Both increase only and decrease only are set")
104
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106
+ return wav
107
+ return wav * (10 ** (dBFS_change / 20))
speaker_encoder/ckpt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
speaker_encoder/ckpt/pretrained_bak_5805000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
3
+ size 17090379
speaker_encoder/compute_embed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder import inference as encoder
2
+ from multiprocessing.pool import Pool
3
+ from functools import partial
4
+ from pathlib import Path
5
+ # from utils import logmmse
6
+ # from tqdm import tqdm
7
+ # import numpy as np
8
+ # import librosa
9
+
10
+
11
+ def embed_utterance(fpaths, encoder_model_fpath):
12
+ if not encoder.is_loaded():
13
+ encoder.load_model(encoder_model_fpath)
14
+
15
+ # Compute the speaker embedding of the utterance
16
+ wav_fpath, embed_fpath = fpaths
17
+ wav = np.load(wav_fpath)
18
+ wav = encoder.preprocess_wav(wav)
19
+ embed = encoder.embed_utterance(wav)
20
+ np.save(embed_fpath, embed, allow_pickle=False)
21
+
22
+
23
+ def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24
+
25
+ wav_dir = outdir_root.joinpath("audio")
26
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
27
+ assert wav_dir.exists() and metadata_fpath.exists()
28
+ embed_dir = synthesizer_root.joinpath("embeds")
29
+ embed_dir.mkdir(exist_ok=True)
30
+
31
+ # Gather the input wave filepath and the target output embed filepath
32
+ with metadata_fpath.open("r") as metadata_file:
33
+ metadata = [line.split("|") for line in metadata_file]
34
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35
+
36
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37
+ # Embed the utterances in separate threads
38
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39
+ job = Pool(n_processes).imap(func, fpaths)
40
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
speaker_encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
speaker_encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
speaker_encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
speaker_encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
speaker_encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
speaker_encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+ from speaker_encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
speaker_encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
speaker_encoder/hparams.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Mel-filterbank
2
+ mel_window_length = 25 # In milliseconds
3
+ mel_window_step = 10 # In milliseconds
4
+ mel_n_channels = 40
5
+
6
+
7
+ ## Audio
8
+ sampling_rate = 16000
9
+ # Number of spectrogram frames in a partial utterance
10
+ partials_n_frames = 160 # 1600 ms
11
+
12
+
13
+ ## Voice Activation Detection
14
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15
+ # This sets the granularity of the VAD. Should not need to be changed.
16
+ vad_window_length = 30 # In milliseconds
17
+ # Number of frames to average together when performing the moving average smoothing.
18
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
19
+ vad_moving_average_width = 8
20
+ # Maximum number of consecutive silent frames a segment can have.
21
+ vad_max_silence_length = 6
22
+
23
+
24
+ ## Audio volume normalization
25
+ audio_norm_target_dBFS = -30
26
+
27
+
28
+ ## Model parameters
29
+ model_hidden_size = 256
30
+ model_embedding_size = 256
31
+ model_num_layers = 3
speaker_encoder/inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_data import *
2
+ from speaker_encoder.model import SpeakerEncoder
3
+ from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ cbar.set_clim(*color_range)
175
+
176
+ ax.set_xticks([]), ax.set_yticks([])
177
+ ax.set_title(title)
speaker_encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_model import *
2
+ from speaker_encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19
+ hidden_size=model_hidden_size, # 256
20
+ num_layers=model_num_layers, # 3
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
speaker_encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+