fb700 commited on
Commit
8c9c9c7
1 Parent(s): 4202a0b

Upload 171 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .flake8 +1 -1
  2. .gitattributes +18 -1
  3. .gitignore +159 -0
  4. Dockerfile +59 -28
  5. LICENSE +21 -674
  6. README.md +8 -7
  7. app old.py +608 -0
  8. app.py +435 -176
  9. checkpoint/__init__.py +0 -0
  10. checkpoint/freevc-24.pth +3 -0
  11. checkpoints/BFM_Fitting/01_MorphableModel.mat +0 -0
  12. checkpoints/BFM_Fitting/BFM09_model_info.mat +0 -0
  13. checkpoints/BFM_Fitting/BFM_exp_idx.mat +0 -0
  14. checkpoints/BFM_Fitting/BFM_front_idx.mat +0 -0
  15. checkpoints/BFM_Fitting/facemodel_info.mat +0 -0
  16. checkpoints/BFM_Fitting/select_vertex_id.mat +0 -0
  17. checkpoints/BFM_Fitting/similarity_Lm3D_all.mat +0 -0
  18. checkpoints/BFM_Fitting/std_exp.txt +0 -0
  19. checkpoints/shape_predictor_68_face_landmarks.dat +0 -0
  20. commons.py +171 -0
  21. configs/freevc-24.json +54 -0
  22. mel_processing.py +112 -0
  23. models.py +351 -0
  24. modules.py +342 -0
  25. packages.txt +2 -0
  26. requirements.txt +30 -4
  27. speaker_encoder/__init__.py +1 -0
  28. speaker_encoder/audio.py +107 -0
  29. speaker_encoder/ckpt/__init__.py +1 -0
  30. speaker_encoder/ckpt/pretrained_bak_5805000.pt +3 -0
  31. speaker_encoder/compute_embed.py +40 -0
  32. speaker_encoder/config.py +45 -0
  33. speaker_encoder/data_objects/__init__.py +2 -0
  34. speaker_encoder/data_objects/random_cycler.py +37 -0
  35. speaker_encoder/data_objects/speaker.py +40 -0
  36. speaker_encoder/data_objects/speaker_batch.py +12 -0
  37. speaker_encoder/data_objects/speaker_verification_dataset.py +56 -0
  38. speaker_encoder/data_objects/utterance.py +26 -0
  39. speaker_encoder/hparams.py +31 -0
  40. speaker_encoder/inference.py +177 -0
  41. speaker_encoder/model.py +135 -0
  42. speaker_encoder/params_data.py +29 -0
  43. speaker_encoder/params_model.py +11 -0
  44. speaker_encoder/preprocess.py +285 -0
  45. speaker_encoder/train.py +125 -0
  46. speaker_encoder/visualizations.py +178 -0
  47. speaker_encoder/voice_encoder.py +173 -0
  48. src/audio2exp_models/audio2exp.py +41 -0
  49. src/audio2exp_models/networks.py +74 -0
  50. src/audio2pose_models/audio2pose.py +94 -0
.flake8 CHANGED
@@ -18,4 +18,4 @@ exclude =
18
  dist,
19
  .venv
20
  pad*.py
21
- max-complexity = 25
 
18
  dist,
19
  .venv
20
  pad*.py
21
+ max-complexity = 25
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ checkpoints/BFM_Fitting/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/BFM_Fitting/BFM09_model_info.mat filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/facevid2vid_00189-model.pth.tar filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/mapping_00229-model.pth.tar filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text
40
+ examples/driven_audio/chinese_news.wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/driven_audio/deyu.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/driven_audio/eluosi.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/driven_audio/fayu.wav filter=lfs diff=lfs merge=lfs -text
44
+ examples/driven_audio/imagine.wav filter=lfs diff=lfs merge=lfs -text
45
+ examples/driven_audio/japanese.wav filter=lfs diff=lfs merge=lfs -text
46
+ examples/source_image/art_16.png filter=lfs diff=lfs merge=lfs -text
47
+ examples/source_image/art_17.png filter=lfs diff=lfs merge=lfs -text
48
+ examples/source_image/art_3.png filter=lfs diff=lfs merge=lfs -text
49
+ examples/source_image/art_4.png filter=lfs diff=lfs merge=lfs -text
50
+ examples/source_image/art_5.png filter=lfs diff=lfs merge=lfs -text
51
+ examples/source_image/art_8.png filter=lfs diff=lfs merge=lfs -text
52
+ examples/source_image/art_9.png filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ results/
156
+ checkpoints/
157
+ gradio_cached_examples/
158
+ gfpgan/
159
+ start.sh
Dockerfile CHANGED
@@ -1,28 +1,59 @@
1
- # 此Dockerfile适用于“无本地模型”的环境构建,如果需要使用chatglm等本地模型,请参考 docs/Dockerfile+ChatGLM
2
- # 如何构建: 先修改 `config.py`, 然后 docker build -t gpt-academic .
3
- # 如何运行: docker run --rm -it --net=host gpt-academic
4
- FROM python:3.11
5
-
6
- RUN echo '[global]' > /etc/pip.conf && \
7
- echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
8
- echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
9
-
10
-
11
- WORKDIR /gpt
12
-
13
-
14
-
15
-
16
- # 安装依赖
17
- COPY requirements.txt ./
18
- COPY ./docs/gradio-3.32.2-py3-none-any.whl ./docs/gradio-3.32.2-py3-none-any.whl
19
- RUN pip3 install -r requirements.txt
20
- # 装载项目文件
21
- COPY . .
22
- RUN pip3 install -r requirements.txt
23
-
24
- # 可选步骤,用于预热模块
25
- RUN python3 -c 'from check_proxy import warm_up_modules; warm_up_modules()'
26
-
27
- # 启动
28
- CMD ["python3", "-u", "main.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ zip \
8
+ unzip \
9
+ git-lfs \
10
+ wget \
11
+ curl \
12
+ # ffmpeg \
13
+ ffmpeg \
14
+ x264 \
15
+ # python build dependencies \
16
+ build-essential \
17
+ libssl-dev \
18
+ zlib1g-dev \
19
+ libbz2-dev \
20
+ libreadline-dev \
21
+ libsqlite3-dev \
22
+ libncursesw5-dev \
23
+ xz-utils \
24
+ tk-dev \
25
+ libxml2-dev \
26
+ libxmlsec1-dev \
27
+ libffi-dev \
28
+ liblzma-dev && \
29
+ apt-get clean && \
30
+ rm -rf /var/lib/apt/lists/*
31
+
32
+ RUN useradd -m -u 1000 user
33
+ USER user
34
+ ENV HOME=/home/user \
35
+ PATH=/home/user/.local/bin:${PATH}
36
+ WORKDIR ${HOME}/app
37
+
38
+ RUN curl https://pyenv.run | bash
39
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
40
+ ENV PYTHON_VERSION=3.10.9
41
+ RUN pyenv install ${PYTHON_VERSION} && \
42
+ pyenv global ${PYTHON_VERSION} && \
43
+ pyenv rehash && \
44
+ pip install --no-cache-dir -U pip setuptools wheel
45
+
46
+ RUN pip install --no-cache-dir -U torch==1.12.1 torchvision==0.13.1
47
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
48
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
49
+
50
+ COPY --chown=1000 . ${HOME}/app
51
+ RUN ls -a
52
+ ENV PYTHONPATH=${HOME}/app \
53
+ PYTHONUNBUFFERED=1 \
54
+ GRADIO_ALLOW_FLAGGING=never \
55
+ GRADIO_NUM_PORTS=1 \
56
+ GRADIO_SERVER_NAME=0.0.0.0 \
57
+ GRADIO_THEME=huggingface \
58
+ SYSTEM=spaces
59
+ CMD ["python", "app.py"]
LICENSE CHANGED
@@ -1,674 +1,21 @@
1
- GNU GENERAL PUBLIC LICENSE
2
- Version 3, 29 June 2007
3
-
4
- Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
- Everyone is permitted to copy and distribute verbatim copies
6
- of this license document, but changing it is not allowed.
7
-
8
- Preamble
9
-
10
- The GNU General Public License is a free, copyleft license for
11
- software and other kinds of works.
12
-
13
- The licenses for most software and other practical works are designed
14
- to take away your freedom to share and change the works. By contrast,
15
- the GNU General Public License is intended to guarantee your freedom to
16
- share and change all versions of a program--to make sure it remains free
17
- software for all its users. We, the Free Software Foundation, use the
18
- GNU General Public License for most of our software; it applies also to
19
- any other work released this way by its authors. You can apply it to
20
- your programs, too.
21
-
22
- When we speak of free software, we are referring to freedom, not
23
- price. Our General Public Licenses are designed to make sure that you
24
- have the freedom to distribute copies of free software (and charge for
25
- them if you wish), that you receive source code or can get it if you
26
- want it, that you can change the software or use pieces of it in new
27
- free programs, and that you know you can do these things.
28
-
29
- To protect your rights, we need to prevent others from denying you
30
- these rights or asking you to surrender the rights. Therefore, you have
31
- certain responsibilities if you distribute copies of the software, or if
32
- you modify it: responsibilities to respect the freedom of others.
33
-
34
- For example, if you distribute copies of such a program, whether
35
- gratis or for a fee, you must pass on to the recipients the same
36
- freedoms that you received. You must make sure that they, too, receive
37
- or can get the source code. And you must show them these terms so they
38
- know their rights.
39
-
40
- Developers that use the GNU GPL protect your rights with two steps:
41
- (1) assert copyright on the software, and (2) offer you this License
42
- giving you legal permission to copy, distribute and/or modify it.
43
-
44
- For the developers' and authors' protection, the GPL clearly explains
45
- that there is no warranty for this free software. For both users' and
46
- authors' sake, the GPL requires that modified versions be marked as
47
- changed, so that their problems will not be attributed erroneously to
48
- authors of previous versions.
49
-
50
- Some devices are designed to deny users access to install or run
51
- modified versions of the software inside them, although the manufacturer
52
- can do so. This is fundamentally incompatible with the aim of
53
- protecting users' freedom to change the software. The systematic
54
- pattern of such abuse occurs in the area of products for individuals to
55
- use, which is precisely where it is most unacceptable. Therefore, we
56
- have designed this version of the GPL to prohibit the practice for those
57
- products. If such problems arise substantially in other domains, we
58
- stand ready to extend this provision to those domains in future versions
59
- of the GPL, as needed to protect the freedom of users.
60
-
61
- Finally, every program is threatened constantly by software patents.
62
- States should not allow patents to restrict development and use of
63
- software on general-purpose computers, but in those that do, we wish to
64
- avoid the special danger that patents applied to a free program could
65
- make it effectively proprietary. To prevent this, the GPL assures that
66
- patents cannot be used to render the program non-free.
67
-
68
- The precise terms and conditions for copying, distribution and
69
- modification follow.
70
-
71
- TERMS AND CONDITIONS
72
-
73
- 0. Definitions.
74
-
75
- "This License" refers to version 3 of the GNU General Public License.
76
-
77
- "Copyright" also means copyright-like laws that apply to other kinds of
78
- works, such as semiconductor masks.
79
-
80
- "The Program" refers to any copyrightable work licensed under this
81
- License. Each licensee is addressed as "you". "Licensees" and
82
- "recipients" may be individuals or organizations.
83
-
84
- To "modify" a work means to copy from or adapt all or part of the work
85
- in a fashion requiring copyright permission, other than the making of an
86
- exact copy. The resulting work is called a "modified version" of the
87
- earlier work or a work "based on" the earlier work.
88
-
89
- A "covered work" means either the unmodified Program or a work based
90
- on the Program.
91
-
92
- To "propagate" a work means to do anything with it that, without
93
- permission, would make you directly or secondarily liable for
94
- infringement under applicable copyright law, except executing it on a
95
- computer or modifying a private copy. Propagation includes copying,
96
- distribution (with or without modification), making available to the
97
- public, and in some countries other activities as well.
98
-
99
- To "convey" a work means any kind of propagation that enables other
100
- parties to make or receive copies. Mere interaction with a user through
101
- a computer network, with no transfer of a copy, is not conveying.
102
-
103
- An interactive user interface displays "Appropriate Legal Notices"
104
- to the extent that it includes a convenient and prominently visible
105
- feature that (1) displays an appropriate copyright notice, and (2)
106
- tells the user that there is no warranty for the work (except to the
107
- extent that warranties are provided), that licensees may convey the
108
- work under this License, and how to view a copy of this License. If
109
- the interface presents a list of user commands or options, such as a
110
- menu, a prominent item in the list meets this criterion.
111
-
112
- 1. Source Code.
113
-
114
- The "source code" for a work means the preferred form of the work
115
- for making modifications to it. "Object code" means any non-source
116
- form of a work.
117
-
118
- A "Standard Interface" means an interface that either is an official
119
- standard defined by a recognized standards body, or, in the case of
120
- interfaces specified for a particular programming language, one that
121
- is widely used among developers working in that language.
122
-
123
- The "System Libraries" of an executable work include anything, other
124
- than the work as a whole, that (a) is included in the normal form of
125
- packaging a Major Component, but which is not part of that Major
126
- Component, and (b) serves only to enable use of the work with that
127
- Major Component, or to implement a Standard Interface for which an
128
- implementation is available to the public in source code form. A
129
- "Major Component", in this context, means a major essential component
130
- (kernel, window system, and so on) of the specific operating system
131
- (if any) on which the executable work runs, or a compiler used to
132
- produce the work, or an object code interpreter used to run it.
133
-
134
- The "Corresponding Source" for a work in object code form means all
135
- the source code needed to generate, install, and (for an executable
136
- work) run the object code and to modify the work, including scripts to
137
- control those activities. However, it does not include the work's
138
- System Libraries, or general-purpose tools or generally available free
139
- programs which are used unmodified in performing those activities but
140
- which are not part of the work. For example, Corresponding Source
141
- includes interface definition files associated with source files for
142
- the work, and the source code for shared libraries and dynamically
143
- linked subprograms that the work is specifically designed to require,
144
- such as by intimate data communication or control flow between those
145
- subprograms and other parts of the work.
146
-
147
- The Corresponding Source need not include anything that users
148
- can regenerate automatically from other parts of the Corresponding
149
- Source.
150
-
151
- The Corresponding Source for a work in source code form is that
152
- same work.
153
-
154
- 2. Basic Permissions.
155
-
156
- All rights granted under this License are granted for the term of
157
- copyright on the Program, and are irrevocable provided the stated
158
- conditions are met. This License explicitly affirms your unlimited
159
- permission to run the unmodified Program. The output from running a
160
- covered work is covered by this License only if the output, given its
161
- content, constitutes a covered work. This License acknowledges your
162
- rights of fair use or other equivalent, as provided by copyright law.
163
-
164
- You may make, run and propagate covered works that you do not
165
- convey, without conditions so long as your license otherwise remains
166
- in force. You may convey covered works to others for the sole purpose
167
- of having them make modifications exclusively for you, or provide you
168
- with facilities for running those works, provided that you comply with
169
- the terms of this License in conveying all material for which you do
170
- not control copyright. Those thus making or running the covered works
171
- for you must do so exclusively on your behalf, under your direction
172
- and control, on terms that prohibit them from making any copies of
173
- your copyrighted material outside their relationship with you.
174
-
175
- Conveying under any other circumstances is permitted solely under
176
- the conditions stated below. Sublicensing is not allowed; section 10
177
- makes it unnecessary.
178
-
179
- 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
-
181
- No covered work shall be deemed part of an effective technological
182
- measure under any applicable law fulfilling obligations under article
183
- 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
- similar laws prohibiting or restricting circumvention of such
185
- measures.
186
-
187
- When you convey a covered work, you waive any legal power to forbid
188
- circumvention of technological measures to the extent such circumvention
189
- is effected by exercising rights under this License with respect to
190
- the covered work, and you disclaim any intention to limit operation or
191
- modification of the work as a means of enforcing, against the work's
192
- users, your or third parties' legal rights to forbid circumvention of
193
- technological measures.
194
-
195
- 4. Conveying Verbatim Copies.
196
-
197
- You may convey verbatim copies of the Program's source code as you
198
- receive it, in any medium, provided that you conspicuously and
199
- appropriately publish on each copy an appropriate copyright notice;
200
- keep intact all notices stating that this License and any
201
- non-permissive terms added in accord with section 7 apply to the code;
202
- keep intact all notices of the absence of any warranty; and give all
203
- recipients a copy of this License along with the Program.
204
-
205
- You may charge any price or no price for each copy that you convey,
206
- and you may offer support or warranty protection for a fee.
207
-
208
- 5. Conveying Modified Source Versions.
209
-
210
- You may convey a work based on the Program, or the modifications to
211
- produce it from the Program, in the form of source code under the
212
- terms of section 4, provided that you also meet all of these conditions:
213
-
214
- a) The work must carry prominent notices stating that you modified
215
- it, and giving a relevant date.
216
-
217
- b) The work must carry prominent notices stating that it is
218
- released under this License and any conditions added under section
219
- 7. This requirement modifies the requirement in section 4 to
220
- "keep intact all notices".
221
-
222
- c) You must license the entire work, as a whole, under this
223
- License to anyone who comes into possession of a copy. This
224
- License will therefore apply, along with any applicable section 7
225
- additional terms, to the whole of the work, and all its parts,
226
- regardless of how they are packaged. This License gives no
227
- permission to license the work in any other way, but it does not
228
- invalidate such permission if you have separately received it.
229
-
230
- d) If the work has interactive user interfaces, each must display
231
- Appropriate Legal Notices; however, if the Program has interactive
232
- interfaces that do not display Appropriate Legal Notices, your
233
- work need not make them do so.
234
-
235
- A compilation of a covered work with other separate and independent
236
- works, which are not by their nature extensions of the covered work,
237
- and which are not combined with it such as to form a larger program,
238
- in or on a volume of a storage or distribution medium, is called an
239
- "aggregate" if the compilation and its resulting copyright are not
240
- used to limit the access or legal rights of the compilation's users
241
- beyond what the individual works permit. Inclusion of a covered work
242
- in an aggregate does not cause this License to apply to the other
243
- parts of the aggregate.
244
-
245
- 6. Conveying Non-Source Forms.
246
-
247
- You may convey a covered work in object code form under the terms
248
- of sections 4 and 5, provided that you also convey the
249
- machine-readable Corresponding Source under the terms of this License,
250
- in one of these ways:
251
-
252
- a) Convey the object code in, or embodied in, a physical product
253
- (including a physical distribution medium), accompanied by the
254
- Corresponding Source fixed on a durable physical medium
255
- customarily used for software interchange.
256
-
257
- b) Convey the object code in, or embodied in, a physical product
258
- (including a physical distribution medium), accompanied by a
259
- written offer, valid for at least three years and valid for as
260
- long as you offer spare parts or customer support for that product
261
- model, to give anyone who possesses the object code either (1) a
262
- copy of the Corresponding Source for all the software in the
263
- product that is covered by this License, on a durable physical
264
- medium customarily used for software interchange, for a price no
265
- more than your reasonable cost of physically performing this
266
- conveying of source, or (2) access to copy the
267
- Corresponding Source from a network server at no charge.
268
-
269
- c) Convey individual copies of the object code with a copy of the
270
- written offer to provide the Corresponding Source. This
271
- alternative is allowed only occasionally and noncommercially, and
272
- only if you received the object code with such an offer, in accord
273
- with subsection 6b.
274
-
275
- d) Convey the object code by offering access from a designated
276
- place (gratis or for a charge), and offer equivalent access to the
277
- Corresponding Source in the same way through the same place at no
278
- further charge. You need not require recipients to copy the
279
- Corresponding Source along with the object code. If the place to
280
- copy the object code is a network server, the Corresponding Source
281
- may be on a different server (operated by you or a third party)
282
- that supports equivalent copying facilities, provided you maintain
283
- clear directions next to the object code saying where to find the
284
- Corresponding Source. Regardless of what server hosts the
285
- Corresponding Source, you remain obligated to ensure that it is
286
- available for as long as needed to satisfy these requirements.
287
-
288
- e) Convey the object code using peer-to-peer transmission, provided
289
- you inform other peers where the object code and Corresponding
290
- Source of the work are being offered to the general public at no
291
- charge under subsection 6d.
292
-
293
- A separable portion of the object code, whose source code is excluded
294
- from the Corresponding Source as a System Library, need not be
295
- included in conveying the object code work.
296
-
297
- A "User Product" is either (1) a "consumer product", which means any
298
- tangible personal property which is normally used for personal, family,
299
- or household purposes, or (2) anything designed or sold for incorporation
300
- into a dwelling. In determining whether a product is a consumer product,
301
- doubtful cases shall be resolved in favor of coverage. For a particular
302
- product received by a particular user, "normally used" refers to a
303
- typical or common use of that class of product, regardless of the status
304
- of the particular user or of the way in which the particular user
305
- actually uses, or expects or is expected to use, the product. A product
306
- is a consumer product regardless of whether the product has substantial
307
- commercial, industrial or non-consumer uses, unless such uses represent
308
- the only significant mode of use of the product.
309
-
310
- "Installation Information" for a User Product means any methods,
311
- procedures, authorization keys, or other information required to install
312
- and execute modified versions of a covered work in that User Product from
313
- a modified version of its Corresponding Source. The information must
314
- suffice to ensure that the continued functioning of the modified object
315
- code is in no case prevented or interfered with solely because
316
- modification has been made.
317
-
318
- If you convey an object code work under this section in, or with, or
319
- specifically for use in, a User Product, and the conveying occurs as
320
- part of a transaction in which the right of possession and use of the
321
- User Product is transferred to the recipient in perpetuity or for a
322
- fixed term (regardless of how the transaction is characterized), the
323
- Corresponding Source conveyed under this section must be accompanied
324
- by the Installation Information. But this requirement does not apply
325
- if neither you nor any third party retains the ability to install
326
- modified object code on the User Product (for example, the work has
327
- been installed in ROM).
328
-
329
- The requirement to provide Installation Information does not include a
330
- requirement to continue to provide support service, warranty, or updates
331
- for a work that has been modified or installed by the recipient, or for
332
- the User Product in which it has been modified or installed. Access to a
333
- network may be denied when the modification itself materially and
334
- adversely affects the operation of the network or violates the rules and
335
- protocols for communication across the network.
336
-
337
- Corresponding Source conveyed, and Installation Information provided,
338
- in accord with this section must be in a format that is publicly
339
- documented (and with an implementation available to the public in
340
- source code form), and must require no special password or key for
341
- unpacking, reading or copying.
342
-
343
- 7. Additional Terms.
344
-
345
- "Additional permissions" are terms that supplement the terms of this
346
- License by making exceptions from one or more of its conditions.
347
- Additional permissions that are applicable to the entire Program shall
348
- be treated as though they were included in this License, to the extent
349
- that they are valid under applicable law. If additional permissions
350
- apply only to part of the Program, that part may be used separately
351
- under those permissions, but the entire Program remains governed by
352
- this License without regard to the additional permissions.
353
-
354
- When you convey a copy of a covered work, you may at your option
355
- remove any additional permissions from that copy, or from any part of
356
- it. (Additional permissions may be written to require their own
357
- removal in certain cases when you modify the work.) You may place
358
- additional permissions on material, added by you to a covered work,
359
- for which you have or can give appropriate copyright permission.
360
-
361
- Notwithstanding any other provision of this License, for material you
362
- add to a covered work, you may (if authorized by the copyright holders of
363
- that material) supplement the terms of this License with terms:
364
-
365
- a) Disclaiming warranty or limiting liability differently from the
366
- terms of sections 15 and 16 of this License; or
367
-
368
- b) Requiring preservation of specified reasonable legal notices or
369
- author attributions in that material or in the Appropriate Legal
370
- Notices displayed by works containing it; or
371
-
372
- c) Prohibiting misrepresentation of the origin of that material, or
373
- requiring that modified versions of such material be marked in
374
- reasonable ways as different from the original version; or
375
-
376
- d) Limiting the use for publicity purposes of names of licensors or
377
- authors of the material; or
378
-
379
- e) Declining to grant rights under trademark law for use of some
380
- trade names, trademarks, or service marks; or
381
-
382
- f) Requiring indemnification of licensors and authors of that
383
- material by anyone who conveys the material (or modified versions of
384
- it) with contractual assumptions of liability to the recipient, for
385
- any liability that these contractual assumptions directly impose on
386
- those licensors and authors.
387
-
388
- All other non-permissive additional terms are considered "further
389
- restrictions" within the meaning of section 10. If the Program as you
390
- received it, or any part of it, contains a notice stating that it is
391
- governed by this License along with a term that is a further
392
- restriction, you may remove that term. If a license document contains
393
- a further restriction but permits relicensing or conveying under this
394
- License, you may add to a covered work material governed by the terms
395
- of that license document, provided that the further restriction does
396
- not survive such relicensing or conveying.
397
-
398
- If you add terms to a covered work in accord with this section, you
399
- must place, in the relevant source files, a statement of the
400
- additional terms that apply to those files, or a notice indicating
401
- where to find the applicable terms.
402
-
403
- Additional terms, permissive or non-permissive, may be stated in the
404
- form of a separately written license, or stated as exceptions;
405
- the above requirements apply either way.
406
-
407
- 8. Termination.
408
-
409
- You may not propagate or modify a covered work except as expressly
410
- provided under this License. Any attempt otherwise to propagate or
411
- modify it is void, and will automatically terminate your rights under
412
- this License (including any patent licenses granted under the third
413
- paragraph of section 11).
414
-
415
- However, if you cease all violation of this License, then your
416
- license from a particular copyright holder is reinstated (a)
417
- provisionally, unless and until the copyright holder explicitly and
418
- finally terminates your license, and (b) permanently, if the copyright
419
- holder fails to notify you of the violation by some reasonable means
420
- prior to 60 days after the cessation.
421
-
422
- Moreover, your license from a particular copyright holder is
423
- reinstated permanently if the copyright holder notifies you of the
424
- violation by some reasonable means, this is the first time you have
425
- received notice of violation of this License (for any work) from that
426
- copyright holder, and you cure the violation prior to 30 days after
427
- your receipt of the notice.
428
-
429
- Termination of your rights under this section does not terminate the
430
- licenses of parties who have received copies or rights from you under
431
- this License. If your rights have been terminated and not permanently
432
- reinstated, you do not qualify to receive new licenses for the same
433
- material under section 10.
434
-
435
- 9. Acceptance Not Required for Having Copies.
436
-
437
- You are not required to accept this License in order to receive or
438
- run a copy of the Program. Ancillary propagation of a covered work
439
- occurring solely as a consequence of using peer-to-peer transmission
440
- to receive a copy likewise does not require acceptance. However,
441
- nothing other than this License grants you permission to propagate or
442
- modify any covered work. These actions infringe copyright if you do
443
- not accept this License. Therefore, by modifying or propagating a
444
- covered work, you indicate your acceptance of this License to do so.
445
-
446
- 10. Automatic Licensing of Downstream Recipients.
447
-
448
- Each time you convey a covered work, the recipient automatically
449
- receives a license from the original licensors, to run, modify and
450
- propagate that work, subject to this License. You are not responsible
451
- for enforcing compliance by third parties with this License.
452
-
453
- An "entity transaction" is a transaction transferring control of an
454
- organization, or substantially all assets of one, or subdividing an
455
- organization, or merging organizations. If propagation of a covered
456
- work results from an entity transaction, each party to that
457
- transaction who receives a copy of the work also receives whatever
458
- licenses to the work the party's predecessor in interest had or could
459
- give under the previous paragraph, plus a right to possession of the
460
- Corresponding Source of the work from the predecessor in interest, if
461
- the predecessor has it or can get it with reasonable efforts.
462
-
463
- You may not impose any further restrictions on the exercise of the
464
- rights granted or affirmed under this License. For example, you may
465
- not impose a license fee, royalty, or other charge for exercise of
466
- rights granted under this License, and you may not initiate litigation
467
- (including a cross-claim or counterclaim in a lawsuit) alleging that
468
- any patent claim is infringed by making, using, selling, offering for
469
- sale, or importing the Program or any portion of it.
470
-
471
- 11. Patents.
472
-
473
- A "contributor" is a copyright holder who authorizes use under this
474
- License of the Program or a work on which the Program is based. The
475
- work thus licensed is called the contributor's "contributor version".
476
-
477
- A contributor's "essential patent claims" are all patent claims
478
- owned or controlled by the contributor, whether already acquired or
479
- hereafter acquired, that would be infringed by some manner, permitted
480
- by this License, of making, using, or selling its contributor version,
481
- but do not include claims that would be infringed only as a
482
- consequence of further modification of the contributor version. For
483
- purposes of this definition, "control" includes the right to grant
484
- patent sublicenses in a manner consistent with the requirements of
485
- this License.
486
-
487
- Each contributor grants you a non-exclusive, worldwide, royalty-free
488
- patent license under the contributor's essential patent claims, to
489
- make, use, sell, offer for sale, import and otherwise run, modify and
490
- propagate the contents of its contributor version.
491
-
492
- In the following three paragraphs, a "patent license" is any express
493
- agreement or commitment, however denominated, not to enforce a patent
494
- (such as an express permission to practice a patent or covenant not to
495
- sue for patent infringement). To "grant" such a patent license to a
496
- party means to make such an agreement or commitment not to enforce a
497
- patent against the party.
498
-
499
- If you convey a covered work, knowingly relying on a patent license,
500
- and the Corresponding Source of the work is not available for anyone
501
- to copy, free of charge and under the terms of this License, through a
502
- publicly available network server or other readily accessible means,
503
- then you must either (1) cause the Corresponding Source to be so
504
- available, or (2) arrange to deprive yourself of the benefit of the
505
- patent license for this particular work, or (3) arrange, in a manner
506
- consistent with the requirements of this License, to extend the patent
507
- license to downstream recipients. "Knowingly relying" means you have
508
- actual knowledge that, but for the patent license, your conveying the
509
- covered work in a country, or your recipient's use of the covered work
510
- in a country, would infringe one or more identifiable patents in that
511
- country that you have reason to believe are valid.
512
-
513
- If, pursuant to or in connection with a single transaction or
514
- arrangement, you convey, or propagate by procuring conveyance of, a
515
- covered work, and grant a patent license to some of the parties
516
- receiving the covered work authorizing them to use, propagate, modify
517
- or convey a specific copy of the covered work, then the patent license
518
- you grant is automatically extended to all recipients of the covered
519
- work and works based on it.
520
-
521
- A patent license is "discriminatory" if it does not include within
522
- the scope of its coverage, prohibits the exercise of, or is
523
- conditioned on the non-exercise of one or more of the rights that are
524
- specifically granted under this License. You may not convey a covered
525
- work if you are a party to an arrangement with a third party that is
526
- in the business of distributing software, under which you make payment
527
- to the third party based on the extent of your activity of conveying
528
- the work, and under which the third party grants, to any of the
529
- parties who would receive the covered work from you, a discriminatory
530
- patent license (a) in connection with copies of the covered work
531
- conveyed by you (or copies made from those copies), or (b) primarily
532
- for and in connection with specific products or compilations that
533
- contain the covered work, unless you entered into that arrangement,
534
- or that patent license was granted, prior to 28 March 2007.
535
-
536
- Nothing in this License shall be construed as excluding or limiting
537
- any implied license or other defenses to infringement that may
538
- otherwise be available to you under applicable patent law.
539
-
540
- 12. No Surrender of Others' Freedom.
541
-
542
- If conditions are imposed on you (whether by court order, agreement or
543
- otherwise) that contradict the conditions of this License, they do not
544
- excuse you from the conditions of this License. If you cannot convey a
545
- covered work so as to satisfy simultaneously your obligations under this
546
- License and any other pertinent obligations, then as a consequence you may
547
- not convey it at all. For example, if you agree to terms that obligate you
548
- to collect a royalty for further conveying from those to whom you convey
549
- the Program, the only way you could satisfy both those terms and this
550
- License would be to refrain entirely from conveying the Program.
551
-
552
- 13. Use with the GNU Affero General Public License.
553
-
554
- Notwithstanding any other provision of this License, you have
555
- permission to link or combine any covered work with a work licensed
556
- under version 3 of the GNU Affero General Public License into a single
557
- combined work, and to convey the resulting work. The terms of this
558
- License will continue to apply to the part which is the covered work,
559
- but the special requirements of the GNU Affero General Public License,
560
- section 13, concerning interaction through a network will apply to the
561
- combination as such.
562
-
563
- 14. Revised Versions of this License.
564
-
565
- The Free Software Foundation may publish revised and/or new versions of
566
- the GNU General Public License from time to time. Such new versions will
567
- be similar in spirit to the present version, but may differ in detail to
568
- address new problems or concerns.
569
-
570
- Each version is given a distinguishing version number. If the
571
- Program specifies that a certain numbered version of the GNU General
572
- Public License "or any later version" applies to it, you have the
573
- option of following the terms and conditions either of that numbered
574
- version or of any later version published by the Free Software
575
- Foundation. If the Program does not specify a version number of the
576
- GNU General Public License, you may choose any version ever published
577
- by the Free Software Foundation.
578
-
579
- If the Program specifies that a proxy can decide which future
580
- versions of the GNU General Public License can be used, that proxy's
581
- public statement of acceptance of a version permanently authorizes you
582
- to choose that version for the Program.
583
-
584
- Later license versions may give you additional or different
585
- permissions. However, no additional obligations are imposed on any
586
- author or copyright holder as a result of your choosing to follow a
587
- later version.
588
-
589
- 15. Disclaimer of Warranty.
590
-
591
- THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
- APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
- HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
- OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
- THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
- PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
- IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
- ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
-
600
- 16. Limitation of Liability.
601
-
602
- IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
- WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
- THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
- GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
- USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
- DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
- PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
- EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
- SUCH DAMAGES.
611
-
612
- 17. Interpretation of Sections 15 and 16.
613
-
614
- If the disclaimer of warranty and limitation of liability provided
615
- above cannot be given local legal effect according to their terms,
616
- reviewing courts shall apply local law that most closely approximates
617
- an absolute waiver of all civil liability in connection with the
618
- Program, unless a warranty or assumption of liability accompanies a
619
- copy of the Program in return for a fee.
620
-
621
- END OF TERMS AND CONDITIONS
622
-
623
- How to Apply These Terms to Your New Programs
624
-
625
- If you develop a new program, and you want it to be of the greatest
626
- possible use to the public, the best way to achieve this is to make it
627
- free software which everyone can redistribute and change under these terms.
628
-
629
- To do so, attach the following notices to the program. It is safest
630
- to attach them to the start of each source file to most effectively
631
- state the exclusion of warranty; and each file should have at least
632
- the "copyright" line and a pointer to where the full notice is found.
633
-
634
- <one line to give the program's name and a brief idea of what it does.>
635
- Copyright (C) <year> <name of author>
636
-
637
- This program is free software: you can redistribute it and/or modify
638
- it under the terms of the GNU General Public License as published by
639
- the Free Software Foundation, either version 3 of the License, or
640
- (at your option) any later version.
641
-
642
- This program is distributed in the hope that it will be useful,
643
- but WITHOUT ANY WARRANTY; without even the implied warranty of
644
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
- GNU General Public License for more details.
646
-
647
- You should have received a copy of the GNU General Public License
648
- along with this program. If not, see <https://www.gnu.org/licenses/>.
649
-
650
- Also add information on how to contact you by electronic and paper mail.
651
-
652
- If the program does terminal interaction, make it output a short
653
- notice like this when it starts in an interactive mode:
654
-
655
- <program> Copyright (C) <year> <name of author>
656
- This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
- This is free software, and you are welcome to redistribute it
658
- under certain conditions; type `show c' for details.
659
-
660
- The hypothetical commands `show w' and `show c' should show the appropriate
661
- parts of the General Public License. Of course, your program's commands
662
- might be different; for a GUI interface, you would use an "about box".
663
-
664
- You should also get your employer (if you work as a programmer) or school,
665
- if any, to sign a "copyright disclaimer" for the program, if necessary.
666
- For more information on this, and how to apply and follow the GNU GPL, see
667
- <https://www.gnu.org/licenses/>.
668
-
669
- The GNU General Public License does not permit incorporating your program
670
- into proprietary programs. If your program is a subroutine library, you
671
- may consider it more useful to permit linking proprietary applications with
672
- the library. If this is what you want to do, use the GNU Lesser General
673
- Public License instead of this License. But first, please read
674
- <https://www.gnu.org/licenses/why-not-lgpl.html>.
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Tencent AI Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Bofan Chatglm Fitness RLHF lora
3
- emoji: 🍀
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.38.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
+ title: ChatGLM2-SadTalker
3
+ emoji: 📺
4
+ colorFrom: purple
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.23.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app old.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import tempfile
3
+ import gradio as gr
4
+ from src.gradio_demo import SadTalker
5
+ # from src.utils.text2speech import TTSTalker
6
+ from huggingface_hub import snapshot_download
7
+
8
+ import torch
9
+ import librosa
10
+ from scipy.io.wavfile import write
11
+ from transformers import WavLMModel
12
+
13
+ import utils
14
+ from models import SynthesizerTrn
15
+ from mel_processing import mel_spectrogram_torch
16
+ from speaker_encoder.voice_encoder import SpeakerEncoder
17
+
18
+ import time
19
+ from textwrap import dedent
20
+
21
+ import mdtex2html
22
+ from loguru import logger
23
+ from transformers import AutoModel, AutoTokenizer
24
+
25
+ from tts_voice import tts_order_voice
26
+ import edge_tts
27
+ import tempfile
28
+ import anyio
29
+
30
+
31
+ def get_source_image(image):
32
+ return image
33
+
34
+ try:
35
+ import webui # in webui
36
+ in_webui = True
37
+ except:
38
+ in_webui = False
39
+
40
+
41
+ def toggle_audio_file(choice):
42
+ if choice == False:
43
+ return gr.update(visible=True), gr.update(visible=False)
44
+ else:
45
+ return gr.update(visible=False), gr.update(visible=True)
46
+
47
+ def ref_video_fn(path_of_ref_video):
48
+ if path_of_ref_video is not None:
49
+ return gr.update(value=True)
50
+ else:
51
+ return gr.update(value=False)
52
+
53
+ def download_model():
54
+ REPO_ID = 'vinthony/SadTalker-V002rc'
55
+ snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
56
+
57
+ def sadtalker_demo():
58
+
59
+ download_model()
60
+
61
+ sad_talker = SadTalker(lazy_load=True)
62
+ # tts_talker = TTSTalker()
63
+
64
+ download_model()
65
+ sad_talker = SadTalker(lazy_load=True)
66
+
67
+
68
+ # ChatGLM2 & FreeVC
69
+
70
+ '''
71
+ def get_wavlm():
72
+ os.system('gdown https://drive.google.com/uc?id=12-cB34qCTvByWT-QtOcZaqwwO21FLSqU')
73
+ shutil.move('WavLM-Large.pt', 'wavlm')
74
+ '''
75
+
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+
78
+ smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
79
+
80
+ print("Loading FreeVC(24k)...")
81
+ hps = utils.get_hparams_from_file("configs/freevc-24.json")
82
+ freevc_24 = SynthesizerTrn(
83
+ hps.data.filter_length // 2 + 1,
84
+ hps.train.segment_size // hps.data.hop_length,
85
+ **hps.model).to(device)
86
+ _ = freevc_24.eval()
87
+ _ = utils.load_checkpoint("checkpoint/freevc-24.pth", freevc_24, None)
88
+
89
+ print("Loading WavLM for content...")
90
+ cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
91
+
92
+ def convert(model, src, tgt):
93
+ with torch.no_grad():
94
+ # tgt
95
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
96
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
97
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
98
+ g_tgt = smodel.embed_utterance(wav_tgt)
99
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
100
+ else:
101
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
102
+ mel_tgt = mel_spectrogram_torch(
103
+ wav_tgt,
104
+ hps.data.filter_length,
105
+ hps.data.n_mel_channels,
106
+ hps.data.sampling_rate,
107
+ hps.data.hop_length,
108
+ hps.data.win_length,
109
+ hps.data.mel_fmin,
110
+ hps.data.mel_fmax
111
+ )
112
+ # src
113
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
114
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
115
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
116
+ # infer
117
+ if model == "FreeVC":
118
+ audio = freevc.infer(c, g=g_tgt)
119
+ elif model == "FreeVC-s":
120
+ audio = freevc_s.infer(c, mel=mel_tgt)
121
+ else:
122
+ audio = freevc_24.infer(c, g=g_tgt)
123
+ audio = audio[0][0].data.cpu().float().numpy()
124
+ if model == "FreeVC" or model == "FreeVC-s":
125
+ write("out.wav", hps.data.sampling_rate, audio)
126
+ else:
127
+ write("out.wav", 24000, audio)
128
+ out = "out.wav"
129
+ return out
130
+
131
+ # GLM2
132
+
133
+ language_dict = tts_order_voice
134
+
135
+ # fix timezone in Linux
136
+ os.environ["TZ"] = "Asia/Shanghai"
137
+ try:
138
+ time.tzset() # type: ignore # pylint: disable=no-member
139
+ except Exception:
140
+ # Windows
141
+ logger.warning("Windows, cant run time.tzset()")
142
+
143
+ # model_name = "THUDM/chatglm2-6b"
144
+ model_name = "THUDM/chatglm2-6b-int4"
145
+
146
+ RETRY_FLAG = False
147
+
148
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
149
+
150
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
151
+
152
+ # 4/8 bit
153
+ # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
154
+
155
+ has_cuda = torch.cuda.is_available()
156
+
157
+ # has_cuda = False # force cpu
158
+
159
+ if has_cuda:
160
+ model_glm = (
161
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
162
+ ) # 3.92G
163
+ else:
164
+ model_glm = AutoModel.from_pretrained(
165
+ model_name, trust_remote_code=True
166
+ ).float() # .float() .half().float()
167
+
168
+ model_glm = model_glm.eval()
169
+
170
+ _ = """Override Chatbot.postprocess"""
171
+
172
+
173
+ def postprocess(self, y):
174
+ if y is None:
175
+ return []
176
+ for i, (message, response) in enumerate(y):
177
+ y[i] = (
178
+ None if message is None else mdtex2html.convert((message)),
179
+ None if response is None else mdtex2html.convert(response),
180
+ )
181
+ return y
182
+
183
+
184
+ gr.Chatbot.postprocess = postprocess
185
+
186
+
187
+ def parse_text(text):
188
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
189
+ lines = text.split("\n")
190
+ lines = [line for line in lines if line != ""]
191
+ count = 0
192
+ for i, line in enumerate(lines):
193
+ if "```" in line:
194
+ count += 1
195
+ items = line.split("`")
196
+ if count % 2 == 1:
197
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
198
+ else:
199
+ lines[i] = "<br></code></pre>"
200
+ else:
201
+ if i > 0:
202
+ if count % 2 == 1:
203
+ line = line.replace("`", r"\`")
204
+ line = line.replace("<", "&lt;")
205
+ line = line.replace(">", "&gt;")
206
+ line = line.replace(" ", "&nbsp;")
207
+ line = line.replace("*", "&ast;")
208
+ line = line.replace("_", "&lowbar;")
209
+ line = line.replace("-", "&#45;")
210
+ line = line.replace(".", "&#46;")
211
+ line = line.replace("!", "&#33;")
212
+ line = line.replace("(", "&#40;")
213
+ line = line.replace(")", "&#41;")
214
+ line = line.replace("$", "&#36;")
215
+ lines[i] = "<br>" + line
216
+ text = "".join(lines)
217
+ return text
218
+
219
+
220
+ def predict(
221
+ RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
222
+ ):
223
+ try:
224
+ chatbot.append((parse_text(input), ""))
225
+ except Exception as exc:
226
+ logger.error(exc)
227
+ logger.debug(f"{chatbot=}")
228
+ _ = """
229
+ if chatbot:
230
+ chatbot[-1] = (parse_text(input), str(exc))
231
+ yield chatbot, history, past_key_values
232
+ # """
233
+ yield chatbot, history, past_key_values
234
+
235
+ for response, history, past_key_values in model_glm.stream_chat(
236
+ tokenizer,
237
+ input,
238
+ history,
239
+ past_key_values=past_key_values,
240
+ return_past_key_values=True,
241
+ max_length=max_length,
242
+ top_p=top_p,
243
+ temperature=temperature,
244
+ ):
245
+ chatbot[-1] = (parse_text(input), parse_text(response))
246
+ # chatbot[-1][-1] = parse_text(response)
247
+
248
+ yield chatbot, history, past_key_values, parse_text(response)
249
+
250
+
251
+ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
252
+ if max_length < 10:
253
+ max_length = 4096
254
+ if top_p < 0.1 or top_p > 1:
255
+ top_p = 0.85
256
+ if temperature <= 0 or temperature > 1:
257
+ temperature = 0.01
258
+ try:
259
+ res, _ = model_glm.chat(
260
+ tokenizer,
261
+ input,
262
+ history=[],
263
+ past_key_values=None,
264
+ max_length=max_length,
265
+ top_p=top_p,
266
+ temperature=temperature,
267
+ )
268
+ # logger.debug(f"{res=} \n{_=}")
269
+ except Exception as exc:
270
+ logger.error(f"{exc=}")
271
+ res = str(exc)
272
+
273
+ return res
274
+
275
+
276
+ def reset_user_input():
277
+ return gr.update(value="")
278
+
279
+
280
+ def reset_state():
281
+ return [], [], None, ""
282
+
283
+
284
+ # Delete last turn
285
+ def delete_last_turn(chat, history):
286
+ if chat and history:
287
+ chat.pop(-1)
288
+ history.pop(-1)
289
+ return chat, history
290
+
291
+
292
+ # Regenerate response
293
+ def retry_last_answer(
294
+ user_input, chatbot, max_length, top_p, temperature, history, past_key_values
295
+ ):
296
+ if chatbot and history:
297
+ # Removing the previous conversation from chat
298
+ chatbot.pop(-1)
299
+ # Setting up a flag to capture a retry
300
+ RETRY_FLAG = True
301
+ # Getting last message from user
302
+ user_input = history[-1][0]
303
+ # Removing bot response from the history
304
+ history.pop(-1)
305
+
306
+ yield from predict(
307
+ RETRY_FLAG, # type: ignore
308
+ user_input,
309
+ chatbot,
310
+ max_length,
311
+ top_p,
312
+ temperature,
313
+ history,
314
+ past_key_values,
315
+ )
316
+
317
+ # print
318
+
319
+ def print(text):
320
+ return text
321
+
322
+ # TTS
323
+
324
+ async def text_to_speech_edge(text, language_code):
325
+ voice = language_dict[language_code]
326
+ communicate = edge_tts.Communicate(text, voice)
327
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
328
+ tmp_path = tmp_file.name
329
+
330
+ await communicate.save(tmp_path)
331
+
332
+ return tmp_path
333
+
334
+
335
+ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm"), analytics_enabled=False) as demo:
336
+ gr.HTML("<center>"
337
+ "<h1>📺💕🎶 - ChatGLM2+声音克隆+视频对话:和喜欢的角色畅所欲言吧!</h1>"
338
+ "</center>")
339
+ gr.Markdown("## <center>🥳 - ChatGLM2+FreeVC+SadTalker,为您打造沉浸式的视频对话体验,支持中英双语</center>")
340
+ gr.Markdown("## <center>🌊 - 更多精彩应用,尽在[滔滔AI](http://www.talktalkai.com);滔滔AI,为爱滔滔!💕</center>")
341
+ gr.Markdown("### <center>⭐ - 如果您喜欢这个程序,欢迎给我的[GitHub项目](https://github.com/KevinWang676/ChatGLM2-Voice-Cloning)点赞支持!</center>")
342
+
343
+ with gr.Tab("🍻 - ChatGLM2聊天区"):
344
+ with gr.Accordion("📒 相关信息", open=False):
345
+ _ = f""" ChatGLM2的可选参数信息:
346
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
347
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
348
+ * Top P controls dynamic vocabulary selection based on context.\n
349
+ 如果您想让ChatGLM2进行角色扮演并与之对话,请先输入恰当的提示词,如“请你扮演成动漫角色蜡笔小新并和我进行对话”;您也可以为ChatGLM2提供自定义的角色设定\n
350
+ 当您使用声音克隆功能时,请先在此程序的对应位置上传一段您喜欢的音频
351
+ """
352
+ gr.Markdown(dedent(_))
353
+ chatbot = gr.Chatbot(height=300)
354
+ with gr.Row():
355
+ with gr.Column(scale=4):
356
+ with gr.Column(scale=12):
357
+ user_input = gr.Textbox(
358
+ label="请在此处和GLM2聊天 (按回车键即可发送)",
359
+ placeholder="聊点什么吧",
360
+ )
361
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
362
+ with gr.Column(min_width=32, scale=1):
363
+ with gr.Row():
364
+ submitBtn = gr.Button("开始和GLM2交流吧", variant="primary")
365
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
366
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
367
+
368
+ with gr.Accordion("🔧 更多设置", open=False):
369
+ with gr.Row():
370
+ emptyBtn = gr.Button("清空所有聊天记录")
371
+ max_length = gr.Slider(
372
+ 0,
373
+ 32768,
374
+ value=8192,
375
+ step=1.0,
376
+ label="Maximum length",
377
+ interactive=True,
378
+ )
379
+ top_p = gr.Slider(
380
+ 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
381
+ )
382
+ temperature = gr.Slider(
383
+ 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
384
+ )
385
+
386
+
387
+ with gr.Row():
388
+ test1 = gr.Textbox(label="GLM2的最新回答 (可编辑)", lines = 3)
389
+ with gr.Column():
390
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
391
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
392
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
393
+
394
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
395
+
396
+ with gr.Row():
397
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
398
+ audio1 = output_audio
399
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
400
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
401
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
402
+
403
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
404
+
405
+ history = gr.State([])
406
+ past_key_values = gr.State(None)
407
+
408
+ user_input.submit(
409
+ predict,
410
+ [
411
+ RETRY_FLAG,
412
+ user_input,
413
+ chatbot,
414
+ max_length,
415
+ top_p,
416
+ temperature,
417
+ history,
418
+ past_key_values,
419
+ ],
420
+ [chatbot, history, past_key_values, test1],
421
+ show_progress="full",
422
+ )
423
+ submitBtn.click(
424
+ predict,
425
+ [
426
+ RETRY_FLAG,
427
+ user_input,
428
+ chatbot,
429
+ max_length,
430
+ top_p,
431
+ temperature,
432
+ history,
433
+ past_key_values,
434
+ ],
435
+ [chatbot, history, past_key_values, test1],
436
+ show_progress="full",
437
+ api_name="predict",
438
+ )
439
+ submitBtn.click(reset_user_input, [], [user_input])
440
+
441
+ emptyBtn.click(
442
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
443
+ )
444
+
445
+ retryBtn.click(
446
+ retry_last_answer,
447
+ inputs=[
448
+ user_input,
449
+ chatbot,
450
+ max_length,
451
+ top_p,
452
+ temperature,
453
+ history,
454
+ past_key_values,
455
+ ],
456
+ # outputs = [chatbot, history, last_user_message, user_message]
457
+ outputs=[chatbot, history, past_key_values, test1],
458
+ )
459
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
460
+
461
+ with gr.Accordion("📔 提示词示例", open=False):
462
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
463
+ examples = gr.Examples(
464
+ examples=[
465
+ ["Explain the plot of Cinderella in a sentence."],
466
+ [
467
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
468
+ ],
469
+ ["What are some common mistakes to avoid when writing code?"],
470
+ ["Build a prompt to generate a beautiful portrait of a horse"],
471
+ ["Suggest four metaphors to describe the benefits of AI"],
472
+ ["Write a pop song about leaving home for the sandy beaches."],
473
+ ["Write a summary demonstrating my ability to tame lions"],
474
+ ["鲁迅和周树人什么关系"],
475
+ ["从前有一头牛,这头牛后面有什么?"],
476
+ ["正无穷大加一大于正无穷大吗?"],
477
+ ["正无穷大加正无穷大大于正无穷大吗?"],
478
+ ["-2的平方根等于什么"],
479
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
480
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
481
+ ["鲁迅和周树人什么关系 用英文回答"],
482
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
483
+ [f"{etext} 翻成中文,列出3个版本"],
484
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
485
+ ["js 判断一个数是不是质数"],
486
+ ["js 实现python 的 range(10)"],
487
+ ["js 实现python 的 [*(range(10)]"],
488
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
489
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
490
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
491
+ ],
492
+ inputs=[user_input],
493
+ examples_per_page=30,
494
+ )
495
+
496
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
497
+ input_text = gr.Text()
498
+ tr_btn = gr.Button("Go", variant="primary")
499
+ out_text = gr.Text()
500
+ tr_btn.click(
501
+ trans_api,
502
+ [input_text, max_length, top_p, temperature],
503
+ out_text,
504
+ # show_progress="full",
505
+ api_name="tr",
506
+ )
507
+ _ = """
508
+ input_text.submit(
509
+ trans_api,
510
+ [input_text, max_length, top_p, temperature],
511
+ out_text,
512
+ show_progress="full",
513
+ api_name="tr1",
514
+ )
515
+ # """
516
+ with gr.Tab("📺 - 视频聊天区"):
517
+ with gr.Row().style(equal_height=False):
518
+ with gr.Column(variant='panel'):
519
+ with gr.Tabs(elem_id="sadtalker_source_image"):
520
+ with gr.TabItem('图片上传'):
521
+ with gr.Row():
522
+ source_image = gr.Image(label="请上传一张您喜欢角色的图片", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
523
+
524
+
525
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
526
+ with gr.TabItem('💡您还可以将视频下载到本地'):
527
+
528
+ with gr.Row():
529
+ driven_audio = audio_cloned
530
+ driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)
531
+
532
+ with gr.Column():
533
+ use_idle_mode = gr.Checkbox(label="Use Idle Animation", visible=False)
534
+ length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.", visible=False)
535
+ use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo
536
+
537
+ with gr.Row():
538
+ ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref", visible=False).style(width=512)
539
+
540
+ with gr.Column():
541
+ use_ref_video = gr.Checkbox(label="Use Reference Video", visible=False)
542
+ ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))", visible=False)
543
+
544
+ ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo
545
+
546
+
547
+ with gr.Column(variant='panel'):
548
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
549
+ with gr.TabItem('视频设置'):
550
+ with gr.Column(variant='panel'):
551
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
552
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
553
+ with gr.Row():
554
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0, visible=False) #
555
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1, visible=False) #
556
+ blink_every = gr.Checkbox(label="use eye blink", value=True, visible=False)
557
+
558
+ with gr.Row():
559
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?", visible=False) #
560
+ preprocess_type = gr.Radio(['crop', 'full'], value='crop', label='是否聚焦角色面部', info="crop:视频会聚焦角色面部;full:视频会显示图片全貌")
561
+
562
+ with gr.Row():
563
+ is_still_mode = gr.Checkbox(label="静态模式 (开启静态模式,角色的面部动作会减少;默认开启)", value=True)
564
+ facerender = gr.Radio(['facevid2vid','pirender'], value='facevid2vid', label='facerender', info="which face render?", visible=False)
565
+
566
+ with gr.Row():
567
+ batch_size = gr.Slider(label="Batch size (数值越大,生成速度越快;若显卡性能好,可增大数值)", step=1, maximum=32, value=2)
568
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer", value=True, visible=False)
569
+
570
+ submit = gr.Button('开始视频聊天吧', elem_id="sadtalker_generate", variant='primary')
571
+
572
+ with gr.Tabs(elem_id="sadtalker_genearted"):
573
+ gen_video = gr.Video(label="为您生成的专属视频", format="mp4").style(width=256)
574
+
575
+
576
+
577
+ submit.click(
578
+ fn=sad_talker.test,
579
+ inputs=[source_image,
580
+ driven_audio,
581
+ preprocess_type,
582
+ is_still_mode,
583
+ enhancer,
584
+ batch_size,
585
+ size_of_image,
586
+ pose_style,
587
+ facerender,
588
+ exp_weight,
589
+ use_ref_video,
590
+ ref_video,
591
+ ref_info,
592
+ use_idle_mode,
593
+ length_of_audio,
594
+ blink_every
595
+ ],
596
+ outputs=[gen_video]
597
+ )
598
+ gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。</center>")
599
+ gr.Markdown("<center>💡- 如何使用此程序:输入您对ChatGLM的提问后,依次点击“开始和GLM2交流吧”、“生成对应的音频吧”、“开始AI声音克隆吧”、“开始视频聊天吧”四个按键即可;使用声音克隆功能时,请先上传一段您喜欢的音频</center>")
600
+ gr.HTML('''
601
+ <div class="footer">
602
+ <p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘
603
+ </p>
604
+ </div>
605
+ ''')
606
+
607
+
608
+ demo.queue().launch(show_error=True, debug=True)
app.py CHANGED
@@ -1,22 +1,137 @@
1
- """Credit to https://github.com/THUDM/ChatGLM2-6B/blob/main/web_demo.py while mistakes are mine."""
2
- # pylint: disable=broad-exception-caught, redefined-outer-name, missing-function-docstring, missing-module-docstring, too-many-arguments, line-too-long, invalid-name, redefined-builtin, redefined-argument-from-local
3
- # import gradio as gr
 
 
 
4
 
5
- # model_name = "models/THUDM/chatglm2-6b-int4"
6
- # gr.load(model_name).lauch()
 
 
7
 
8
- # %%writefile demo-4bit.py
 
 
 
9
 
10
- import os
11
  import time
12
  from textwrap import dedent
13
 
14
- import gradio as gr
15
  import mdtex2html
16
- import torch
17
  from loguru import logger
18
  from transformers import AutoModel, AutoTokenizer
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # fix timezone in Linux
21
  os.environ["TZ"] = "Asia/Shanghai"
22
  try:
@@ -25,16 +140,32 @@ except Exception:
25
  # Windows
26
  logger.warning("Windows, cant run time.tzset()")
27
 
28
-
29
-
30
- model_name = "fb700/chatglm-fitness-RLHF"
31
 
32
  RETRY_FLAG = False
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
35
- #model = AutoModel.from_pretrained(model_name, trust_remote_code=True).quantize(4).half().cuda()
36
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda()
37
- model = model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  _ = """Override Chatbot.postprocess"""
40
 
@@ -54,6 +185,7 @@ gr.Chatbot.postprocess = postprocess
54
 
55
 
56
  def parse_text(text):
 
57
  lines = text.split("\n")
58
  lines = [line for line in lines if line != ""]
59
  count = 0
@@ -99,8 +231,8 @@ def predict(
99
  yield chatbot, history, past_key_values
100
  # """
101
  yield chatbot, history, past_key_values
102
- """
103
- for response, history, past_key_values in model.stream_chat(
104
  tokenizer,
105
  input,
106
  history,
@@ -110,23 +242,21 @@ def predict(
110
  top_p=top_p,
111
  temperature=temperature,
112
  ):
113
- """
114
- for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
115
- temperature=temperature):
116
  chatbot[-1] = (parse_text(input), parse_text(response))
 
117
 
118
- yield chatbot, history, past_key_values
119
 
120
 
121
- def trans_api(input, max_length=40960, top_p=0.8, temperature=0.2):
122
  if max_length < 10:
123
- max_length = 40960
124
  if top_p < 0.1 or top_p > 1:
125
  top_p = 0.85
126
  if temperature <= 0 or temperature > 1:
127
  temperature = 0.01
128
  try:
129
- res, _ = model.chat(
130
  tokenizer,
131
  input,
132
  history=[],
@@ -148,7 +278,7 @@ def reset_user_input():
148
 
149
 
150
  def reset_state():
151
- return [], [], None
152
 
153
 
154
  # Delete last turn
@@ -184,131 +314,177 @@ def retry_last_answer(
184
  past_key_values,
185
  )
186
 
 
187
 
188
- with gr.Blocks(title="Bofan Ai", theme=gr.themes.Soft(text_size="sm")) as demo:
189
- # gr.HTML("""<h1 align="center">ChatGLM2-6B-int4</h1>""")
190
- gr.HTML(
191
- """<center><a href="https://huggingface.co/spaces/mikeee/chatglm2-6b-4bit?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>It's beyond Fitness,模型由[帛凡]基于ChatGLM-6b进行微调后,在健康(全科)、心理等领域达至少60分的专业水准,而且中文总结能力超越了GPT3.5各版本。</center>"""
192
- """<center>特别声明:本应用仅为模型能力演示,无任何商业行为,部署资源为Huggingface官方免费提供,任何通过此项目产生的知识仅用于学术参考,作者和网站均不承担任何责任。</center>"""
193
- """<h1 align="center">帛凡 Fitness AI 演示</h1>"""
194
- """<center><a href="https://huggingface.co/fb700/chatglm-fitness-RLHF">Bofan基于chatglm-6的微调模型</a>如果喜欢请给个 ❤ 。遇到任何问题可邮件和我联系👉 [email protected]</center>"""
195
- )
196
 
197
- with gr.Accordion("🎈 Info", open=False):
198
- _ = f"""
199
- ## {model_name}
200
 
201
- ChatGLM-6B 是开源中英双语对话模型,本次训练基于ChatGLM-6B 的第一代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上开展训练。
202
-
203
- 本项目经过多位网友实测,中文总结能力超越了GPT3.5各版本,健康咨询水平优于其它同量级模型,且经优化目前可以支持无限context,远大于4k、8K、16K......,可能是任何个人和中小企业首选模型。
204
-
205
- *首先,用40万条高质量数据进行强化训练,以提高模型的基础能力;
206
-
207
- *第二,使用30万条人类反馈数据,构建一个表达方式规范优雅的语言模式(RM模型);
208
-
209
- *第三,在保留SFT阶段三分之一训练数据的同时,增加了30万条fitness数据,叠加RM模型,对ChatGLM-6B进行强化训练。
210
-
211
- 通过训练我们对模型有了更深刻的认知,LLM在一直在进化,好的方法和数据可以挖掘出模型的更大潜能。
212
- 训练中特别强化了中英文学术论文的翻译和总结,可以成为普通用户和科研人员的得力助手。
213
-
214
- 免责声明:本应用仅为模型能力演示,无任何商业行为,部署资源为huggingface官方免费提供,任何通过此项目产生的知识仅用��学术参考,作者和网站均不承担任何责任 。
215
-
216
- The T4 GPU is sponsored by a community GPU grant from Huggingface. Thanks a lot!
217
-
218
- [模型下载地址](https://huggingface.co/fb700/chatglm-fitness-RLHF)
219
-
220
-
221
- """
222
- gr.Markdown(dedent(_))
223
- chatbot = gr.Chatbot()
224
- with gr.Row():
225
- with gr.Column(scale=4):
226
- with gr.Column(scale=12):
227
- user_input = gr.Textbox(
228
- show_label=False,
229
- placeholder="请输入内容Input...",
230
- ).style(container=False)
231
- RETRY_FLAG = gr.Checkbox(value=False, visible=False)
232
- with gr.Column(min_width=32, scale=1):
233
- with gr.Row():
234
- submitBtn = gr.Button("发送Submit", variant="primary")
235
- deleteBtn = gr.Button("删除最后一条对话", variant="secondary")
236
- retryBtn = gr.Button("重新生成Regenerate", variant="secondary")
237
- with gr.Column(scale=1):
238
- emptyBtn = gr.Button("清空对话Clear History")
239
- max_length = gr.Slider(
240
- 0,
241
- 32768,
242
- value=8192,
243
- step=1.0,
244
- label="Maximum length",
245
- interactive=True,
246
- )
247
- top_p = gr.Slider(
248
- 0, 1, value=0.2, step=0.01, label="Top P", interactive=True
249
- )
250
- temperature = gr.Slider(
251
- 0.01, 1, value=0.85, step=0.01, label="Temperature", interactive=True
252
- )
253
 
254
- history = gr.State([])
255
- past_key_values = gr.State(None)
256
-
257
- user_input.submit(
258
- predict,
259
- [
260
- RETRY_FLAG,
261
- user_input,
262
- chatbot,
263
- max_length,
264
- top_p,
265
- temperature,
266
- history,
267
- past_key_values,
268
- ],
269
- [chatbot, history, past_key_values],
270
- show_progress="full",
271
- )
272
- submitBtn.click(
273
- predict,
274
- [
275
- RETRY_FLAG,
276
- user_input,
277
- chatbot,
278
- max_length,
279
- top_p,
280
- temperature,
281
- history,
282
- past_key_values,
283
- ],
284
- [chatbot, history, past_key_values],
285
- show_progress="full",
286
- api_name="predict",
287
- )
288
- submitBtn.click(reset_user_input, [], [user_input])
289
 
290
- emptyBtn.click(
291
- reset_state, outputs=[chatbot, history, past_key_values], show_progress="full"
292
- )
293
 
294
- retryBtn.click(
295
- retry_last_answer,
296
- inputs=[
297
- user_input,
298
- chatbot,
299
- max_length,
300
- top_p,
301
- temperature,
302
- history,
303
- past_key_values,
304
- ],
305
- # outputs = [chatbot, history, last_user_message, user_message]
306
- outputs=[chatbot, history, past_key_values],
307
- )
308
- deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
309
 
310
- with gr.Accordion("Example inputs", open=True):
311
- etext0 = """ "act": "作为基于文本的冒险游戏",\n "prompt": "我想让你扮演一个基于文本的冒险游戏。我在这个基于文本的冒险游戏中扮演一个角色。请尽可能具体地描述角色所看到的内容和环境,并在游戏输出1、2、3让用户选择进行回复,而不是其它方式。我将输入命令来告诉角色该做什么,而你需要回复角色的行动结果以推动游戏的进行。我的第一个命令是'醒来',请从这里开始故事 “ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
313
  etext1 = """云南大学(Yunnan University),简称云大(YNU),位于云南省昆明市,是教育部与云南省“以部为主、部省合建”的全国重点大学,国家“双一流”建设高校 [31] 、211工程、一省一校、中西部高校基础能力建设工程,云南省重点支持的国家一流大学建设高校,“111计划”、卓越法律人才教育培养计划、卓越工程师教育培养计划、国家建设高水平大学公派研究生项目、中国政府奖学金来华留学生接收院校、全国深化创新创业教育改革示范高校,为中西部“一省一校”国家重点建设大学(Z14)联盟、南亚东南亚大学联盟牵头单位。 [1]
314
  云南大学始建于1922年,时为私立东陆大学。1930年,改为省立东陆大学。1934年更名为省立云南大学。1938年改为国立云南大学。1946年,《不列颠百科全书》将云南大学列为中国15所在世界最具影响的大学之一。1950年定名为云南大学。1958年,云南大学由中央高教部划归云南省管理。1978年,云南大学被国务院确定为88所全国重点大学之一。1996年首批列入国家“211工程”重点建设大学。1999年,云南政法高等专科学校并入云南大学。 [2] [23]
@@ -370,37 +546,120 @@ with gr.Blocks(title="Bofan Ai", theme=gr.themes.Soft(text_size="sm")) as demo:
370
  ["Erkläre die Handlung von Cinderella in einem Satz."],
371
  ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
372
  ],
373
- inputs=[user_input],
374
- examples_per_page=50,
 
 
 
 
 
 
 
 
 
 
 
 
375
  )
376
-
377
- with gr.Accordion("For Chat/Translation API", open=False, visible=False):
378
- input_text = gr.Text()
379
- tr_btn = gr.Button("Go", variant="primary")
380
- out_text = gr.Text()
381
- tr_btn.click(
382
- trans_api,
383
- [input_text, max_length, top_p, temperature],
384
- out_text,
385
- # show_progress="full",
386
- api_name="tr",
387
- )
388
- _ = """
389
- input_text.submit(
390
- trans_api,
391
- [input_text, max_length, top_p, temperature],
392
- out_text,
393
- show_progress="full",
394
- api_name="tr1",
395
- )
396
- # """
397
-
398
- # demo.queue().launch(share=False, inbrowser=True)
399
- # demo.queue().launch(share=True, inbrowser=True, debug=True)
400
-
401
- # concurrency_count > 1 requires more memory, max_size: queue size
402
- # T4 medium: 30GB, model size: ~4G concurrency_count = 6
403
- # leave one for api access
404
- # reduce to 5 if OOM occurs to often
405
-
406
- demo.queue(concurrency_count=6, max_size=30).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import tempfile
3
+ import gradio as gr
4
+ from src.gradio_demo import SadTalker
5
+ # from src.utils.text2speech import TTSTalker
6
+ from huggingface_hub import snapshot_download
7
 
8
+ import torch
9
+ import librosa
10
+ from scipy.io.wavfile import write
11
+ from transformers import WavLMModel
12
 
13
+ import utils
14
+ from models import SynthesizerTrn
15
+ from mel_processing import mel_spectrogram_torch
16
+ from speaker_encoder.voice_encoder import SpeakerEncoder
17
 
 
18
  import time
19
  from textwrap import dedent
20
 
 
21
  import mdtex2html
 
22
  from loguru import logger
23
  from transformers import AutoModel, AutoTokenizer
24
 
25
+ from tts_voice import tts_order_voice
26
+ import edge_tts
27
+ import tempfile
28
+ import anyio
29
+
30
+
31
+ def get_source_image(image):
32
+ return image
33
+
34
+ try:
35
+ import webui # in webui
36
+ in_webui = True
37
+ except:
38
+ in_webui = False
39
+
40
+
41
+ def toggle_audio_file(choice):
42
+ if choice == False:
43
+ return gr.update(visible=True), gr.update(visible=False)
44
+ else:
45
+ return gr.update(visible=False), gr.update(visible=True)
46
+
47
+ def ref_video_fn(path_of_ref_video):
48
+ if path_of_ref_video is not None:
49
+ return gr.update(value=True)
50
+ else:
51
+ return gr.update(value=False)
52
+
53
+ def download_model():
54
+ REPO_ID = 'vinthony/SadTalker-V002rc'
55
+ snapshot_download(repo_id=REPO_ID, local_dir='./checkpoints', local_dir_use_symlinks=True)
56
+
57
+ def sadtalker_demo():
58
+
59
+ download_model()
60
+
61
+ sad_talker = SadTalker(lazy_load=True)
62
+ # tts_talker = TTSTalker()
63
+
64
+ download_model()
65
+ sad_talker = SadTalker(lazy_load=True)
66
+
67
+
68
+ # ChatGLM2 & FreeVC
69
+
70
+ '''
71
+ def get_wavlm():
72
+ os.system('gdown https://drive.google.com/uc?id=12-cB34qCTvByWT-QtOcZaqwwO21FLSqU')
73
+ shutil.move('WavLM-Large.pt', 'wavlm')
74
+ '''
75
+
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+
78
+ smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
79
+
80
+ print("Loading FreeVC(24k)...")
81
+ hps = utils.get_hparams_from_file("configs/freevc-24.json")
82
+ freevc_24 = SynthesizerTrn(
83
+ hps.data.filter_length // 2 + 1,
84
+ hps.train.segment_size // hps.data.hop_length,
85
+ **hps.model).to(device)
86
+ _ = freevc_24.eval()
87
+ _ = utils.load_checkpoint("checkpoint/freevc-24.pth", freevc_24, None)
88
+
89
+ print("Loading WavLM for content...")
90
+ cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
91
+
92
+ def convert(model, src, tgt):
93
+ with torch.no_grad():
94
+ # tgt
95
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
96
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
97
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
98
+ g_tgt = smodel.embed_utterance(wav_tgt)
99
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
100
+ else:
101
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
102
+ mel_tgt = mel_spectrogram_torch(
103
+ wav_tgt,
104
+ hps.data.filter_length,
105
+ hps.data.n_mel_channels,
106
+ hps.data.sampling_rate,
107
+ hps.data.hop_length,
108
+ hps.data.win_length,
109
+ hps.data.mel_fmin,
110
+ hps.data.mel_fmax
111
+ )
112
+ # src
113
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
114
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
115
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
116
+ # infer
117
+ if model == "FreeVC":
118
+ audio = freevc.infer(c, g=g_tgt)
119
+ elif model == "FreeVC-s":
120
+ audio = freevc_s.infer(c, mel=mel_tgt)
121
+ else:
122
+ audio = freevc_24.infer(c, g=g_tgt)
123
+ audio = audio[0][0].data.cpu().float().numpy()
124
+ if model == "FreeVC" or model == "FreeVC-s":
125
+ write("out.wav", hps.data.sampling_rate, audio)
126
+ else:
127
+ write("out.wav", 24000, audio)
128
+ out = "out.wav"
129
+ return out
130
+
131
+ # BofanAi
132
+
133
+ language_dict = tts_order_voice
134
+
135
  # fix timezone in Linux
136
  os.environ["TZ"] = "Asia/Shanghai"
137
  try:
 
140
  # Windows
141
  logger.warning("Windows, cant run time.tzset()")
142
 
143
+ # model_name = "THUDM/chatglm2-6b"
144
+ model_name = "fb700/chatglm-fitness-RLHF"
 
145
 
146
  RETRY_FLAG = False
147
 
148
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
149
+
150
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
151
+
152
+ # 4/8 bit
153
+ # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
154
+
155
+ has_cuda = torch.cuda.is_available()
156
+
157
+ # has_cuda = False # force cpu
158
+
159
+ if has_cuda:
160
+ model_glm = (
161
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
162
+ ) # 3.92G
163
+ else:
164
+ model_glm = AutoModel.from_pretrained(
165
+ model_name, trust_remote_code=True
166
+ ).float() # .float() .half().float()
167
+
168
+ model_glm = model_glm.eval()
169
 
170
  _ = """Override Chatbot.postprocess"""
171
 
 
185
 
186
 
187
  def parse_text(text):
188
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
189
  lines = text.split("\n")
190
  lines = [line for line in lines if line != ""]
191
  count = 0
 
231
  yield chatbot, history, past_key_values
232
  # """
233
  yield chatbot, history, past_key_values
234
+
235
+ for response, history, past_key_values in model_glm.stream_chat(
236
  tokenizer,
237
  input,
238
  history,
 
242
  top_p=top_p,
243
  temperature=temperature,
244
  ):
 
 
 
245
  chatbot[-1] = (parse_text(input), parse_text(response))
246
+ # chatbot[-1][-1] = parse_text(response)
247
 
248
+ yield chatbot, history, past_key_values, parse_text(response)
249
 
250
 
251
+ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
252
  if max_length < 10:
253
+ max_length = 4096
254
  if top_p < 0.1 or top_p > 1:
255
  top_p = 0.85
256
  if temperature <= 0 or temperature > 1:
257
  temperature = 0.01
258
  try:
259
+ res, _ = model_glm.chat(
260
  tokenizer,
261
  input,
262
  history=[],
 
278
 
279
 
280
  def reset_state():
281
+ return [], [], None, ""
282
 
283
 
284
  # Delete last turn
 
314
  past_key_values,
315
  )
316
 
317
+ # print
318
 
319
+ def print(text):
320
+ return text
 
 
 
 
 
 
321
 
322
+ # TTS
 
 
323
 
324
+ async def text_to_speech_edge(text, language_code):
325
+ voice = language_dict[language_code]
326
+ communicate = edge_tts.Communicate(text, voice)
327
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
328
+ tmp_path = tmp_file.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ await communicate.save(tmp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ return tmp_path
 
 
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
+ with gr.Blocks(title="Bofan Ai", theme=gr.themes.Soft(text_size="sm"), analytics_enabled=False) as demo:
336
+ gr.HTML("<center>"
337
+ "<h1>📺💕🎶 - BofanAi+声音克隆+视频对话:和喜欢的角色畅所欲言吧!</h1>"
338
+ "</center>"
339
+ """<center><a href="https://huggingface.co/fb700/chatglm-fitness-RLHF">Bofan基于chatglm-6的微调模型</a>如果喜欢请给个 ❤ 。遇到任何问题可邮件和我联系👉 [email protected]</center>"""
340
+ )
341
+ gr.Markdown("## <center>帛凡 Fitness AI 演示</center>"
342
+ """<center>特别声明:本应用仅为模型能力演示,无任何商业行为,部署资源为Huggingface官方免费提供,任何通过此项目产生的知识仅用于学术参考,作者和网站均不承担任何责任。</center>"""
343
+ )
344
+
345
+ with gr.Tab("🍻 - BofanAi聊天区"):
346
+ with gr.Accordion("📒 相关信息", open=False):
347
+ _ = f""" BofanAi的可选参数信息:
348
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
349
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
350
+ * Top P controls dynamic vocabulary selection based on context.\n
351
+ 如果您想让BofanAi进行角色扮演并与之对话,请先输入恰当的提示词,如“请你扮演成动漫角色蜡笔小新并和我进行对话”;您也可以为BofanAi提供自定义的角色设定\n
352
+ 当您使用声音克隆功能时,请先在此程序的对应位置上传一段您喜欢的音频
353
+ ## {model_name}
354
+
355
+ ChatGLM-6B 是开源中英双语对话模型,本次训练基于ChatGLM-6B 的第一代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上开展训练。
356
+
357
+ 本项目经过多位网友实测,中文总结能力超越了GPT3.5各版本,健康咨询水平优于其它同量级模型,且经优化目前可以支持无限context,远大于4k、8K、16K......,可能是任何个人和中小企业首选模型。
358
+
359
+ *首先,用40万条高质量数据进行强化训练,以提高模型的基础能力;
360
+
361
+ *第二,使用30万条人类反馈数据,构建一个表达方式规范优雅的语言模式(RM模型);
362
+
363
+ *第三,在保留SFT阶段三分之一训练数据的同时,增加了30万条fitness数据,叠加RM模型,对ChatGLM-6B进行强化训练。
364
+
365
+ 通过训练我们对模型有了更深刻的认知,LLM在一直在进化,好的方法和数据可以挖掘出模型的更大潜能。
366
+ 训练中特别强化了中英文学术论文的翻译和总结,可以成为普通用户和科研人员的得力助手。
367
+
368
+ 免责声明:本应用仅为模型能力演示,无任何商业行为,部署资源为huggingface官方免费提供,任何通过此项目产生的知识仅用于学术参考,作者和网站均不承担任何责任 。
369
+
370
+ The T4 GPU is sponsored by a community GPU grant from Huggingface. Thanks a lot!
371
+
372
+ [模型下载地址](https://huggingface.co/fb700/chatglm-fitness-RLHF)
373
+ """
374
+ gr.Markdown(dedent(_))
375
+ chatbot = gr.Chatbot(height=300)
376
+ with gr.Row():
377
+ with gr.Column(scale=4):
378
+ with gr.Column(scale=12):
379
+ user_input = gr.Textbox(
380
+ label="请在此处和BofanAi聊天 (按回车键即可发送)",
381
+ placeholder="聊点什么吧",
382
+ )
383
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
384
+ with gr.Column(min_width=32, scale=1):
385
+ with gr.Row():
386
+ submitBtn = gr.Button("开始和BofanAi交流吧", variant="primary")
387
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
388
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
389
+
390
+ with gr.Accordion("🔧 更多设置", open=False):
391
+ with gr.Row():
392
+ emptyBtn = gr.Button("清空所有聊天记录")
393
+ max_length = gr.Slider(
394
+ 0,
395
+ 32768,
396
+ value=8192,
397
+ step=1.0,
398
+ label="Maximum length",
399
+ interactive=True,
400
+ )
401
+ top_p = gr.Slider(
402
+ 0, 1, value=0.2, step=0.01, label="Top P", interactive=True
403
+ )
404
+ temperature = gr.Slider(
405
+ 0.01, 1, value=0.85, step=0.01, label="Temperature", interactive=True
406
+ )
407
+
408
+
409
+ with gr.Row():
410
+ test1 = gr.Textbox(label="BofanAi的最新回答 (可编辑)", lines = 3)
411
+ with gr.Column():
412
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
413
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
414
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
415
+
416
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
417
+
418
+ with gr.Row():
419
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
420
+ audio1 = output_audio
421
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
422
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
423
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
424
+
425
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
426
+
427
+ history = gr.State([])
428
+ past_key_values = gr.State(None)
429
+
430
+ user_input.submit(
431
+ predict,
432
+ [
433
+ RETRY_FLAG,
434
+ user_input,
435
+ chatbot,
436
+ max_length,
437
+ top_p,
438
+ temperature,
439
+ history,
440
+ past_key_values,
441
+ ],
442
+ [chatbot, history, past_key_values, test1],
443
+ show_progress="full",
444
+ )
445
+ submitBtn.click(
446
+ predict,
447
+ [
448
+ RETRY_FLAG,
449
+ user_input,
450
+ chatbot,
451
+ max_length,
452
+ top_p,
453
+ temperature,
454
+ history,
455
+ past_key_values,
456
+ ],
457
+ [chatbot, history, past_key_values, test1],
458
+ show_progress="full",
459
+ api_name="predict",
460
+ )
461
+ submitBtn.click(reset_user_input, [], [user_input])
462
+
463
+ emptyBtn.click(
464
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
465
+ )
466
+
467
+ retryBtn.click(
468
+ retry_last_answer,
469
+ inputs=[
470
+ user_input,
471
+ chatbot,
472
+ max_length,
473
+ top_p,
474
+ temperature,
475
+ history,
476
+ past_key_values,
477
+ ],
478
+ # outputs = [chatbot, history, last_user_message, user_message]
479
+ outputs=[chatbot, history, past_key_values, test1],
480
+ )
481
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
482
+
483
+ with gr.Accordion("📔 模型应用测试", open=True):
484
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
485
+ examples = gr.Examples(
486
+ examples=[
487
+ etext0 = """ "act": "作为基于文本的冒险游戏",\n "prompt": "我想让你扮演一个基于文本的冒险游戏。我在这个基于文本的冒险游戏中扮演一个角色。请尽可能具体地描述角色所看到的内容和环境,并在游戏输出1、2、3让用户选择进行回复,而不是其它方式。我将输入命令来告诉角色该做什么,而你需要回复角色的行动结果以推动游戏的进行。我的第一个命令是'醒来',请从这里开始故事 “ """
488
  etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
489
  etext1 = """云南大学(Yunnan University),简称云大(YNU),位于云南省昆明市,是教育部与云南省“以部为主、部省合建”的全国重点大学,国家“双一流”建设高校 [31] 、211工程、一省一校、中西部高校基础能力建设工程,云南省重点支持的国家一流大学建设高校,“111计划”、卓越法律人才教育培养计划、卓越工程师教育培养计划、国家建设高水平大学公派研究生项目、中国政府奖学金来华留学生接收院校、全国深化创新创业教育改革示范高校,为中西部“一省一校”国家重点建设大学(Z14)联盟、南亚东南亚大学联盟牵头单位。 [1]
490
  云南大学始建于1922年,时为私立东陆大学。1930年,改为省立东陆大学。1934年更名为省立云南大学。1938年改为国立云南大学。1946年,《不列颠百科全书》将云南大学列为中国15所在世界最具影响的大学之一。1950年定名为云南大学。1958年,云南大学由中央高教部划归云南省管理。1978年,云南大学被国务院确定为88所全国重点大学之一。1996年首批列入国家“211工程”重点建设大学。1999年,云南政法高等专科学校并入云南大学。 [2] [23]
 
546
  ["Erkläre die Handlung von Cinderella in einem Satz."],
547
  ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
548
  ],
549
+ inputs=[user_input],
550
+ examples_per_page=50,
551
+ )
552
+
553
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
554
+ input_text = gr.Text()
555
+ tr_btn = gr.Button("Go", variant="primary")
556
+ out_text = gr.Text()
557
+ tr_btn.click(
558
+ trans_api,
559
+ [input_text, max_length, top_p, temperature],
560
+ out_text,
561
+ # show_progress="full",
562
+ api_name="tr",
563
  )
564
+ _ = """
565
+ input_text.submit(
566
+ trans_api,
567
+ [input_text, max_length, top_p, temperature],
568
+ out_text,
569
+ show_progress="full",
570
+ api_name="tr1",
571
+ )
572
+ # """
573
+ with gr.Tab("📺 - 视频聊天区"):
574
+ with gr.Row().style(equal_height=False):
575
+ with gr.Column(variant='panel'):
576
+ with gr.Tabs(elem_id="sadtalker_source_image"):
577
+ with gr.TabItem('图片上传'):
578
+ with gr.Row():
579
+ source_image = gr.Image(label="请上传一张您喜欢角色的图片", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
580
+
581
+
582
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
583
+ with gr.TabItem('💡您还可以将视频下载到本地'):
584
+
585
+ with gr.Row():
586
+ driven_audio = audio_cloned
587
+ driven_audio_no = gr.Audio(label="Use IDLE mode, no audio is required", source="upload", type="filepath", visible=False)
588
+
589
+ with gr.Column():
590
+ use_idle_mode = gr.Checkbox(label="Use Idle Animation", visible=False)
591
+ length_of_audio = gr.Number(value=5, label="The length(seconds) of the generated video.", visible=False)
592
+ use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no]) # todo
593
+
594
+ with gr.Row():
595
+ ref_video = gr.Video(label="Reference Video", source="upload", type="filepath", elem_id="vidref", visible=False).style(width=512)
596
+
597
+ with gr.Column():
598
+ use_ref_video = gr.Checkbox(label="Use Reference Video", visible=False)
599
+ ref_info = gr.Radio(['pose', 'blink','pose+blink', 'all'], value='pose', label='Reference Video',info="How to borrow from reference Video?((fully transfer, aka, video driving mode))", visible=False)
600
+
601
+ ref_video.change(ref_video_fn, inputs=ref_video, outputs=[use_ref_video]) # todo
602
+
603
+
604
+ with gr.Column(variant='panel'):
605
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
606
+ with gr.TabItem('视频设置'):
607
+ with gr.Column(variant='panel'):
608
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
609
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
610
+ with gr.Row():
611
+ pose_style = gr.Slider(minimum=0, maximum=45, step=1, label="Pose style", value=0, visible=False) #
612
+ exp_weight = gr.Slider(minimum=0, maximum=3, step=0.1, label="expression scale", value=1, visible=False) #
613
+ blink_every = gr.Checkbox(label="use eye blink", value=True, visible=False)
614
+
615
+ with gr.Row():
616
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?", visible=False) #
617
+ preprocess_type = gr.Radio(['crop', 'full'], value='crop', label='是否聚焦角色面部', info="crop:视频会聚焦角色面部;full:视频会显示图片全貌")
618
+
619
+ with gr.Row():
620
+ is_still_mode = gr.Checkbox(label="静态模式 (开启静态模式,角色的面部动作会减少;默认开启)", value=True)
621
+ facerender = gr.Radio(['facevid2vid','pirender'], value='facevid2vid', label='facerender', info="which face render?", visible=False)
622
+
623
+ with gr.Row():
624
+ batch_size = gr.Slider(label="Batch size (数值越大,生成速度越快;若显卡性能好,可增大数值)", step=1, maximum=32, value=2)
625
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer", value=True, visible=False)
626
+
627
+ submit = gr.Button('开始视频聊天吧', elem_id="sadtalker_generate", variant='primary')
628
+
629
+ with gr.Tabs(elem_id="sadtalker_genearted"):
630
+ gen_video = gr.Video(label="为您生成的专属视频", format="mp4").style(width=256)
631
+
632
+
633
+
634
+ submit.click(
635
+ fn=sad_talker.test,
636
+ inputs=[source_image,
637
+ driven_audio,
638
+ preprocess_type,
639
+ is_still_mode,
640
+ enhancer,
641
+ batch_size,
642
+ size_of_image,
643
+ pose_style,
644
+ facerender,
645
+ exp_weight,
646
+ use_ref_video,
647
+ ref_video,
648
+ ref_info,
649
+ use_idle_mode,
650
+ length_of_audio,
651
+ blink_every
652
+ ],
653
+ outputs=[gen_video]
654
+ )
655
+ gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。</center>")
656
+ gr.Markdown("<center>💡- 如何使用此程序:输入您对ChatGLM的提问后,依次点击“开始和BofanAi交流吧”、“生成对应的音频吧”、“开始AI声音克隆吧”、“开始视频聊天吧”四个按键即可;使用声音克隆功能时,请先上传一段您喜欢的音频</center>")
657
+ gr.HTML('''
658
+ <div class="footer">
659
+ <p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘
660
+ </p>
661
+ </div>
662
+ ''')
663
+
664
+
665
+ demo.queue().launch(show_error=True, debug=True)
checkpoint/__init__.py ADDED
File without changes
checkpoint/freevc-24.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20
3
+ size 472644351
checkpoints/BFM_Fitting/01_MorphableModel.mat ADDED
File without changes
checkpoints/BFM_Fitting/BFM09_model_info.mat ADDED
File without changes
checkpoints/BFM_Fitting/BFM_exp_idx.mat ADDED
File without changes
checkpoints/BFM_Fitting/BFM_front_idx.mat ADDED
File without changes
checkpoints/BFM_Fitting/facemodel_info.mat ADDED
File without changes
checkpoints/BFM_Fitting/select_vertex_id.mat ADDED
File without changes
checkpoints/BFM_Fitting/similarity_Lm3D_all.mat ADDED
File without changes
checkpoints/BFM_Fitting/std_exp.txt ADDED
File without changes
checkpoints/shape_predictor_68_face_landmarks.dat ADDED
File without changes
commons.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size*dilation - dilation)/2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def rand_spec_segments(x, x_lengths=None, segment_size=4):
68
+ b, d, t = x.size()
69
+ if x_lengths is None:
70
+ x_lengths = t
71
+ ids_str_max = x_lengths - segment_size
72
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
73
+ ret = slice_segments(x, ids_str, segment_size)
74
+ return ret, ids_str
75
+
76
+
77
+ def get_timing_signal_1d(
78
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
79
+ position = torch.arange(length, dtype=torch.float)
80
+ num_timescales = channels // 2
81
+ log_timescale_increment = (
82
+ math.log(float(max_timescale) / float(min_timescale)) /
83
+ (num_timescales - 1))
84
+ inv_timescales = min_timescale * torch.exp(
85
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
86
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
87
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
88
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
89
+ signal = signal.view(1, channels, length)
90
+ return signal
91
+
92
+
93
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
94
+ b, channels, length = x.size()
95
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
96
+ return x + signal.to(dtype=x.dtype, device=x.device)
97
+
98
+
99
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
100
+ b, channels, length = x.size()
101
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
102
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
103
+
104
+
105
+ def subsequent_mask(length):
106
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
107
+ return mask
108
+
109
+
110
+ @torch.jit.script
111
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
112
+ n_channels_int = n_channels[0]
113
+ in_act = input_a + input_b
114
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
115
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
116
+ acts = t_act * s_act
117
+ return acts
118
+
119
+
120
+ def convert_pad_shape(pad_shape):
121
+ l = pad_shape[::-1]
122
+ pad_shape = [item for sublist in l for item in sublist]
123
+ return pad_shape
124
+
125
+
126
+ def shift_1d(x):
127
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
128
+ return x
129
+
130
+
131
+ def sequence_mask(length, max_length=None):
132
+ if max_length is None:
133
+ max_length = length.max()
134
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
135
+ return x.unsqueeze(0) < length.unsqueeze(1)
136
+
137
+
138
+ def generate_path(duration, mask):
139
+ """
140
+ duration: [b, 1, t_x]
141
+ mask: [b, 1, t_y, t_x]
142
+ """
143
+ device = duration.device
144
+
145
+ b, _, t_y, t_x = mask.shape
146
+ cum_duration = torch.cumsum(duration, -1)
147
+
148
+ cum_duration_flat = cum_duration.view(b * t_x)
149
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
150
+ path = path.view(b, t_x, t_y)
151
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
152
+ path = path.unsqueeze(1).transpose(2,3) * mask
153
+ return path
154
+
155
+
156
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
157
+ if isinstance(parameters, torch.Tensor):
158
+ parameters = [parameters]
159
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
160
+ norm_type = float(norm_type)
161
+ if clip_value is not None:
162
+ clip_value = float(clip_value)
163
+
164
+ total_norm = 0
165
+ for p in parameters:
166
+ param_norm = p.grad.data.norm(norm_type)
167
+ total_norm += param_norm.item() ** norm_type
168
+ if clip_value is not None:
169
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
170
+ total_norm = total_norm ** (1. / norm_type)
171
+ return total_norm
configs/freevc-24.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 10000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 64,
11
+ "fp16_run": false,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8640,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0,
18
+ "use_sr": true,
19
+ "max_speclen": 128,
20
+ "port": "8008"
21
+ },
22
+ "data": {
23
+ "training_files":"filelists/train.txt",
24
+ "validation_files":"filelists/val.txt",
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 16000,
27
+ "filter_length": 1280,
28
+ "hop_length": 320,
29
+ "win_length": 1280,
30
+ "n_mel_channels": 80,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null
33
+ },
34
+ "model": {
35
+ "inter_channels": 192,
36
+ "hidden_channels": 192,
37
+ "filter_channels": 768,
38
+ "n_heads": 2,
39
+ "n_layers": 6,
40
+ "kernel_size": 3,
41
+ "p_dropout": 0.1,
42
+ "resblock": "1",
43
+ "resblock_kernel_sizes": [3,7,11],
44
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
45
+ "upsample_rates": [10,6,4,2],
46
+ "upsample_initial_channel": 512,
47
+ "upsample_kernel_sizes": [16,16,4,4],
48
+ "n_layers_q": 3,
49
+ "use_spectral_norm": false,
50
+ "gin_channels": 256,
51
+ "ssl_dim": 1024,
52
+ "use_spk": true
53
+ }
54
+ }
mel_processing.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ import numpy as np
9
+ import librosa
10
+ import librosa.util as librosa_util
11
+ from librosa.util import normalize, pad_center, tiny
12
+ from scipy.signal import get_window
13
+ from scipy.io.wavfile import read
14
+ from librosa.filters import mel as librosa_mel_fn
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ """
21
+ PARAMS
22
+ ------
23
+ C: compression factor
24
+ """
25
+ return torch.log(torch.clamp(x, min=clip_val) * C)
26
+
27
+
28
+ def dynamic_range_decompression_torch(x, C=1):
29
+ """
30
+ PARAMS
31
+ ------
32
+ C: compression factor used to compress
33
+ """
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
+ if torch.min(y) < -1.:
53
+ print('min value is ', torch.min(y))
54
+ if torch.max(y) > 1.:
55
+ print('max value is ', torch.max(y))
56
+
57
+ global hann_window
58
+ dtype_device = str(y.dtype) + '_' + str(y.device)
59
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
60
+ if wnsize_dtype_device not in hann_window:
61
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62
+
63
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64
+ y = y.squeeze(1)
65
+
66
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
68
+
69
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70
+ return spec
71
+
72
+
73
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74
+ global mel_basis
75
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
76
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
77
+ if fmax_dtype_device not in mel_basis:
78
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81
+ spec = spectral_normalize_torch(spec)
82
+ return spec
83
+
84
+
85
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86
+ if torch.min(y) < -1.:
87
+ print('min value is ', torch.min(y))
88
+ if torch.max(y) > 1.:
89
+ print('max value is ', torch.max(y))
90
+
91
+ global mel_basis, hann_window
92
+ dtype_device = str(y.dtype) + '_' + str(y.device)
93
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
94
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
95
+ if fmax_dtype_device not in mel_basis:
96
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
97
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98
+ if wnsize_dtype_device not in hann_window:
99
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100
+
101
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102
+ y = y.squeeze(1)
103
+
104
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
106
+
107
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108
+
109
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110
+ spec = spectral_normalize_torch(spec)
111
+
112
+ return spec
models.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ import modules
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from commons import init_weights, get_padding
13
+
14
+
15
+ class ResidualCouplingBlock(nn.Module):
16
+ def __init__(self,
17
+ channels,
18
+ hidden_channels,
19
+ kernel_size,
20
+ dilation_rate,
21
+ n_layers,
22
+ n_flows=4,
23
+ gin_channels=0):
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.hidden_channels = hidden_channels
27
+ self.kernel_size = kernel_size
28
+ self.dilation_rate = dilation_rate
29
+ self.n_layers = n_layers
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.flows = nn.ModuleList()
34
+ for i in range(n_flows):
35
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
36
+ self.flows.append(modules.Flip())
37
+
38
+ def forward(self, x, x_mask, g=None, reverse=False):
39
+ if not reverse:
40
+ for flow in self.flows:
41
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
42
+ else:
43
+ for flow in reversed(self.flows):
44
+ x = flow(x, x_mask, g=g, reverse=reverse)
45
+ return x
46
+
47
+
48
+ class Encoder(nn.Module):
49
+ def __init__(self,
50
+ in_channels,
51
+ out_channels,
52
+ hidden_channels,
53
+ kernel_size,
54
+ dilation_rate,
55
+ n_layers,
56
+ gin_channels=0):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ self.out_channels = out_channels
60
+ self.hidden_channels = hidden_channels
61
+ self.kernel_size = kernel_size
62
+ self.dilation_rate = dilation_rate
63
+ self.n_layers = n_layers
64
+ self.gin_channels = gin_channels
65
+
66
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
67
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
68
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
69
+
70
+ def forward(self, x, x_lengths, g=None):
71
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
72
+ x = self.pre(x) * x_mask
73
+ x = self.enc(x, x_mask, g=g)
74
+ stats = self.proj(x) * x_mask
75
+ m, logs = torch.split(stats, self.out_channels, dim=1)
76
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
77
+ return z, m, logs, x_mask
78
+
79
+
80
+ class Generator(torch.nn.Module):
81
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
82
+ super(Generator, self).__init__()
83
+ self.num_kernels = len(resblock_kernel_sizes)
84
+ self.num_upsamples = len(upsample_rates)
85
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
86
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
87
+
88
+ self.ups = nn.ModuleList()
89
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
90
+ self.ups.append(weight_norm(
91
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
92
+ k, u, padding=(k-u)//2)))
93
+
94
+ self.resblocks = nn.ModuleList()
95
+ for i in range(len(self.ups)):
96
+ ch = upsample_initial_channel//(2**(i+1))
97
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
98
+ self.resblocks.append(resblock(ch, k, d))
99
+
100
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
101
+ self.ups.apply(init_weights)
102
+
103
+ if gin_channels != 0:
104
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
105
+
106
+ def forward(self, x, g=None):
107
+ x = self.conv_pre(x)
108
+ if g is not None:
109
+ x = x + self.cond(g)
110
+
111
+ for i in range(self.num_upsamples):
112
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
113
+ x = self.ups[i](x)
114
+ xs = None
115
+ for j in range(self.num_kernels):
116
+ if xs is None:
117
+ xs = self.resblocks[i*self.num_kernels+j](x)
118
+ else:
119
+ xs += self.resblocks[i*self.num_kernels+j](x)
120
+ x = xs / self.num_kernels
121
+ x = F.leaky_relu(x)
122
+ x = self.conv_post(x)
123
+ x = torch.tanh(x)
124
+
125
+ return x
126
+
127
+ def remove_weight_norm(self):
128
+ print('Removing weight norm...')
129
+ for l in self.ups:
130
+ remove_weight_norm(l)
131
+ for l in self.resblocks:
132
+ l.remove_weight_norm()
133
+
134
+
135
+ class DiscriminatorP(torch.nn.Module):
136
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
137
+ super(DiscriminatorP, self).__init__()
138
+ self.period = period
139
+ self.use_spectral_norm = use_spectral_norm
140
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
141
+ self.convs = nn.ModuleList([
142
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
143
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
144
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
145
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
146
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
147
+ ])
148
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
149
+
150
+ def forward(self, x):
151
+ fmap = []
152
+
153
+ # 1d to 2d
154
+ b, c, t = x.shape
155
+ if t % self.period != 0: # pad first
156
+ n_pad = self.period - (t % self.period)
157
+ x = F.pad(x, (0, n_pad), "reflect")
158
+ t = t + n_pad
159
+ x = x.view(b, c, t // self.period, self.period)
160
+
161
+ for l in self.convs:
162
+ x = l(x)
163
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
164
+ fmap.append(x)
165
+ x = self.conv_post(x)
166
+ fmap.append(x)
167
+ x = torch.flatten(x, 1, -1)
168
+
169
+ return x, fmap
170
+
171
+
172
+ class DiscriminatorS(torch.nn.Module):
173
+ def __init__(self, use_spectral_norm=False):
174
+ super(DiscriminatorS, self).__init__()
175
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
176
+ self.convs = nn.ModuleList([
177
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
178
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
179
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
180
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
181
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
182
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
183
+ ])
184
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
185
+
186
+ def forward(self, x):
187
+ fmap = []
188
+
189
+ for l in self.convs:
190
+ x = l(x)
191
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
192
+ fmap.append(x)
193
+ x = self.conv_post(x)
194
+ fmap.append(x)
195
+ x = torch.flatten(x, 1, -1)
196
+
197
+ return x, fmap
198
+
199
+
200
+ class MultiPeriodDiscriminator(torch.nn.Module):
201
+ def __init__(self, use_spectral_norm=False):
202
+ super(MultiPeriodDiscriminator, self).__init__()
203
+ periods = [2,3,5,7,11]
204
+
205
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
206
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
207
+ self.discriminators = nn.ModuleList(discs)
208
+
209
+ def forward(self, y, y_hat):
210
+ y_d_rs = []
211
+ y_d_gs = []
212
+ fmap_rs = []
213
+ fmap_gs = []
214
+ for i, d in enumerate(self.discriminators):
215
+ y_d_r, fmap_r = d(y)
216
+ y_d_g, fmap_g = d(y_hat)
217
+ y_d_rs.append(y_d_r)
218
+ y_d_gs.append(y_d_g)
219
+ fmap_rs.append(fmap_r)
220
+ fmap_gs.append(fmap_g)
221
+
222
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
223
+
224
+
225
+ class SpeakerEncoder(torch.nn.Module):
226
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
227
+ super(SpeakerEncoder, self).__init__()
228
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
229
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
230
+ self.relu = nn.ReLU()
231
+
232
+ def forward(self, mels):
233
+ self.lstm.flatten_parameters()
234
+ _, (hidden, _) = self.lstm(mels)
235
+ embeds_raw = self.relu(self.linear(hidden[-1]))
236
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
237
+
238
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
239
+ mel_slices = []
240
+ for i in range(0, total_frames-partial_frames, partial_hop):
241
+ mel_range = torch.arange(i, i+partial_frames)
242
+ mel_slices.append(mel_range)
243
+
244
+ return mel_slices
245
+
246
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
247
+ mel_len = mel.size(1)
248
+ last_mel = mel[:,-partial_frames:]
249
+
250
+ if mel_len > partial_frames:
251
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
252
+ mels = list(mel[:,s] for s in mel_slices)
253
+ mels.append(last_mel)
254
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
255
+
256
+ with torch.no_grad():
257
+ partial_embeds = self(mels)
258
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
259
+ #embed = embed / torch.linalg.norm(embed, 2)
260
+ else:
261
+ with torch.no_grad():
262
+ embed = self(last_mel)
263
+
264
+ return embed
265
+
266
+
267
+ class SynthesizerTrn(nn.Module):
268
+ """
269
+ Synthesizer for Training
270
+ """
271
+
272
+ def __init__(self,
273
+ spec_channels,
274
+ segment_size,
275
+ inter_channels,
276
+ hidden_channels,
277
+ filter_channels,
278
+ n_heads,
279
+ n_layers,
280
+ kernel_size,
281
+ p_dropout,
282
+ resblock,
283
+ resblock_kernel_sizes,
284
+ resblock_dilation_sizes,
285
+ upsample_rates,
286
+ upsample_initial_channel,
287
+ upsample_kernel_sizes,
288
+ gin_channels,
289
+ ssl_dim,
290
+ use_spk,
291
+ **kwargs):
292
+
293
+ super().__init__()
294
+ self.spec_channels = spec_channels
295
+ self.inter_channels = inter_channels
296
+ self.hidden_channels = hidden_channels
297
+ self.filter_channels = filter_channels
298
+ self.n_heads = n_heads
299
+ self.n_layers = n_layers
300
+ self.kernel_size = kernel_size
301
+ self.p_dropout = p_dropout
302
+ self.resblock = resblock
303
+ self.resblock_kernel_sizes = resblock_kernel_sizes
304
+ self.resblock_dilation_sizes = resblock_dilation_sizes
305
+ self.upsample_rates = upsample_rates
306
+ self.upsample_initial_channel = upsample_initial_channel
307
+ self.upsample_kernel_sizes = upsample_kernel_sizes
308
+ self.segment_size = segment_size
309
+ self.gin_channels = gin_channels
310
+ self.ssl_dim = ssl_dim
311
+ self.use_spk = use_spk
312
+
313
+ self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16)
314
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
315
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
316
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
317
+
318
+ if not self.use_spk:
319
+ self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels)
320
+
321
+ def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
322
+ if c_lengths == None:
323
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
324
+ if spec_lengths == None:
325
+ spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
326
+
327
+ if not self.use_spk:
328
+ g = self.enc_spk(mel.transpose(1,2))
329
+ g = g.unsqueeze(-1)
330
+
331
+ _, m_p, logs_p, _ = self.enc_p(c, c_lengths)
332
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
333
+ z_p = self.flow(z, spec_mask, g=g)
334
+
335
+ z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
336
+ o = self.dec(z_slice, g=g)
337
+
338
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
339
+
340
+ def infer(self, c, g=None, mel=None, c_lengths=None):
341
+ if c_lengths == None:
342
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
343
+ if not self.use_spk:
344
+ g = self.enc_spk.embed_utterance(mel.transpose(1,2))
345
+ g = g.unsqueeze(-1)
346
+
347
+ z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
348
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
349
+ o = self.dec(z * c_mask, g=g)
350
+
351
+ return o
modules.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48
+ self.norm_layers.append(LayerNorm(hidden_channels))
49
+ self.relu_drop = nn.Sequential(
50
+ nn.ReLU(),
51
+ nn.Dropout(p_dropout))
52
+ for _ in range(n_layers-1):
53
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56
+ self.proj.weight.data.zero_()
57
+ self.proj.bias.data.zero_()
58
+
59
+ def forward(self, x, x_mask):
60
+ x_org = x
61
+ for i in range(self.n_layers):
62
+ x = self.conv_layers[i](x * x_mask)
63
+ x = self.norm_layers[i](x)
64
+ x = self.relu_drop(x)
65
+ x = x_org + self.proj(x)
66
+ return x * x_mask
67
+
68
+
69
+ class DDSConv(nn.Module):
70
+ """
71
+ Dialted and Depth-Separable Convolution
72
+ """
73
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.kernel_size = kernel_size
77
+ self.n_layers = n_layers
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = nn.Dropout(p_dropout)
81
+ self.convs_sep = nn.ModuleList()
82
+ self.convs_1x1 = nn.ModuleList()
83
+ self.norms_1 = nn.ModuleList()
84
+ self.norms_2 = nn.ModuleList()
85
+ for i in range(n_layers):
86
+ dilation = kernel_size ** i
87
+ padding = (kernel_size * dilation - dilation) // 2
88
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89
+ groups=channels, dilation=dilation, padding=padding
90
+ ))
91
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92
+ self.norms_1.append(LayerNorm(channels))
93
+ self.norms_2.append(LayerNorm(channels))
94
+
95
+ def forward(self, x, x_mask, g=None):
96
+ if g is not None:
97
+ x = x + g
98
+ for i in range(self.n_layers):
99
+ y = self.convs_sep[i](x * x_mask)
100
+ y = self.norms_1[i](y)
101
+ y = F.gelu(y)
102
+ y = self.convs_1x1[i](y)
103
+ y = self.norms_2[i](y)
104
+ y = F.gelu(y)
105
+ y = self.drop(y)
106
+ x = x + y
107
+ return x * x_mask
108
+
109
+
110
+ class WN(torch.nn.Module):
111
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112
+ super(WN, self).__init__()
113
+ assert(kernel_size % 2 == 1)
114
+ self.hidden_channels =hidden_channels
115
+ self.kernel_size = kernel_size,
116
+ self.dilation_rate = dilation_rate
117
+ self.n_layers = n_layers
118
+ self.gin_channels = gin_channels
119
+ self.p_dropout = p_dropout
120
+
121
+ self.in_layers = torch.nn.ModuleList()
122
+ self.res_skip_layers = torch.nn.ModuleList()
123
+ self.drop = nn.Dropout(p_dropout)
124
+
125
+ if gin_channels != 0:
126
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128
+
129
+ for i in range(n_layers):
130
+ dilation = dilation_rate ** i
131
+ padding = int((kernel_size * dilation - dilation) / 2)
132
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133
+ dilation=dilation, padding=padding)
134
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135
+ self.in_layers.append(in_layer)
136
+
137
+ # last one is not necessary
138
+ if i < n_layers - 1:
139
+ res_skip_channels = 2 * hidden_channels
140
+ else:
141
+ res_skip_channels = hidden_channels
142
+
143
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145
+ self.res_skip_layers.append(res_skip_layer)
146
+
147
+ def forward(self, x, x_mask, g=None, **kwargs):
148
+ output = torch.zeros_like(x)
149
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
150
+
151
+ if g is not None:
152
+ g = self.cond_layer(g)
153
+
154
+ for i in range(self.n_layers):
155
+ x_in = self.in_layers[i](x)
156
+ if g is not None:
157
+ cond_offset = i * 2 * self.hidden_channels
158
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159
+ else:
160
+ g_l = torch.zeros_like(x_in)
161
+
162
+ acts = commons.fused_add_tanh_sigmoid_multiply(
163
+ x_in,
164
+ g_l,
165
+ n_channels_tensor)
166
+ acts = self.drop(acts)
167
+
168
+ res_skip_acts = self.res_skip_layers[i](acts)
169
+ if i < self.n_layers - 1:
170
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
171
+ x = (x + res_acts) * x_mask
172
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
173
+ else:
174
+ output = output + res_skip_acts
175
+ return output * x_mask
176
+
177
+ def remove_weight_norm(self):
178
+ if self.gin_channels != 0:
179
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
180
+ for l in self.in_layers:
181
+ torch.nn.utils.remove_weight_norm(l)
182
+ for l in self.res_skip_layers:
183
+ torch.nn.utils.remove_weight_norm(l)
184
+
185
+
186
+ class ResBlock1(torch.nn.Module):
187
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188
+ super(ResBlock1, self).__init__()
189
+ self.convs1 = nn.ModuleList([
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191
+ padding=get_padding(kernel_size, dilation[0]))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193
+ padding=get_padding(kernel_size, dilation[1]))),
194
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195
+ padding=get_padding(kernel_size, dilation[2])))
196
+ ])
197
+ self.convs1.apply(init_weights)
198
+
199
+ self.convs2 = nn.ModuleList([
200
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
+ padding=get_padding(kernel_size, 1))),
202
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
+ padding=get_padding(kernel_size, 1))),
204
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205
+ padding=get_padding(kernel_size, 1)))
206
+ ])
207
+ self.convs2.apply(init_weights)
208
+
209
+ def forward(self, x, x_mask=None):
210
+ for c1, c2 in zip(self.convs1, self.convs2):
211
+ xt = F.leaky_relu(x, LRELU_SLOPE)
212
+ if x_mask is not None:
213
+ xt = xt * x_mask
214
+ xt = c1(xt)
215
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
216
+ if x_mask is not None:
217
+ xt = xt * x_mask
218
+ xt = c2(xt)
219
+ x = xt + x
220
+ if x_mask is not None:
221
+ x = x * x_mask
222
+ return x
223
+
224
+ def remove_weight_norm(self):
225
+ for l in self.convs1:
226
+ remove_weight_norm(l)
227
+ for l in self.convs2:
228
+ remove_weight_norm(l)
229
+
230
+
231
+ class ResBlock2(torch.nn.Module):
232
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233
+ super(ResBlock2, self).__init__()
234
+ self.convs = nn.ModuleList([
235
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]))),
237
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238
+ padding=get_padding(kernel_size, dilation[1])))
239
+ ])
240
+ self.convs.apply(init_weights)
241
+
242
+ def forward(self, x, x_mask=None):
243
+ for c in self.convs:
244
+ xt = F.leaky_relu(x, LRELU_SLOPE)
245
+ if x_mask is not None:
246
+ xt = xt * x_mask
247
+ xt = c(xt)
248
+ x = xt + x
249
+ if x_mask is not None:
250
+ x = x * x_mask
251
+ return x
252
+
253
+ def remove_weight_norm(self):
254
+ for l in self.convs:
255
+ remove_weight_norm(l)
256
+
257
+
258
+ class Log(nn.Module):
259
+ def forward(self, x, x_mask, reverse=False, **kwargs):
260
+ if not reverse:
261
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262
+ logdet = torch.sum(-y, [1, 2])
263
+ return y, logdet
264
+ else:
265
+ x = torch.exp(x) * x_mask
266
+ return x
267
+
268
+
269
+ class Flip(nn.Module):
270
+ def forward(self, x, *args, reverse=False, **kwargs):
271
+ x = torch.flip(x, [1])
272
+ if not reverse:
273
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274
+ return x, logdet
275
+ else:
276
+ return x
277
+
278
+
279
+ class ElementwiseAffine(nn.Module):
280
+ def __init__(self, channels):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.m = nn.Parameter(torch.zeros(channels,1))
284
+ self.logs = nn.Parameter(torch.zeros(channels,1))
285
+
286
+ def forward(self, x, x_mask, reverse=False, **kwargs):
287
+ if not reverse:
288
+ y = self.m + torch.exp(self.logs) * x
289
+ y = y * x_mask
290
+ logdet = torch.sum(self.logs * x_mask, [1,2])
291
+ return y, logdet
292
+ else:
293
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
294
+ return x
295
+
296
+
297
+ class ResidualCouplingLayer(nn.Module):
298
+ def __init__(self,
299
+ channels,
300
+ hidden_channels,
301
+ kernel_size,
302
+ dilation_rate,
303
+ n_layers,
304
+ p_dropout=0,
305
+ gin_channels=0,
306
+ mean_only=False):
307
+ assert channels % 2 == 0, "channels should be divisible by 2"
308
+ super().__init__()
309
+ self.channels = channels
310
+ self.hidden_channels = hidden_channels
311
+ self.kernel_size = kernel_size
312
+ self.dilation_rate = dilation_rate
313
+ self.n_layers = n_layers
314
+ self.half_channels = channels // 2
315
+ self.mean_only = mean_only
316
+
317
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320
+ self.post.weight.data.zero_()
321
+ self.post.bias.data.zero_()
322
+
323
+ def forward(self, x, x_mask, g=None, reverse=False):
324
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325
+ h = self.pre(x0) * x_mask
326
+ h = self.enc(h, x_mask, g=g)
327
+ stats = self.post(h) * x_mask
328
+ if not self.mean_only:
329
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
330
+ else:
331
+ m = stats
332
+ logs = torch.zeros_like(m)
333
+
334
+ if not reverse:
335
+ x1 = m + x1 * torch.exp(logs) * x_mask
336
+ x = torch.cat([x0, x1], 1)
337
+ logdet = torch.sum(logs, [1,2])
338
+ return x, logdet
339
+ else:
340
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
341
+ x = torch.cat([x0, x1], 1)
342
+ return x
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt CHANGED
@@ -1,9 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  protobuf
2
- transformers==4.30.2
3
  cpm_kernels
4
- torch>=2.0
5
- # gradio
6
  mdtex2html
7
  sentencepiece
8
  accelerate
9
- loguru
 
 
 
 
1
+ torch
2
+ torchvision
3
+ torchaudio
4
+ numpy==1.22.0
5
+ face_alignment==1.3.0
6
+ imageio==2.19.3
7
+ imageio-ffmpeg==0.4.7
8
+ librosa==0.8.1
9
+ numba
10
+ resampy==0.3.1
11
+ pydub==0.25.1
12
+ scipy
13
+ kornia==0.6.8
14
+ tqdm
15
+ yacs==0.1.8
16
+ pyyaml
17
+ joblib==1.1.0
18
+ scikit-image==0.19.3
19
+ basicsr==1.4.2
20
+ facexlib==0.3.0
21
+ dlib-bin
22
+ gfpgan
23
+ av
24
+ safetensors
25
+ transformers
26
+ webrtcvad==2.0.10
27
  protobuf
 
28
  cpm_kernels
 
 
29
  mdtex2html
30
  sentencepiece
31
  accelerate
32
+ loguru
33
+ edge_tts
34
+ altair
35
+ gradio==3.36.1
speaker_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
speaker_encoder/audio.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from speaker_encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ import numpy as np
6
+ import webrtcvad
7
+ import librosa
8
+ import struct
9
+
10
+ int16_max = (2 ** 15) - 1
11
+
12
+
13
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14
+ source_sr: Optional[int] = None):
15
+ """
16
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18
+
19
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20
+ just .wav), either the waveform as a numpy array of floats.
21
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
23
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24
+ this argument will be ignored.
25
+ """
26
+ # Load the wav from disk if needed
27
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28
+ wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29
+ else:
30
+ wav = fpath_or_wav
31
+
32
+ # Resample the wav if needed
33
+ if source_sr is not None and source_sr != sampling_rate:
34
+ wav = librosa.resample(wav, source_sr, sampling_rate)
35
+
36
+ # Apply the preprocessing: normalize volume and shorten long silences
37
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38
+ wav = trim_long_silences(wav)
39
+
40
+ return wav
41
+
42
+
43
+ def wav_to_mel_spectrogram(wav):
44
+ """
45
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46
+ Note: this not a log-mel spectrogram.
47
+ """
48
+ frames = librosa.feature.melspectrogram(
49
+ y=wav,
50
+ sr=sampling_rate,
51
+ n_fft=int(sampling_rate * mel_window_length / 1000),
52
+ hop_length=int(sampling_rate * mel_window_step / 1000),
53
+ n_mels=mel_n_channels
54
+ )
55
+ return frames.astype(np.float32).T
56
+
57
+
58
+ def trim_long_silences(wav):
59
+ """
60
+ Ensures that segments without voice in the waveform remain no longer than a
61
+ threshold determined by the VAD parameters in params.py.
62
+
63
+ :param wav: the raw waveform as a numpy array of floats
64
+ :return: the same waveform with silences trimmed away (length <= original wav length)
65
+ """
66
+ # Compute the voice detection window size
67
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
68
+
69
+ # Trim the end of the audio to have a multiple of the window size
70
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71
+
72
+ # Convert the float waveform to 16-bit mono PCM
73
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74
+
75
+ # Perform voice activation detection
76
+ voice_flags = []
77
+ vad = webrtcvad.Vad(mode=3)
78
+ for window_start in range(0, len(wav), samples_per_window):
79
+ window_end = window_start + samples_per_window
80
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81
+ sample_rate=sampling_rate))
82
+ voice_flags = np.array(voice_flags)
83
+
84
+ # Smooth the voice detection with a moving average
85
+ def moving_average(array, width):
86
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87
+ ret = np.cumsum(array_padded, dtype=float)
88
+ ret[width:] = ret[width:] - ret[:-width]
89
+ return ret[width - 1:] / width
90
+
91
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
92
+ audio_mask = np.round(audio_mask).astype(np.bool)
93
+
94
+ # Dilate the voiced regions
95
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96
+ audio_mask = np.repeat(audio_mask, samples_per_window)
97
+
98
+ return wav[audio_mask == True]
99
+
100
+
101
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102
+ if increase_only and decrease_only:
103
+ raise ValueError("Both increase only and decrease only are set")
104
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106
+ return wav
107
+ return wav * (10 ** (dBFS_change / 20))
speaker_encoder/ckpt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
speaker_encoder/ckpt/pretrained_bak_5805000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
3
+ size 17090379
speaker_encoder/compute_embed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder import inference as encoder
2
+ from multiprocessing.pool import Pool
3
+ from functools import partial
4
+ from pathlib import Path
5
+ # from utils import logmmse
6
+ # from tqdm import tqdm
7
+ # import numpy as np
8
+ # import librosa
9
+
10
+
11
+ def embed_utterance(fpaths, encoder_model_fpath):
12
+ if not encoder.is_loaded():
13
+ encoder.load_model(encoder_model_fpath)
14
+
15
+ # Compute the speaker embedding of the utterance
16
+ wav_fpath, embed_fpath = fpaths
17
+ wav = np.load(wav_fpath)
18
+ wav = encoder.preprocess_wav(wav)
19
+ embed = encoder.embed_utterance(wav)
20
+ np.save(embed_fpath, embed, allow_pickle=False)
21
+
22
+
23
+ def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24
+
25
+ wav_dir = outdir_root.joinpath("audio")
26
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
27
+ assert wav_dir.exists() and metadata_fpath.exists()
28
+ embed_dir = synthesizer_root.joinpath("embeds")
29
+ embed_dir.mkdir(exist_ok=True)
30
+
31
+ # Gather the input wave filepath and the target output embed filepath
32
+ with metadata_fpath.open("r") as metadata_file:
33
+ metadata = [line.split("|") for line in metadata_file]
34
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35
+
36
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37
+ # Embed the utterances in separate threads
38
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39
+ job = Pool(n_processes).imap(func, fpaths)
40
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
speaker_encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
speaker_encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
speaker_encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
speaker_encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
speaker_encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
speaker_encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+ from speaker_encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
speaker_encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
speaker_encoder/hparams.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Mel-filterbank
2
+ mel_window_length = 25 # In milliseconds
3
+ mel_window_step = 10 # In milliseconds
4
+ mel_n_channels = 40
5
+
6
+
7
+ ## Audio
8
+ sampling_rate = 16000
9
+ # Number of spectrogram frames in a partial utterance
10
+ partials_n_frames = 160 # 1600 ms
11
+
12
+
13
+ ## Voice Activation Detection
14
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15
+ # This sets the granularity of the VAD. Should not need to be changed.
16
+ vad_window_length = 30 # In milliseconds
17
+ # Number of frames to average together when performing the moving average smoothing.
18
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
19
+ vad_moving_average_width = 8
20
+ # Maximum number of consecutive silent frames a segment can have.
21
+ vad_max_silence_length = 6
22
+
23
+
24
+ ## Audio volume normalization
25
+ audio_norm_target_dBFS = -30
26
+
27
+
28
+ ## Model parameters
29
+ model_hidden_size = 256
30
+ model_embedding_size = 256
31
+ model_num_layers = 3
speaker_encoder/inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_data import *
2
+ from speaker_encoder.model import SpeakerEncoder
3
+ from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ cbar.set_clim(*color_range)
175
+
176
+ ax.set_xticks([]), ax.set_yticks([])
177
+ ax.set_title(title)
speaker_encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_model import *
2
+ from speaker_encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19
+ hidden_size=model_hidden_size, # 256
20
+ num_layers=model_num_layers, # 3
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
speaker_encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
speaker_encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
speaker_encoder/preprocess.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from speaker_encoder.params_data import *
3
+ from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from speaker_encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+
121
+ # Function to preprocess utterances for one speaker
122
+ def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
123
+ # Give a name to the speaker that includes its dataset
124
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
125
+
126
+ # Create an output directory with that name, as well as a txt file containing a
127
+ # reference to each source file.
128
+ speaker_out_dir = out_dir.joinpath(speaker_name)
129
+ speaker_out_dir.mkdir(exist_ok=True)
130
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
131
+
132
+ # There's a possibility that the preprocessing was interrupted earlier, check if
133
+ # there already is a sources file.
134
+ # if sources_fpath.exists():
135
+ # try:
136
+ # with sources_fpath.open("r") as sources_file:
137
+ # existing_fnames = {line.split(",")[0] for line in sources_file}
138
+ # except:
139
+ # existing_fnames = {}
140
+ # else:
141
+ # existing_fnames = {}
142
+ existing_fnames = {}
143
+ # Gather all audio files for that speaker recursively
144
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
145
+
146
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
147
+ # Check if the target output file already exists
148
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
149
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
150
+ if skip_existing and out_fname in existing_fnames:
151
+ continue
152
+
153
+ # Load and preprocess the waveform
154
+ wav = audio.preprocess_wav(in_fpath)
155
+ if len(wav) == 0:
156
+ continue
157
+
158
+ # Create the mel spectrogram, discard those that are too short
159
+ frames = audio.wav_to_mel_spectrogram(wav)
160
+ if len(frames) < partials_n_frames:
161
+ continue
162
+
163
+ out_fpath = speaker_out_dir.joinpath(out_fname)
164
+ np.save(out_fpath, frames)
165
+ # logger.add_sample(duration=len(wav) / sampling_rate)
166
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
167
+
168
+ sources_file.close()
169
+ return len(wav)
170
+
171
+ def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
172
+ skip_existing, logger):
173
+ # from multiprocessing import Pool, cpu_count
174
+ from pathos.multiprocessing import ProcessingPool as Pool
175
+ # Function to preprocess utterances for one speaker
176
+ def __preprocess_speaker(speaker_dir: Path):
177
+ # Give a name to the speaker that includes its dataset
178
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
179
+
180
+ # Create an output directory with that name, as well as a txt file containing a
181
+ # reference to each source file.
182
+ speaker_out_dir = out_dir.joinpath(speaker_name)
183
+ speaker_out_dir.mkdir(exist_ok=True)
184
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
185
+
186
+ existing_fnames = {}
187
+ # Gather all audio files for that speaker recursively
188
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
189
+ wav_lens = []
190
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
191
+ # Check if the target output file already exists
192
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
193
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
194
+ if skip_existing and out_fname in existing_fnames:
195
+ continue
196
+
197
+ # Load and preprocess the waveform
198
+ wav = audio.preprocess_wav(in_fpath)
199
+ if len(wav) == 0:
200
+ continue
201
+
202
+ # Create the mel spectrogram, discard those that are too short
203
+ frames = audio.wav_to_mel_spectrogram(wav)
204
+ if len(frames) < partials_n_frames:
205
+ continue
206
+
207
+ out_fpath = speaker_out_dir.joinpath(out_fname)
208
+ np.save(out_fpath, frames)
209
+ # logger.add_sample(duration=len(wav) / sampling_rate)
210
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
211
+ wav_lens.append(len(wav))
212
+ sources_file.close()
213
+ return wav_lens
214
+
215
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
216
+ # Process the utterances for each speaker
217
+ # with ThreadPool(8) as pool:
218
+ # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
219
+ # unit="speakers"))
220
+ pool = Pool(processes=20)
221
+ for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
222
+ for wav_len in wav_lens:
223
+ logger.add_sample(duration=wav_len / sampling_rate)
224
+ print(f'{i}/{len(speaker_dirs)} \r')
225
+
226
+ logger.finalize()
227
+ print("Done preprocessing %s.\n" % dataset_name)
228
+
229
+
230
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
231
+ for dataset_name in librispeech_datasets["train"]["other"]:
232
+ # Initialize the preprocessing
233
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
234
+ if not dataset_root:
235
+ return
236
+
237
+ # Preprocess all speakers
238
+ speaker_dirs = list(dataset_root.glob("*"))
239
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
240
+ skip_existing, logger)
241
+
242
+
243
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
244
+ # Initialize the preprocessing
245
+ dataset_name = "VoxCeleb1"
246
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
247
+ if not dataset_root:
248
+ return
249
+
250
+ # Get the contents of the meta file
251
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
252
+ metadata = [line.split("\t") for line in metafile][1:]
253
+
254
+ # Select the ID and the nationality, filter out non-anglophone speakers
255
+ nationalities = {line[0]: line[3] for line in metadata}
256
+ # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
257
+ # nationality.lower() in anglophone_nationalites]
258
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
259
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
260
+ (len(keep_speaker_ids), len(nationalities)))
261
+
262
+ # Get the speaker directories for anglophone speakers only
263
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
264
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
265
+ speaker_dir.name in keep_speaker_ids]
266
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
267
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
268
+
269
+ # Preprocess all speakers
270
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
271
+ skip_existing, logger)
272
+
273
+
274
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
275
+ # Initialize the preprocessing
276
+ dataset_name = "VoxCeleb2"
277
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
278
+ if not dataset_root:
279
+ return
280
+
281
+ # Get the speaker directories
282
+ # Preprocess all speakers
283
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
284
+ _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
285
+ skip_existing, logger)
speaker_encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.visualizations import Visualizations
2
+ from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from speaker_encoder.params_model import *
4
+ from speaker_encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # FIXME
11
+ return
12
+ # For correct profiling (cuda operations are async)
13
+ if device.type == "cuda":
14
+ torch.cuda.synchronize(device)
15
+
16
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
17
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
18
+ no_visdom: bool):
19
+ # Create a dataset and a dataloader
20
+ dataset = SpeakerVerificationDataset(clean_data_root)
21
+ loader = SpeakerVerificationDataLoader(
22
+ dataset,
23
+ speakers_per_batch, # 64
24
+ utterances_per_speaker, # 10
25
+ num_workers=8,
26
+ )
27
+
28
+ # Setup the device on which to run the forward pass and the loss. These can be different,
29
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
30
+ # hyperparameters) faster on the CPU.
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ # FIXME: currently, the gradient is None if loss_device is cuda
33
+ loss_device = torch.device("cpu")
34
+
35
+ # Create the model and the optimizer
36
+ model = SpeakerEncoder(device, loss_device)
37
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
38
+ init_step = 1
39
+
40
+ # Configure file path for the model
41
+ state_fpath = models_dir.joinpath(run_id + ".pt")
42
+ backup_dir = models_dir.joinpath(run_id + "_backups")
43
+
44
+ # Load any existing model
45
+ if not force_restart:
46
+ if state_fpath.exists():
47
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
48
+ checkpoint = torch.load(state_fpath)
49
+ init_step = checkpoint["step"]
50
+ model.load_state_dict(checkpoint["model_state"])
51
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
52
+ optimizer.param_groups[0]["lr"] = learning_rate_init
53
+ else:
54
+ print("No model \"%s\" found, starting training from scratch." % run_id)
55
+ else:
56
+ print("Starting the training from scratch.")
57
+ model.train()
58
+
59
+ # Initialize the visualization environment
60
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
61
+ vis.log_dataset(dataset)
62
+ vis.log_params()
63
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
64
+ vis.log_implementation({"Device": device_name})
65
+
66
+ # Training loop
67
+ profiler = Profiler(summarize_every=10, disabled=False)
68
+ for step, speaker_batch in enumerate(loader, init_step):
69
+ profiler.tick("Blocking, waiting for batch (threaded)")
70
+
71
+ # Forward pass
72
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
73
+ sync(device)
74
+ profiler.tick("Data to %s" % device)
75
+ embeds = model(inputs)
76
+ sync(device)
77
+ profiler.tick("Forward pass")
78
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
79
+ loss, eer = model.loss(embeds_loss)
80
+ sync(loss_device)
81
+ profiler.tick("Loss")
82
+
83
+ # Backward pass
84
+ model.zero_grad()
85
+ loss.backward()
86
+ profiler.tick("Backward pass")
87
+ model.do_gradient_ops()
88
+ optimizer.step()
89
+ profiler.tick("Parameter update")
90
+
91
+ # Update visualizations
92
+ # learning_rate = optimizer.param_groups[0]["lr"]
93
+ vis.update(loss.item(), eer, step)
94
+
95
+ # Draw projections and save them to the backup folder
96
+ if umap_every != 0 and step % umap_every == 0:
97
+ print("Drawing and saving projections (step %d)" % step)
98
+ backup_dir.mkdir(exist_ok=True)
99
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
100
+ embeds = embeds.detach().cpu().numpy()
101
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
102
+ vis.save()
103
+
104
+ # Overwrite the latest version of the model
105
+ if save_every != 0 and step % save_every == 0:
106
+ print("Saving the model (step %d)" % step)
107
+ torch.save({
108
+ "step": step + 1,
109
+ "model_state": model.state_dict(),
110
+ "optimizer_state": optimizer.state_dict(),
111
+ }, state_fpath)
112
+
113
+ # Make a backup
114
+ if backup_every != 0 and step % backup_every == 0:
115
+ print("Making a backup (step %d)" % step)
116
+ backup_dir.mkdir(exist_ok=True)
117
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
118
+ torch.save({
119
+ "step": step + 1,
120
+ "model_state": model.state_dict(),
121
+ "optimizer_state": optimizer.state_dict(),
122
+ }, backup_fpath)
123
+
124
+ profiler.tick("Extras (visualizations, saving)")
125
+
speaker_encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from speaker_encoder import params_data
69
+ from speaker_encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
speaker_encoder/voice_encoder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.hparams import *
2
+ from speaker_encoder import audio
3
+ from pathlib import Path
4
+ from typing import Union, List
5
+ from torch import nn
6
+ from time import perf_counter as timer
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class SpeakerEncoder(nn.Module):
12
+ def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
13
+ """
14
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
15
+ If None, defaults to cuda if it is available on your machine, otherwise the model will
16
+ run on cpu. Outputs are always returned on the cpu, as numpy arrays.
17
+ """
18
+ super().__init__()
19
+
20
+ # Define the network
21
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
22
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
23
+ self.relu = nn.ReLU()
24
+
25
+ # Get the target device
26
+ if device is None:
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ elif isinstance(device, str):
29
+ device = torch.device(device)
30
+ self.device = device
31
+
32
+ # Load the pretrained model'speaker weights
33
+ # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
34
+ # if not weights_fpath.exists():
35
+ # raise Exception("Couldn't find the voice encoder pretrained model at %s." %
36
+ # weights_fpath)
37
+
38
+ start = timer()
39
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
40
+
41
+ self.load_state_dict(checkpoint["model_state"], strict=False)
42
+ self.to(device)
43
+
44
+ if verbose:
45
+ print("Loaded the voice encoder model on %s in %.2f seconds." %
46
+ (device.type, timer() - start))
47
+
48
+ def forward(self, mels: torch.FloatTensor):
49
+ """
50
+ Computes the embeddings of a batch of utterance spectrograms.
51
+ :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
52
+ (batch_size, n_frames, n_channels)
53
+ :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
54
+ Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
55
+ """
56
+ # Pass the input through the LSTM layers and retrieve the final hidden state of the last
57
+ # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
58
+ _, (hidden, _) = self.lstm(mels)
59
+ embeds_raw = self.relu(self.linear(hidden[-1]))
60
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
61
+
62
+ @staticmethod
63
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
64
+ """
65
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to
66
+ obtain partial utterances of <partials_n_frames> each. Both the waveform and the
67
+ mel spectrogram slices are returned, so as to make each partial utterance waveform
68
+ correspond to its spectrogram.
69
+
70
+ The returned ranges may be indexing further than the length of the waveform. It is
71
+ recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
72
+
73
+ :param n_samples: the number of samples in the waveform
74
+ :param rate: how many partial utterances should occur per second. Partial utterances must
75
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
76
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
77
+ the minimum rate is thus 0.625.
78
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
79
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
80
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
81
+ it will be discarded. If there aren't enough frames for one partial utterance,
82
+ this parameter is ignored so that the function always returns at least one slice.
83
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
84
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
85
+ utterances.
86
+ """
87
+ assert 0 < min_coverage <= 1
88
+
89
+ # Compute how many frames separate two partial utterances
90
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
91
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
92
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
93
+ assert 0 < frame_step, "The rate is too high"
94
+ assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
95
+ (sampling_rate / (samples_per_frame * partials_n_frames))
96
+
97
+ # Compute the slices
98
+ wav_slices, mel_slices = [], []
99
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
100
+ for i in range(0, steps, frame_step):
101
+ mel_range = np.array([i, i + partials_n_frames])
102
+ wav_range = mel_range * samples_per_frame
103
+ mel_slices.append(slice(*mel_range))
104
+ wav_slices.append(slice(*wav_range))
105
+
106
+ # Evaluate whether extra padding is warranted or not
107
+ last_wav_range = wav_slices[-1]
108
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
109
+ if coverage < min_coverage and len(mel_slices) > 1:
110
+ mel_slices = mel_slices[:-1]
111
+ wav_slices = wav_slices[:-1]
112
+
113
+ return wav_slices, mel_slices
114
+
115
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
116
+ """
117
+ Computes an embedding for a single utterance. The utterance is divided in partial
118
+ utterances and an embedding is computed for each. The complete utterance embedding is the
119
+ L2-normed average embedding of the partial utterances.
120
+
121
+ TODO: independent batched version of this function
122
+
123
+ :param wav: a preprocessed utterance waveform as a numpy array of float32
124
+ :param return_partials: if True, the partial embeddings will also be returned along with
125
+ the wav slices corresponding to each partial utterance.
126
+ :param rate: how many partial utterances should occur per second. Partial utterances must
127
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
128
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
129
+ the minimum rate is thus 0.625.
130
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
131
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
132
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
133
+ it will be discarded. If there aren't enough frames for one partial utterance,
134
+ this parameter is ignored so that the function always returns at least one slice.
135
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
136
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
137
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
138
+ returned.
139
+ """
140
+ # Compute where to split the utterance into partials and pad the waveform with zeros if
141
+ # the partial utterances cover a larger range.
142
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
143
+ max_wave_length = wav_slices[-1].stop
144
+ if max_wave_length >= len(wav):
145
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
146
+
147
+ # Split the utterance into partials and forward them through the model
148
+ mel = audio.wav_to_mel_spectrogram(wav)
149
+ mels = np.array([mel[s] for s in mel_slices])
150
+ with torch.no_grad():
151
+ mels = torch.from_numpy(mels).to(self.device)
152
+ partial_embeds = self(mels).cpu().numpy()
153
+
154
+ # Compute the utterance embedding from the partial embeddings
155
+ raw_embed = np.mean(partial_embeds, axis=0)
156
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
157
+
158
+ if return_partials:
159
+ return embed, partial_embeds, wav_slices
160
+ return embed
161
+
162
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
163
+ """
164
+ Compute the embedding of a collection of wavs (presumably from the same speaker) by
165
+ averaging their embedding and L2-normalizing it.
166
+
167
+ :param wavs: list of wavs a numpy arrays of float32.
168
+ :param kwargs: extra arguments to embed_utterance()
169
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
170
+ """
171
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
172
+ for wav in wavs], axis=0)
173
+ return raw_embed / np.linalg.norm(raw_embed, 2)
src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Audio2Exp(nn.Module):
7
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
+ super(Audio2Exp, self).__init__()
9
+ self.cfg = cfg
10
+ self.device = device
11
+ self.netG = netG.to(device)
12
+
13
+ def test(self, batch):
14
+
15
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
16
+ bs = mel_input.shape[0]
17
+ T = mel_input.shape[1]
18
+
19
+ exp_coeff_pred = []
20
+
21
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
+
23
+ current_mel_input = mel_input[:,i:i+10]
24
+
25
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
+ ref = batch['ref'][:, :, :64][:, i:i+10]
27
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
+
29
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
+
31
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
+
33
+ exp_coeff_pred += [curr_exp_coeff_pred]
34
+
35
+ # BS x T x 64
36
+ results_dict = {
37
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
+ }
39
+ return results_dict
40
+
41
+
src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+
24
+ def forward(self, x):
25
+
26
+ batch = {}
27
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
31
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
+
33
+ # forward
34
+ audio_emb_list = []
35
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
+ batch['audio_emb'] = audio_emb
37
+ batch = self.netG(batch)
38
+
39
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
+
43
+ batch['pose_pred'] = pose_pred
44
+ batch['pose_gt'] = pose_gt
45
+
46
+ return batch
47
+
48
+ def test(self, x):
49
+
50
+ batch = {}
51
+ ref = x['ref'] #bs 1 70
52
+ batch['ref'] = x['ref'][:,0,-6:]
53
+ batch['class'] = x['class']
54
+ bs = ref.shape[0]
55
+
56
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
+ num_frames = x['num_frames']
59
+ num_frames = int(num_frames) - 1
60
+
61
+ #
62
+ div = num_frames//self.seq_len
63
+ re = num_frames%self.seq_len
64
+ audio_emb_list = []
65
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
+ device=batch['ref'].device)]
67
+
68
+ for i in range(div):
69
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
70
+ batch['z'] = z
71
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
+ batch['audio_emb'] = audio_emb
73
+ batch = self.netG.test(batch)
74
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
+
76
+ if re != 0:
77
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
78
+ batch['z'] = z
79
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
+ if audio_emb.shape[1] != self.seq_len:
81
+ pad_dim = self.seq_len-audio_emb.shape[1]
82
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
+ batch['audio_emb'] = audio_emb
85
+ batch = self.netG.test(batch)
86
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
+
88
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
+ batch['pose_motion_pred'] = pose_motion_pred
90
+
91
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
+
93
+ batch['pose_pred'] = pose_pred
94
+ return batch