Spaces:
Runtime error
Runtime error
Commit
·
f4dac30
0
Parent(s):
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +31 -0
- .gitignore +18 -0
- CODE_OF_CONDUCT.md +130 -0
- LICENSE.txt +24 -0
- README-CN.md +230 -0
- README.md +13 -0
- app.py +80 -0
- demo_toolbox.py +49 -0
- encoder/__init__.py +0 -0
- encoder/audio.py +117 -0
- encoder/config.py +45 -0
- encoder/data_objects/__init__.py +2 -0
- encoder/data_objects/random_cycler.py +37 -0
- encoder/data_objects/speaker.py +40 -0
- encoder/data_objects/speaker_batch.py +12 -0
- encoder/data_objects/speaker_verification_dataset.py +56 -0
- encoder/data_objects/utterance.py +26 -0
- encoder/inference.py +195 -0
- encoder/model.py +135 -0
- encoder/params_data.py +29 -0
- encoder/params_model.py +11 -0
- encoder/preprocess.py +184 -0
- encoder/saved_models/pretrained.pt +3 -0
- encoder/train.py +123 -0
- encoder/visualizations.py +178 -0
- encoder_preprocess.py +61 -0
- encoder_train.py +47 -0
- gen_voice.py +128 -0
- mkgui/__init__.py +0 -0
- mkgui/app.py +145 -0
- mkgui/app_vc.py +166 -0
- mkgui/base/__init__.py +2 -0
- mkgui/base/api/__init__.py +1 -0
- mkgui/base/api/fastapi_utils.py +102 -0
- mkgui/base/components/__init__.py +0 -0
- mkgui/base/components/outputs.py +43 -0
- mkgui/base/components/types.py +46 -0
- mkgui/base/core.py +203 -0
- mkgui/base/ui/__init__.py +1 -0
- mkgui/base/ui/schema_utils.py +129 -0
- mkgui/base/ui/streamlit_ui.py +888 -0
- mkgui/base/ui/streamlit_utils.py +13 -0
- mkgui/preprocess.py +96 -0
- mkgui/static/mb.png +0 -0
- mkgui/train.py +106 -0
- mkgui/train_vc.py +155 -0
- packages.txt +5 -0
- ppg2mel/__init__.py +209 -0
- ppg2mel/preprocess.py +113 -0
- ppg2mel/rnn_decoder_mol.py +374 -0
.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.aux
|
3 |
+
*.log
|
4 |
+
*.out
|
5 |
+
*.synctex.gz
|
6 |
+
*.suo
|
7 |
+
*__pycache__
|
8 |
+
*.idea
|
9 |
+
*.ipynb_checkpoints
|
10 |
+
*.pickle
|
11 |
+
*.npy
|
12 |
+
*.blg
|
13 |
+
*.bbl
|
14 |
+
*.bcf
|
15 |
+
*.toc
|
16 |
+
*.sh
|
17 |
+
wavs
|
18 |
+
log
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributor Covenant Code of Conduct
|
2 |
+
## First of all
|
3 |
+
Don't be evil, never
|
4 |
+
|
5 |
+
## Our Pledge
|
6 |
+
|
7 |
+
We as members, contributors, and leaders pledge to make participation in our
|
8 |
+
community a harassment-free experience for everyone, regardless of age, body
|
9 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
10 |
+
identity and expression, level of experience, education, socio-economic status,
|
11 |
+
nationality, personal appearance, race, religion, or sexual identity
|
12 |
+
and orientation.
|
13 |
+
|
14 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
15 |
+
diverse, inclusive, and healthy community.
|
16 |
+
|
17 |
+
## Our Standards
|
18 |
+
|
19 |
+
Examples of behavior that contributes to a positive environment for our
|
20 |
+
community include:
|
21 |
+
|
22 |
+
* Demonstrating empathy and kindness toward other people
|
23 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
24 |
+
* Giving and gracefully accepting constructive feedback
|
25 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
26 |
+
and learning from the experience
|
27 |
+
* Focusing on what is best not just for us as individuals, but for the
|
28 |
+
overall community
|
29 |
+
|
30 |
+
Examples of unacceptable behavior include:
|
31 |
+
|
32 |
+
* The use of sexualized language or imagery, and sexual attention or
|
33 |
+
advances of any kind
|
34 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
35 |
+
* Public or private harassment
|
36 |
+
* Publishing others' private information, such as a physical or email
|
37 |
+
address, without their explicit permission
|
38 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
39 |
+
professional setting
|
40 |
+
|
41 |
+
## Enforcement Responsibilities
|
42 |
+
|
43 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
44 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
45 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
46 |
+
or harmful.
|
47 |
+
|
48 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
49 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
50 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
51 |
+
decisions when appropriate.
|
52 |
+
|
53 |
+
## Scope
|
54 |
+
|
55 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
56 |
+
an individual is officially representing the community in public spaces.
|
57 |
+
Examples of representing our community include using an official e-mail address,
|
58 |
+
posting via an official social media account, or acting as an appointed
|
59 |
+
representative at an online or offline event.
|
60 |
+
|
61 |
+
## Enforcement
|
62 |
+
|
63 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
64 |
+
reported to the community leaders responsible for enforcement at
|
65 | |
66 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
67 |
+
|
68 |
+
All community leaders are obligated to respect the privacy and security of the
|
69 |
+
reporter of any incident.
|
70 |
+
|
71 |
+
## Enforcement Guidelines
|
72 |
+
|
73 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
74 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
75 |
+
|
76 |
+
### 1. Correction
|
77 |
+
|
78 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
79 |
+
unprofessional or unwelcome in the community.
|
80 |
+
|
81 |
+
**Consequence**: A private, written warning from community leaders, providing
|
82 |
+
clarity around the nature of the violation and an explanation of why the
|
83 |
+
behavior was inappropriate. A public apology may be requested.
|
84 |
+
|
85 |
+
### 2. Warning
|
86 |
+
|
87 |
+
**Community Impact**: A violation through a single incident or series
|
88 |
+
of actions.
|
89 |
+
|
90 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
91 |
+
interaction with the people involved, including unsolicited interaction with
|
92 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
93 |
+
includes avoiding interactions in community spaces as well as external channels
|
94 |
+
like social media. Violating these terms may lead to a temporary or
|
95 |
+
permanent ban.
|
96 |
+
|
97 |
+
### 3. Temporary Ban
|
98 |
+
|
99 |
+
**Community Impact**: A serious violation of community standards, including
|
100 |
+
sustained inappropriate behavior.
|
101 |
+
|
102 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
103 |
+
communication with the community for a specified period of time. No public or
|
104 |
+
private interaction with the people involved, including unsolicited interaction
|
105 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
106 |
+
Violating these terms may lead to a permanent ban.
|
107 |
+
|
108 |
+
### 4. Permanent Ban
|
109 |
+
|
110 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
111 |
+
standards, including sustained inappropriate behavior, harassment of an
|
112 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
113 |
+
|
114 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
115 |
+
the community.
|
116 |
+
|
117 |
+
## Attribution
|
118 |
+
|
119 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
120 |
+
version 2.0, available at
|
121 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
122 |
+
|
123 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
124 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
125 |
+
|
126 |
+
[homepage]: https://www.contributor-covenant.org
|
127 |
+
|
128 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
129 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
130 |
+
https://www.contributor-covenant.org/translations.
|
LICENSE.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
|
4 |
+
Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
|
5 |
+
Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
6 |
+
Original work Copyright (c) 2015 braindead (https://github.com/braindead)
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
of this software and associated documentation files (the "Software"), to deal
|
10 |
+
in the Software without restriction, including without limitation the rights
|
11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
copies of the Software, and to permit persons to whom the Software is
|
13 |
+
furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all
|
16 |
+
copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
SOFTWARE.
|
README-CN.md
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 实时语音克隆 - 中文/普通话
|
2 |
+

|
3 |
+
|
4 |
+
[](http://choosealicense.com/licenses/mit/)
|
5 |
+
|
6 |
+
### [English](README.md) | 中文
|
7 |
+
|
8 |
+
### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wiki教程](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) | [训练教程](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd)
|
9 |
+
|
10 |
+
## 特性
|
11 |
+
🌍 **中文** 支持普通话并使用多种中文数据集进行测试:aidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell 等
|
12 |
+
|
13 |
+
🤩 **PyTorch** 适用于 pytorch,已在 1.9.0 版本(最新于 2021 年 8 月)中测试,GPU Tesla T4 和 GTX 2060
|
14 |
+
|
15 |
+
🌍 **Windows + Linux** 可在 Windows 操作系统和 linux 操作系统中运行(苹果系统M1版也有社区成功运行案例)
|
16 |
+
|
17 |
+
🤩 **Easy & Awesome** 仅需下载或新训练合成器(synthesizer)就有良好效果,复用预训练的编码器/声码器,或实时的HiFi-GAN作为vocoder
|
18 |
+
|
19 |
+
🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
|
20 |
+
|
21 |
+
### 进行中的工作
|
22 |
+
* GUI/客户端大升级与合并
|
23 |
+
[X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
|
24 |
+
[X] 增加 Voice Cloning and Conversion的演示页面
|
25 |
+
[X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面
|
26 |
+
[ ] 增加其他的的预处理preprocessing 和训练 training 页面
|
27 |
+
* 模型后端基于ESPnet2升级
|
28 |
+
|
29 |
+
|
30 |
+
## 开始
|
31 |
+
### 1. 安装要求
|
32 |
+
> 按照原始存储库测试您是否已准备好所有环境。
|
33 |
+
运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本** 。
|
34 |
+
|
35 |
+
* 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
|
36 |
+
> 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功
|
37 |
+
* 安装 [ffmpeg](https://ffmpeg.org/download.html#get-packages)。
|
38 |
+
* 运行`pip install -r requirements.txt` 来安装剩余的必要包。
|
39 |
+
* 安装 webrtcvad `pip install webrtcvad-wheels`。
|
40 |
+
|
41 |
+
### 2. 准备预训练模型
|
42 |
+
考虑训练您自己专属的模型或者下载社区他人训练好的模型:
|
43 |
+
> 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得,也欢迎提问
|
44 |
+
#### 2.1 使用数据集自己训练encoder模型 (可选)
|
45 |
+
|
46 |
+
* 进行音频和梅尔频谱图预处理:
|
47 |
+
`python encoder_preprocess.py <datasets_root>`
|
48 |
+
使用`-d {dataset}` 指定数据集,支持 librispeech_other,voxceleb1,aidatatang_200zh,使用逗号分割处理多数据集。
|
49 |
+
* 训练encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
|
50 |
+
> 训练encoder使用了visdom。你可以加上`-no_visdom`禁用visdom,但是有可视化会更好。在单独的命令行/进程中运行"visdom"来启动visdom服务器。
|
51 |
+
|
52 |
+
#### 2.2 使用数据集自己训练合成器模型(与2.3二选一)
|
53 |
+
* 下载 数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
54 |
+
* 进行音频和梅尔频谱图预处理:
|
55 |
+
`python pre.py <datasets_root> -d {dataset} -n {number}`
|
56 |
+
可传入参数:
|
57 |
+
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, magicdata, aishell3, data_aishell, 不传默认为aidatatang_200zh
|
58 |
+
* `-n {number}` 指定并行数,CPU 11770k + 32GB实测10没有问题
|
59 |
+
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
60 |
+
|
61 |
+
* 训练合成器:
|
62 |
+
`python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
|
63 |
+
|
64 |
+
* 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。
|
65 |
+
|
66 |
+
#### 2.3使用社区预先训练好的合成器(与2.2二选一)
|
67 |
+
> 当实在没有设备或者不想慢慢调试,可以使用社区贡献的模型(欢迎持续分享):
|
68 |
+
|
69 |
+
| 作者 | 下载链接 | 效果预览 | 信息 |
|
70 |
+
| --- | ----------- | ----- | ----- |
|
71 |
+
| 作者 | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [百度盘链接](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps 用3个开源数据集混合训练
|
72 |
+
| 作者 | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [百度盘链接](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) 提取码:om7f | | 25k steps 用3个开源数据集混合训练, 切换到tag v0.0.1使用
|
73 |
+
|@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音需切换到tag v0.0.1使用
|
74 |
+
|@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 注意:根据[issue](https://github.com/babysor/MockingBird/issues/37)修复 并切换到tag v0.0.1使用
|
75 |
+
|
76 |
+
#### 2.4训练声码器 (可选)
|
77 |
+
对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。
|
78 |
+
* 预处理数据:
|
79 |
+
`python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
|
80 |
+
> `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_models\xxx*
|
81 |
+
|
82 |
+
|
83 |
+
* 训练wavernn声码器:
|
84 |
+
`python vocoder_train.py <trainid> <datasets_root>`
|
85 |
+
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
86 |
+
|
87 |
+
* 训练hifigan声码器:
|
88 |
+
`python vocoder_train.py <trainid> <datasets_root> hifigan`
|
89 |
+
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
90 |
+
* 训练fregan声码器:
|
91 |
+
`python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
|
92 |
+
> `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
|
93 |
+
* 将GAN声码器的训练切换为多GPU模式:修改GAN文件夹下.json文件中的"num_gpus"参数
|
94 |
+
### 3. 启动程序或工具箱
|
95 |
+
您可以尝试使用以下命令:
|
96 |
+
|
97 |
+
### 3.1 启动Web程序(v2):
|
98 |
+
`python web.py`
|
99 |
+
运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
|
100 |
+
> * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
|
101 |
+
|
102 |
+
### 3.2 启动工具箱:
|
103 |
+
`python demo_toolbox.py -d <datasets_root>`
|
104 |
+
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
|
105 |
+
|
106 |
+
<img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
|
107 |
+
|
108 |
+
### 4. 番外:语音转换Voice Conversion(PPG based)
|
109 |
+
想像柯南拿着变声器然后发出毛利小五郎的声音吗?本项目现基于PPG-VC,引入额外两个模块(PPG extractor + PPG2Mel), 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
|
110 |
+
#### 4.0 准备环境
|
111 |
+
* 确保项目以上环境已经安装ok,运行`pip install espnet` 来安装剩余的必要包。
|
112 |
+
* 下载以下模型 链接:https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
|
113 |
+
提取码:gh41
|
114 |
+
* 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_models\xxx*
|
115 |
+
* 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_models\xxx*
|
116 |
+
* 预训练的PPG2Mel到 *ppg2mel\saved_models\xxx*
|
117 |
+
|
118 |
+
#### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
|
119 |
+
|
120 |
+
* 下载aidatatang_200zh数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
|
121 |
+
* 进行音频和梅尔频谱图预处理:
|
122 |
+
`python pre4ppg.py <datasets_root> -d {dataset} -n {number}`
|
123 |
+
可传入参数:
|
124 |
+
* `-d {dataset}` 指定数据集,支持 aidatatang_200zh, 不传默认为aidatatang_200zh
|
125 |
+
* `-n {number}` 指定并行数,CPU 11770k在8的情况下,需要运行12到18小时!待优化
|
126 |
+
> 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
|
127 |
+
|
128 |
+
* 训练合成器, 注意在上一步先下载好`ppg2mel.yaml`, 修改里面的地址指向预训练好的文件夹:
|
129 |
+
`python ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc `
|
130 |
+
* 如果想要继续上一次的训练,可以通过`--load .\ppg2mel\saved_models\<old_pt_file>` 参数指定一个预训练模型文件。
|
131 |
+
|
132 |
+
#### 4.2 启动工具箱VC模式
|
133 |
+
您可以尝试使用以下命令:
|
134 |
+
`python demo_toolbox.py -vc -d <datasets_root>`
|
135 |
+
> 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
|
136 |
+
<img width="971" alt="微信图片_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
|
137 |
+
|
138 |
+
## 引用及论文
|
139 |
+
> 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
|
140 |
+
|
141 |
+
| URL | Designation | 标题 | 实现源码 |
|
142 |
+
| --- | ----------- | ----- | --------------------- |
|
143 |
+
| [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
|
144 |
+
| [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
|
145 |
+
| [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | 本代码库 |
|
146 |
+
|[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
|
147 |
+
|[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
|
148 |
+
|[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
|
149 |
+
|[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
|
150 |
+
|
151 |
+
## 常見問題(FQ&A)
|
152 |
+
#### 1.數據集哪裡下載?
|
153 |
+
| 数据集 | OpenSLR地址 | 其他源 (Google Drive, Baidu网盘等) |
|
154 |
+
| --- | ----------- | ---------------|
|
155 |
+
| aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
|
156 |
+
| magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
|
157 |
+
| aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
|
158 |
+
| data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
|
159 |
+
> 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮
|
160 |
+
|
161 |
+
#### 2.`<datasets_root>`是什麼意思?
|
162 |
+
假如數據集路徑為 `D:\data\aidatatang_200zh`,那麼 `<datasets_root>`就是 `D:\data`
|
163 |
+
|
164 |
+
#### 3.訓練模型顯存不足
|
165 |
+
訓練合成器時:將 `synthesizer/hparams.py`中的batch_size參數調小
|
166 |
+
```
|
167 |
+
//調整前
|
168 |
+
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
169 |
+
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
170 |
+
(2, 2e-4, 80_000, 12), #
|
171 |
+
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
172 |
+
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
173 |
+
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
174 |
+
//調整後
|
175 |
+
tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
|
176 |
+
(2, 5e-4, 40_000, 8), # (r, lr, step, batch_size)
|
177 |
+
(2, 2e-4, 80_000, 8), #
|
178 |
+
(2, 1e-4, 160_000, 8), # r = reduction factor (# of mel frames
|
179 |
+
(2, 3e-5, 320_000, 8), # synthesized for each decoder iteration)
|
180 |
+
(2, 1e-5, 640_000, 8)], # lr = learning rate
|
181 |
+
```
|
182 |
+
|
183 |
+
聲碼器-預處理數據集時:將 `synthesizer/hparams.py`中的batch_size參數調小
|
184 |
+
```
|
185 |
+
//調整前
|
186 |
+
### Data Preprocessing
|
187 |
+
max_mel_frames = 900,
|
188 |
+
rescale = True,
|
189 |
+
rescaling_max = 0.9,
|
190 |
+
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
191 |
+
//調整後
|
192 |
+
### Data Preprocessing
|
193 |
+
max_mel_frames = 900,
|
194 |
+
rescale = True,
|
195 |
+
rescaling_max = 0.9,
|
196 |
+
synthesis_batch_size = 8, # For vocoder preprocessing and inference.
|
197 |
+
```
|
198 |
+
|
199 |
+
聲碼器-訓練聲碼器時:將 `vocoder/wavernn/hparams.py`中的batch_size參數調小
|
200 |
+
```
|
201 |
+
//調整前
|
202 |
+
# Training
|
203 |
+
voc_batch_size = 100
|
204 |
+
voc_lr = 1e-4
|
205 |
+
voc_gen_at_checkpoint = 5
|
206 |
+
voc_pad = 2
|
207 |
+
|
208 |
+
//調整後
|
209 |
+
# Training
|
210 |
+
voc_batch_size = 6
|
211 |
+
voc_lr = 1e-4
|
212 |
+
voc_gen_at_checkpoint = 5
|
213 |
+
voc_pad =2
|
214 |
+
```
|
215 |
+
|
216 |
+
#### 4.碰到`RuntimeError: Error(s) in loading state_dict for Tacotron: size mismatch for encoder.embedding.weight: copying a param with shape torch.Size([70, 512]) from checkpoint, the shape in current model is torch.Size([75, 512]).`
|
217 |
+
請參照 issue [#37](https://github.com/babysor/MockingBird/issues/37)
|
218 |
+
|
219 |
+
#### 5.如何改善CPU、GPU佔用率?
|
220 |
+
適情況調整batch_size參數來改善
|
221 |
+
|
222 |
+
#### 6.發生 `頁面文件太小,無法完成操作`
|
223 |
+
請參考這篇[文章](https://blog.csdn.net/qq_17755303/article/details/112564030),將虛擬內存更改為100G(102400),例如:档案放置D槽就更改D槽的虚拟内存
|
224 |
+
|
225 |
+
#### 7.什么时候算训练完成?
|
226 |
+
首先一定要出现注意力模型,其次是loss足够低,取决于硬件设备和数据集。拿本人的供参考,我的注意力是在 18k 步之后出现的,并且在 50k 步之后损失变得低于 0.4
|
227 |
+

|
228 |
+
|
229 |
+

|
230 |
+
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MockingBird
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.1.3
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
import re
|
5 |
+
import random
|
6 |
+
import string
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
from scipy.io.wavfile import write
|
12 |
+
|
13 |
+
from encoder import inference as encoder
|
14 |
+
from vocoder.hifigan import inference as gan_vocoder
|
15 |
+
from synthesizer.inference import Synthesizer
|
16 |
+
|
17 |
+
class Mandarin:
|
18 |
+
def __init__(self):
|
19 |
+
self.encoder_path = "encoder/saved_models/pretrained.pt"
|
20 |
+
self.vocoder_path = "vocoder/saved_models/pretrained/g_hifigan.pt"
|
21 |
+
self.config_fpath = "vocoder/hifigan/config_16k_.json"
|
22 |
+
self.accent = "synthesizer/saved_models/普通话.pt"
|
23 |
+
|
24 |
+
synthesizers_cache = {}
|
25 |
+
if synthesizers_cache.get(self.accent) is None:
|
26 |
+
self.current_synt = Synthesizer(Path(self.accent))
|
27 |
+
synthesizers_cache[self.accent] = self.current_synt
|
28 |
+
else:
|
29 |
+
self.current_synt = synthesizers_cache[self.accent]
|
30 |
+
|
31 |
+
encoder.load_model(Path(self.encoder_path))
|
32 |
+
gan_vocoder.load_model(Path(self.vocoder_path), self.config_fpath)
|
33 |
+
|
34 |
+
def setVoice(self, timbre):
|
35 |
+
self.timbre = timbre
|
36 |
+
wav, sample_rate, = librosa.load(self.timbre)
|
37 |
+
|
38 |
+
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
39 |
+
self.embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
40 |
+
|
41 |
+
def say(self, text):
|
42 |
+
texts = filter(None, text.split("\n"))
|
43 |
+
punctuation = "!,。、?!,.?::" # punctuate and split/clean text
|
44 |
+
processed_texts = []
|
45 |
+
for text in texts:
|
46 |
+
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
47 |
+
if processed_text:
|
48 |
+
processed_texts.append(processed_text.strip())
|
49 |
+
texts = processed_texts
|
50 |
+
embeds = [self.embed] * len(texts)
|
51 |
+
|
52 |
+
specs = self.current_synt.synthesize_spectrograms(texts, embeds)
|
53 |
+
spec = np.concatenate(specs, axis=1)
|
54 |
+
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
55 |
+
|
56 |
+
return wav, sample_rate
|
57 |
+
|
58 |
+
def greet(audio, text, voice=None):
|
59 |
+
|
60 |
+
if voice is None:
|
61 |
+
voice = Mandarin()
|
62 |
+
voice.setVoice(audio.name)
|
63 |
+
voice.say("加载成功")
|
64 |
+
wav, sample_rate = voice.say(text)
|
65 |
+
|
66 |
+
output_file = "".join( random.sample(string.ascii_lowercase + string.digits, 11) ) + ".wav"
|
67 |
+
|
68 |
+
write(output_file, sample_rate, wav.astype(np.float32))
|
69 |
+
|
70 |
+
return output_file, voice
|
71 |
+
|
72 |
+
def main():
|
73 |
+
gr.Interface(
|
74 |
+
fn=greet,
|
75 |
+
inputs=[gr.inputs.Audio(type="file"),"text", "state"],
|
76 |
+
outputs=[gr.outputs.Audio(type="file"), "state"]
|
77 |
+
).launch()
|
78 |
+
|
79 |
+
if __name__=="__main__":
|
80 |
+
main()
|
demo_toolbox.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from toolbox import Toolbox
|
3 |
+
from utils.argutils import print_args
|
4 |
+
from utils.modelutils import check_model_paths
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
parser = argparse.ArgumentParser(
|
11 |
+
description="Runs the toolbox",
|
12 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
13 |
+
)
|
14 |
+
|
15 |
+
parser.add_argument("-d", "--datasets_root", type=Path, help= \
|
16 |
+
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
|
17 |
+
"supported datasets.", default=None)
|
18 |
+
parser.add_argument("-vc", "--vc_mode", action="store_true",
|
19 |
+
help="Voice Conversion Mode(PPG based)")
|
20 |
+
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
21 |
+
help="Directory containing saved encoder models")
|
22 |
+
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
23 |
+
help="Directory containing saved synthesizer models")
|
24 |
+
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
25 |
+
help="Directory containing saved vocoder models")
|
26 |
+
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models",
|
27 |
+
help="Directory containing saved extrator models")
|
28 |
+
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models",
|
29 |
+
help="Directory containing saved convert models")
|
30 |
+
parser.add_argument("--cpu", action="store_true", help=\
|
31 |
+
"If True, processing is done on CPU, even when a GPU is available.")
|
32 |
+
parser.add_argument("--seed", type=int, default=None, help=\
|
33 |
+
"Optional random number seed value to make toolbox deterministic.")
|
34 |
+
parser.add_argument("--no_mp3_support", action="store_true", help=\
|
35 |
+
"If True, no mp3 files are allowed.")
|
36 |
+
args = parser.parse_args()
|
37 |
+
print_args(args, parser)
|
38 |
+
|
39 |
+
if args.cpu:
|
40 |
+
# Hide GPUs from Pytorch to force CPU processing
|
41 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
42 |
+
del args.cpu
|
43 |
+
|
44 |
+
## Remind the user to download pretrained models if needed
|
45 |
+
check_model_paths(encoder_path=args.enc_models_dir, synthesizer_path=args.syn_models_dir,
|
46 |
+
vocoder_path=args.voc_models_dir)
|
47 |
+
|
48 |
+
# Launch the toolbox
|
49 |
+
Toolbox(**vars(args))
|
encoder/__init__.py
ADDED
File without changes
|
encoder/audio.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.ndimage.morphology import binary_dilation
|
2 |
+
from encoder.params_data import *
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
from warnings import warn
|
6 |
+
import numpy as np
|
7 |
+
import librosa
|
8 |
+
import struct
|
9 |
+
|
10 |
+
try:
|
11 |
+
import webrtcvad
|
12 |
+
except:
|
13 |
+
warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
|
14 |
+
webrtcvad=None
|
15 |
+
|
16 |
+
int16_max = (2 ** 15) - 1
|
17 |
+
|
18 |
+
|
19 |
+
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
20 |
+
source_sr: Optional[int] = None,
|
21 |
+
normalize: Optional[bool] = True,
|
22 |
+
trim_silence: Optional[bool] = True):
|
23 |
+
"""
|
24 |
+
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
25 |
+
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
26 |
+
|
27 |
+
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
28 |
+
just .wav), either the waveform as a numpy array of floats.
|
29 |
+
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
30 |
+
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
31 |
+
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
32 |
+
this argument will be ignored.
|
33 |
+
"""
|
34 |
+
# Load the wav from disk if needed
|
35 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
36 |
+
wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
|
37 |
+
else:
|
38 |
+
wav = fpath_or_wav
|
39 |
+
|
40 |
+
# Resample the wav if needed
|
41 |
+
if source_sr is not None and source_sr != sampling_rate:
|
42 |
+
wav = librosa.resample(wav, source_sr, sampling_rate)
|
43 |
+
|
44 |
+
# Apply the preprocessing: normalize volume and shorten long silences
|
45 |
+
if normalize:
|
46 |
+
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
47 |
+
if webrtcvad and trim_silence:
|
48 |
+
wav = trim_long_silences(wav)
|
49 |
+
|
50 |
+
return wav
|
51 |
+
|
52 |
+
|
53 |
+
def wav_to_mel_spectrogram(wav):
|
54 |
+
"""
|
55 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
56 |
+
Note: this not a log-mel spectrogram.
|
57 |
+
"""
|
58 |
+
frames = librosa.feature.melspectrogram(
|
59 |
+
y=wav,
|
60 |
+
sr=sampling_rate,
|
61 |
+
n_fft=int(sampling_rate * mel_window_length / 1000),
|
62 |
+
hop_length=int(sampling_rate * mel_window_step / 1000),
|
63 |
+
n_mels=mel_n_channels
|
64 |
+
)
|
65 |
+
return frames.astype(np.float32).T
|
66 |
+
|
67 |
+
|
68 |
+
def trim_long_silences(wav):
|
69 |
+
"""
|
70 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
71 |
+
threshold determined by the VAD parameters in params.py.
|
72 |
+
|
73 |
+
:param wav: the raw waveform as a numpy array of floats
|
74 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
75 |
+
"""
|
76 |
+
# Compute the voice detection window size
|
77 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
78 |
+
|
79 |
+
# Trim the end of the audio to have a multiple of the window size
|
80 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
81 |
+
|
82 |
+
# Convert the float waveform to 16-bit mono PCM
|
83 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
84 |
+
|
85 |
+
# Perform voice activation detection
|
86 |
+
voice_flags = []
|
87 |
+
vad = webrtcvad.Vad(mode=3)
|
88 |
+
for window_start in range(0, len(wav), samples_per_window):
|
89 |
+
window_end = window_start + samples_per_window
|
90 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
91 |
+
sample_rate=sampling_rate))
|
92 |
+
voice_flags = np.array(voice_flags)
|
93 |
+
|
94 |
+
# Smooth the voice detection with a moving average
|
95 |
+
def moving_average(array, width):
|
96 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
97 |
+
ret = np.cumsum(array_padded, dtype=float)
|
98 |
+
ret[width:] = ret[width:] - ret[:-width]
|
99 |
+
return ret[width - 1:] / width
|
100 |
+
|
101 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
102 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
103 |
+
|
104 |
+
# Dilate the voiced regions
|
105 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
106 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
107 |
+
|
108 |
+
return wav[audio_mask == True]
|
109 |
+
|
110 |
+
|
111 |
+
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
112 |
+
if increase_only and decrease_only:
|
113 |
+
raise ValueError("Both increase only and decrease only are set")
|
114 |
+
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
|
115 |
+
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
116 |
+
return wav
|
117 |
+
return wav * (10 ** (dBFS_change / 20))
|
encoder/config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librispeech_datasets = {
|
2 |
+
"train": {
|
3 |
+
"clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
|
4 |
+
"other": ["LibriSpeech/train-other-500"]
|
5 |
+
},
|
6 |
+
"test": {
|
7 |
+
"clean": ["LibriSpeech/test-clean"],
|
8 |
+
"other": ["LibriSpeech/test-other"]
|
9 |
+
},
|
10 |
+
"dev": {
|
11 |
+
"clean": ["LibriSpeech/dev-clean"],
|
12 |
+
"other": ["LibriSpeech/dev-other"]
|
13 |
+
},
|
14 |
+
}
|
15 |
+
libritts_datasets = {
|
16 |
+
"train": {
|
17 |
+
"clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
|
18 |
+
"other": ["LibriTTS/train-other-500"]
|
19 |
+
},
|
20 |
+
"test": {
|
21 |
+
"clean": ["LibriTTS/test-clean"],
|
22 |
+
"other": ["LibriTTS/test-other"]
|
23 |
+
},
|
24 |
+
"dev": {
|
25 |
+
"clean": ["LibriTTS/dev-clean"],
|
26 |
+
"other": ["LibriTTS/dev-other"]
|
27 |
+
},
|
28 |
+
}
|
29 |
+
voxceleb_datasets = {
|
30 |
+
"voxceleb1" : {
|
31 |
+
"train": ["VoxCeleb1/wav"],
|
32 |
+
"test": ["VoxCeleb1/test_wav"]
|
33 |
+
},
|
34 |
+
"voxceleb2" : {
|
35 |
+
"train": ["VoxCeleb2/dev/aac"],
|
36 |
+
"test": ["VoxCeleb2/test_wav"]
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
other_datasets = [
|
41 |
+
"LJSpeech-1.1",
|
42 |
+
"VCTK-Corpus/wav48",
|
43 |
+
]
|
44 |
+
|
45 |
+
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
encoder/data_objects/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
encoder/data_objects/random_cycler.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
class RandomCycler:
|
4 |
+
"""
|
5 |
+
Creates an internal copy of a sequence and allows access to its items in a constrained random
|
6 |
+
order. For a source sequence of n items and one or several consecutive queries of a total
|
7 |
+
of m items, the following guarantees hold (one implies the other):
|
8 |
+
- Each item will be returned between m // n and ((m - 1) // n) + 1 times.
|
9 |
+
- Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, source):
|
13 |
+
if len(source) == 0:
|
14 |
+
raise Exception("Can't create RandomCycler from an empty collection")
|
15 |
+
self.all_items = list(source)
|
16 |
+
self.next_items = []
|
17 |
+
|
18 |
+
def sample(self, count: int):
|
19 |
+
shuffle = lambda l: random.sample(l, len(l))
|
20 |
+
|
21 |
+
out = []
|
22 |
+
while count > 0:
|
23 |
+
if count >= len(self.all_items):
|
24 |
+
out.extend(shuffle(list(self.all_items)))
|
25 |
+
count -= len(self.all_items)
|
26 |
+
continue
|
27 |
+
n = min(count, len(self.next_items))
|
28 |
+
out.extend(self.next_items[:n])
|
29 |
+
count -= n
|
30 |
+
self.next_items = self.next_items[n:]
|
31 |
+
if len(self.next_items) == 0:
|
32 |
+
self.next_items = shuffle(list(self.all_items))
|
33 |
+
return out
|
34 |
+
|
35 |
+
def __next__(self):
|
36 |
+
return self.sample(1)[0]
|
37 |
+
|
encoder/data_objects/speaker.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from encoder.data_objects.utterance import Utterance
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# Contains the set of utterances of a single speaker
|
6 |
+
class Speaker:
|
7 |
+
def __init__(self, root: Path):
|
8 |
+
self.root = root
|
9 |
+
self.name = root.name
|
10 |
+
self.utterances = None
|
11 |
+
self.utterance_cycler = None
|
12 |
+
|
13 |
+
def _load_utterances(self):
|
14 |
+
with self.root.joinpath("_sources.txt").open("r") as sources_file:
|
15 |
+
sources = [l.split(",") for l in sources_file]
|
16 |
+
sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
|
17 |
+
self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
|
18 |
+
self.utterance_cycler = RandomCycler(self.utterances)
|
19 |
+
|
20 |
+
def random_partial(self, count, n_frames):
|
21 |
+
"""
|
22 |
+
Samples a batch of <count> unique partial utterances from the disk in a way that all
|
23 |
+
utterances come up at least once every two cycles and in a random order every time.
|
24 |
+
|
25 |
+
:param count: The number of partial utterances to sample from the set of utterances from
|
26 |
+
that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
|
27 |
+
the number of utterances available.
|
28 |
+
:param n_frames: The number of frames in the partial utterance.
|
29 |
+
:return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
|
30 |
+
frames are the frames of the partial utterances and range is the range of the partial
|
31 |
+
utterance with regard to the complete utterance.
|
32 |
+
"""
|
33 |
+
if self.utterances is None:
|
34 |
+
self._load_utterances()
|
35 |
+
|
36 |
+
utterances = self.utterance_cycler.sample(count)
|
37 |
+
|
38 |
+
a = [(u,) + u.random_partial(n_frames) for u in utterances]
|
39 |
+
|
40 |
+
return a
|
encoder/data_objects/speaker_batch.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List
|
3 |
+
from encoder.data_objects.speaker import Speaker
|
4 |
+
|
5 |
+
class SpeakerBatch:
|
6 |
+
def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
|
7 |
+
self.speakers = speakers
|
8 |
+
self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
|
9 |
+
|
10 |
+
# Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
|
11 |
+
# 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
|
12 |
+
self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
|
encoder/data_objects/speaker_verification_dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from encoder.data_objects.speaker_batch import SpeakerBatch
|
3 |
+
from encoder.data_objects.speaker import Speaker
|
4 |
+
from encoder.params_data import partials_n_frames
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# TODO: improve with a pool of speakers for data efficiency
|
9 |
+
|
10 |
+
class SpeakerVerificationDataset(Dataset):
|
11 |
+
def __init__(self, datasets_root: Path):
|
12 |
+
self.root = datasets_root
|
13 |
+
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
14 |
+
if len(speaker_dirs) == 0:
|
15 |
+
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
16 |
+
"containing all preprocessed speaker directories.")
|
17 |
+
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
18 |
+
self.speaker_cycler = RandomCycler(self.speakers)
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return int(1e10)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
return next(self.speaker_cycler)
|
25 |
+
|
26 |
+
def get_logs(self):
|
27 |
+
log_string = ""
|
28 |
+
for log_fpath in self.root.glob("*.txt"):
|
29 |
+
with log_fpath.open("r") as log_file:
|
30 |
+
log_string += "".join(log_file.readlines())
|
31 |
+
return log_string
|
32 |
+
|
33 |
+
|
34 |
+
class SpeakerVerificationDataLoader(DataLoader):
|
35 |
+
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
|
36 |
+
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
|
37 |
+
worker_init_fn=None):
|
38 |
+
self.utterances_per_speaker = utterances_per_speaker
|
39 |
+
|
40 |
+
super().__init__(
|
41 |
+
dataset=dataset,
|
42 |
+
batch_size=speakers_per_batch,
|
43 |
+
shuffle=False,
|
44 |
+
sampler=sampler,
|
45 |
+
batch_sampler=batch_sampler,
|
46 |
+
num_workers=num_workers,
|
47 |
+
collate_fn=self.collate,
|
48 |
+
pin_memory=pin_memory,
|
49 |
+
drop_last=False,
|
50 |
+
timeout=timeout,
|
51 |
+
worker_init_fn=worker_init_fn
|
52 |
+
)
|
53 |
+
|
54 |
+
def collate(self, speakers):
|
55 |
+
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|
56 |
+
|
encoder/data_objects/utterance.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Utterance:
|
5 |
+
def __init__(self, frames_fpath, wave_fpath):
|
6 |
+
self.frames_fpath = frames_fpath
|
7 |
+
self.wave_fpath = wave_fpath
|
8 |
+
|
9 |
+
def get_frames(self):
|
10 |
+
return np.load(self.frames_fpath)
|
11 |
+
|
12 |
+
def random_partial(self, n_frames):
|
13 |
+
"""
|
14 |
+
Crops the frames into a partial utterance of n_frames
|
15 |
+
|
16 |
+
:param n_frames: The number of frames of the partial utterance
|
17 |
+
:return: the partial utterance frames and a tuple indicating the start and end of the
|
18 |
+
partial utterance in the complete utterance.
|
19 |
+
"""
|
20 |
+
frames = self.get_frames()
|
21 |
+
if frames.shape[0] == n_frames:
|
22 |
+
start = 0
|
23 |
+
else:
|
24 |
+
start = np.random.randint(0, frames.shape[0] - n_frames)
|
25 |
+
end = start + n_frames
|
26 |
+
return frames[start:end], (start, end)
|
encoder/inference.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_data import *
|
2 |
+
from encoder.model import SpeakerEncoder
|
3 |
+
from encoder.audio import preprocess_wav # We want to expose this function from here
|
4 |
+
from matplotlib import cm
|
5 |
+
from encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
_model = None # type: SpeakerEncoder
|
12 |
+
_device = None # type: torch.device
|
13 |
+
|
14 |
+
|
15 |
+
def load_model(weights_fpath: Path, device=None):
|
16 |
+
"""
|
17 |
+
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
18 |
+
first call to embed_frames() with the default weights file.
|
19 |
+
|
20 |
+
:param weights_fpath: the path to saved model weights.
|
21 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
22 |
+
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
23 |
+
If None, will default to your GPU if it"s available, otherwise your CPU.
|
24 |
+
"""
|
25 |
+
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
26 |
+
# was saved on. Worth investigating.
|
27 |
+
global _model, _device
|
28 |
+
if device is None:
|
29 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
elif isinstance(device, str):
|
31 |
+
_device = torch.device(device)
|
32 |
+
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
33 |
+
checkpoint = torch.load(weights_fpath, _device)
|
34 |
+
_model.load_state_dict(checkpoint["model_state"])
|
35 |
+
_model.eval()
|
36 |
+
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
37 |
+
return _model
|
38 |
+
|
39 |
+
def set_model(model, device=None):
|
40 |
+
global _model, _device
|
41 |
+
_model = model
|
42 |
+
if device is None:
|
43 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
+
_device = device
|
45 |
+
_model.to(device)
|
46 |
+
|
47 |
+
def is_loaded():
|
48 |
+
return _model is not None
|
49 |
+
|
50 |
+
|
51 |
+
def embed_frames_batch(frames_batch):
|
52 |
+
"""
|
53 |
+
Computes embeddings for a batch of mel spectrogram.
|
54 |
+
|
55 |
+
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
56 |
+
(batch_size, n_frames, n_channels)
|
57 |
+
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
58 |
+
"""
|
59 |
+
if _model is None:
|
60 |
+
raise Exception("Model was not loaded. Call load_model() before inference.")
|
61 |
+
|
62 |
+
frames = torch.from_numpy(frames_batch).to(_device)
|
63 |
+
embed = _model.forward(frames).detach().cpu().numpy()
|
64 |
+
return embed
|
65 |
+
|
66 |
+
|
67 |
+
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
68 |
+
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
69 |
+
"""
|
70 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
71 |
+
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
72 |
+
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
73 |
+
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
74 |
+
defined in params_data.py.
|
75 |
+
|
76 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
77 |
+
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
78 |
+
|
79 |
+
:param n_samples: the number of samples in the waveform
|
80 |
+
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
81 |
+
utterance
|
82 |
+
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
83 |
+
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
84 |
+
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
85 |
+
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
86 |
+
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
87 |
+
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
88 |
+
utterances are entirely disjoint.
|
89 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
90 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
91 |
+
utterances.
|
92 |
+
"""
|
93 |
+
assert 0 <= overlap < 1
|
94 |
+
assert 0 < min_pad_coverage <= 1
|
95 |
+
|
96 |
+
if rate != None:
|
97 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
98 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
99 |
+
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
100 |
+
else:
|
101 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
102 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
103 |
+
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
104 |
+
|
105 |
+
assert 0 < frame_step, "The rate is too high"
|
106 |
+
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
107 |
+
(sampling_rate / (samples_per_frame * partials_n_frames))
|
108 |
+
|
109 |
+
# Compute the slices
|
110 |
+
wav_slices, mel_slices = [], []
|
111 |
+
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
112 |
+
for i in range(0, steps, frame_step):
|
113 |
+
mel_range = np.array([i, i + partial_utterance_n_frames])
|
114 |
+
wav_range = mel_range * samples_per_frame
|
115 |
+
mel_slices.append(slice(*mel_range))
|
116 |
+
wav_slices.append(slice(*wav_range))
|
117 |
+
|
118 |
+
# Evaluate whether extra padding is warranted or not
|
119 |
+
last_wav_range = wav_slices[-1]
|
120 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
121 |
+
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
122 |
+
mel_slices = mel_slices[:-1]
|
123 |
+
wav_slices = wav_slices[:-1]
|
124 |
+
|
125 |
+
return wav_slices, mel_slices
|
126 |
+
|
127 |
+
|
128 |
+
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
129 |
+
"""
|
130 |
+
Computes an embedding for a single utterance.
|
131 |
+
|
132 |
+
# TODO: handle multiple wavs to benefit from batching on GPU
|
133 |
+
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
134 |
+
:param using_partials: if True, then the utterance is split in partial utterances of
|
135 |
+
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
136 |
+
normalized average. If False, the utterance is instead computed from feeding the entire
|
137 |
+
spectogram to the network.
|
138 |
+
:param return_partials: if True, the partial embeddings will also be returned along with the
|
139 |
+
wav slices that correspond to the partial embeddings.
|
140 |
+
:param kwargs: additional arguments to compute_partial_splits()
|
141 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
142 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
143 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
144 |
+
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
145 |
+
instead.
|
146 |
+
"""
|
147 |
+
# Process the entire utterance if not using partials
|
148 |
+
if not using_partials:
|
149 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
150 |
+
embed = embed_frames_batch(frames[None, ...])[0]
|
151 |
+
if return_partials:
|
152 |
+
return embed, None, None
|
153 |
+
return embed
|
154 |
+
|
155 |
+
# Compute where to split the utterance into partials and pad if necessary
|
156 |
+
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
157 |
+
max_wave_length = wave_slices[-1].stop
|
158 |
+
if max_wave_length >= len(wav):
|
159 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
160 |
+
|
161 |
+
# Split the utterance into partials
|
162 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
163 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
164 |
+
partial_embeds = embed_frames_batch(frames_batch)
|
165 |
+
|
166 |
+
# Compute the utterance embedding from the partial embeddings
|
167 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
168 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
169 |
+
|
170 |
+
if return_partials:
|
171 |
+
return embed, partial_embeds, wave_slices
|
172 |
+
return embed
|
173 |
+
|
174 |
+
|
175 |
+
def embed_speaker(wavs, **kwargs):
|
176 |
+
raise NotImplemented()
|
177 |
+
|
178 |
+
|
179 |
+
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
180 |
+
if ax is None:
|
181 |
+
ax = plt.gca()
|
182 |
+
|
183 |
+
if shape is None:
|
184 |
+
height = int(np.sqrt(len(embed)))
|
185 |
+
shape = (height, -1)
|
186 |
+
embed = embed.reshape(shape)
|
187 |
+
|
188 |
+
cmap = cm.get_cmap()
|
189 |
+
mappable = ax.imshow(embed, cmap=cmap)
|
190 |
+
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
191 |
+
sm = cm.ScalarMappable(cmap=cmap)
|
192 |
+
sm.set_clim(*color_range)
|
193 |
+
|
194 |
+
ax.set_xticks([]), ax.set_yticks([])
|
195 |
+
ax.set_title(title)
|
encoder/model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_model import *
|
2 |
+
from encoder.params_data import *
|
3 |
+
from scipy.interpolate import interp1d
|
4 |
+
from sklearn.metrics import roc_curve
|
5 |
+
from torch.nn.utils import clip_grad_norm_
|
6 |
+
from scipy.optimize import brentq
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class SpeakerEncoder(nn.Module):
|
13 |
+
def __init__(self, device, loss_device):
|
14 |
+
super().__init__()
|
15 |
+
self.loss_device = loss_device
|
16 |
+
|
17 |
+
# Network defition
|
18 |
+
self.lstm = nn.LSTM(input_size=mel_n_channels,
|
19 |
+
hidden_size=model_hidden_size,
|
20 |
+
num_layers=model_num_layers,
|
21 |
+
batch_first=True).to(device)
|
22 |
+
self.linear = nn.Linear(in_features=model_hidden_size,
|
23 |
+
out_features=model_embedding_size).to(device)
|
24 |
+
self.relu = torch.nn.ReLU().to(device)
|
25 |
+
|
26 |
+
# Cosine similarity scaling (with fixed initial parameter values)
|
27 |
+
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
28 |
+
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
29 |
+
|
30 |
+
# Loss
|
31 |
+
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
32 |
+
|
33 |
+
def do_gradient_ops(self):
|
34 |
+
# Gradient scale
|
35 |
+
self.similarity_weight.grad *= 0.01
|
36 |
+
self.similarity_bias.grad *= 0.01
|
37 |
+
|
38 |
+
# Gradient clipping
|
39 |
+
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
40 |
+
|
41 |
+
def forward(self, utterances, hidden_init=None):
|
42 |
+
"""
|
43 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
44 |
+
|
45 |
+
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
46 |
+
(batch_size, n_frames, n_channels)
|
47 |
+
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
48 |
+
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
49 |
+
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
50 |
+
"""
|
51 |
+
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
52 |
+
# and the final cell state.
|
53 |
+
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
54 |
+
|
55 |
+
# We take only the hidden state of the last layer
|
56 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
57 |
+
|
58 |
+
# L2-normalize it
|
59 |
+
embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
|
60 |
+
|
61 |
+
return embeds
|
62 |
+
|
63 |
+
def similarity_matrix(self, embeds):
|
64 |
+
"""
|
65 |
+
Computes the similarity matrix according the section 2.1 of GE2E.
|
66 |
+
|
67 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
68 |
+
utterances_per_speaker, embedding_size)
|
69 |
+
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
70 |
+
utterances_per_speaker, speakers_per_batch)
|
71 |
+
"""
|
72 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
73 |
+
|
74 |
+
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
75 |
+
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
76 |
+
centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
|
77 |
+
|
78 |
+
# Exclusive centroids (1 per utterance)
|
79 |
+
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
80 |
+
centroids_excl /= (utterances_per_speaker - 1)
|
81 |
+
centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
|
82 |
+
|
83 |
+
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
84 |
+
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
85 |
+
# We vectorize the computation for efficiency.
|
86 |
+
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
87 |
+
speakers_per_batch).to(self.loss_device)
|
88 |
+
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
89 |
+
for j in range(speakers_per_batch):
|
90 |
+
mask = np.where(mask_matrix[j])[0]
|
91 |
+
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
92 |
+
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
93 |
+
|
94 |
+
## Even more vectorized version (slower maybe because of transpose)
|
95 |
+
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
96 |
+
# ).to(self.loss_device)
|
97 |
+
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
98 |
+
# mask = np.where(1 - eye)
|
99 |
+
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
100 |
+
# mask = np.where(eye)
|
101 |
+
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
102 |
+
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
103 |
+
|
104 |
+
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
105 |
+
return sim_matrix
|
106 |
+
|
107 |
+
def loss(self, embeds):
|
108 |
+
"""
|
109 |
+
Computes the softmax loss according the section 2.1 of GE2E.
|
110 |
+
|
111 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
112 |
+
utterances_per_speaker, embedding_size)
|
113 |
+
:return: the loss and the EER for this batch of embeddings.
|
114 |
+
"""
|
115 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
116 |
+
|
117 |
+
# Loss
|
118 |
+
sim_matrix = self.similarity_matrix(embeds)
|
119 |
+
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
120 |
+
speakers_per_batch))
|
121 |
+
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
122 |
+
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
123 |
+
loss = self.loss_fn(sim_matrix, target)
|
124 |
+
|
125 |
+
# EER (not backpropagated)
|
126 |
+
with torch.no_grad():
|
127 |
+
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
128 |
+
labels = np.array([inv_argmax(i) for i in ground_truth])
|
129 |
+
preds = sim_matrix.detach().cpu().numpy()
|
130 |
+
|
131 |
+
# Snippet from https://yangcha.github.io/EER-ROC/
|
132 |
+
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
133 |
+
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
134 |
+
|
135 |
+
return loss, eer
|
encoder/params_data.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Mel-filterbank
|
3 |
+
mel_window_length = 25 # In milliseconds
|
4 |
+
mel_window_step = 10 # In milliseconds
|
5 |
+
mel_n_channels = 40
|
6 |
+
|
7 |
+
|
8 |
+
## Audio
|
9 |
+
sampling_rate = 16000
|
10 |
+
# Number of spectrogram frames in a partial utterance
|
11 |
+
partials_n_frames = 160 # 1600 ms
|
12 |
+
# Number of spectrogram frames at inference
|
13 |
+
inference_n_frames = 80 # 800 ms
|
14 |
+
|
15 |
+
|
16 |
+
## Voice Activation Detection
|
17 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
18 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
19 |
+
vad_window_length = 30 # In milliseconds
|
20 |
+
# Number of frames to average together when performing the moving average smoothing.
|
21 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
22 |
+
vad_moving_average_width = 8
|
23 |
+
# Maximum number of consecutive silent frames a segment can have.
|
24 |
+
vad_max_silence_length = 6
|
25 |
+
|
26 |
+
|
27 |
+
## Audio volume normalization
|
28 |
+
audio_norm_target_dBFS = -30
|
29 |
+
|
encoder/params_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Model parameters
|
3 |
+
model_hidden_size = 256
|
4 |
+
model_embedding_size = 256
|
5 |
+
model_num_layers = 3
|
6 |
+
|
7 |
+
|
8 |
+
## Training parameters
|
9 |
+
learning_rate_init = 1e-4
|
10 |
+
speakers_per_batch = 64
|
11 |
+
utterances_per_speaker = 10
|
encoder/preprocess.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocess.pool import ThreadPool
|
2 |
+
from encoder.params_data import *
|
3 |
+
from encoder.config import librispeech_datasets, anglophone_nationalites
|
4 |
+
from datetime import datetime
|
5 |
+
from encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import tqdm
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class DatasetLog:
|
12 |
+
"""
|
13 |
+
Registers metadata about the dataset in a text file.
|
14 |
+
"""
|
15 |
+
def __init__(self, root, name):
|
16 |
+
self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
|
17 |
+
self.sample_data = dict()
|
18 |
+
|
19 |
+
start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
20 |
+
self.write_line("Creating dataset %s on %s" % (name, start_time))
|
21 |
+
self.write_line("-----")
|
22 |
+
self._log_params()
|
23 |
+
|
24 |
+
def _log_params(self):
|
25 |
+
from encoder import params_data
|
26 |
+
self.write_line("Parameter values:")
|
27 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
28 |
+
value = getattr(params_data, param_name)
|
29 |
+
self.write_line("\t%s: %s" % (param_name, value))
|
30 |
+
self.write_line("-----")
|
31 |
+
|
32 |
+
def write_line(self, line):
|
33 |
+
self.text_file.write("%s\n" % line)
|
34 |
+
|
35 |
+
def add_sample(self, **kwargs):
|
36 |
+
for param_name, value in kwargs.items():
|
37 |
+
if not param_name in self.sample_data:
|
38 |
+
self.sample_data[param_name] = []
|
39 |
+
self.sample_data[param_name].append(value)
|
40 |
+
|
41 |
+
def finalize(self):
|
42 |
+
self.write_line("Statistics:")
|
43 |
+
for param_name, values in self.sample_data.items():
|
44 |
+
self.write_line("\t%s:" % param_name)
|
45 |
+
self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
|
46 |
+
self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
|
47 |
+
self.write_line("-----")
|
48 |
+
end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
49 |
+
self.write_line("Finished on %s" % end_time)
|
50 |
+
self.text_file.close()
|
51 |
+
|
52 |
+
|
53 |
+
def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
|
54 |
+
dataset_root = datasets_root.joinpath(dataset_name)
|
55 |
+
if not dataset_root.exists():
|
56 |
+
print("Couldn\'t find %s, skipping this dataset." % dataset_root)
|
57 |
+
return None, None
|
58 |
+
return dataset_root, DatasetLog(out_dir, dataset_name)
|
59 |
+
|
60 |
+
|
61 |
+
def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
|
62 |
+
skip_existing, logger):
|
63 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
64 |
+
|
65 |
+
# Function to preprocess utterances for one speaker
|
66 |
+
def preprocess_speaker(speaker_dir: Path):
|
67 |
+
# Give a name to the speaker that includes its dataset
|
68 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
69 |
+
|
70 |
+
# Create an output directory with that name, as well as a txt file containing a
|
71 |
+
# reference to each source file.
|
72 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
73 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
74 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
75 |
+
|
76 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
77 |
+
# there already is a sources file.
|
78 |
+
if sources_fpath.exists():
|
79 |
+
try:
|
80 |
+
with sources_fpath.open("r") as sources_file:
|
81 |
+
existing_fnames = {line.split(",")[0] for line in sources_file}
|
82 |
+
except:
|
83 |
+
existing_fnames = {}
|
84 |
+
else:
|
85 |
+
existing_fnames = {}
|
86 |
+
|
87 |
+
# Gather all audio files for that speaker recursively
|
88 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
89 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
90 |
+
# Check if the target output file already exists
|
91 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
92 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
93 |
+
if skip_existing and out_fname in existing_fnames:
|
94 |
+
continue
|
95 |
+
|
96 |
+
# Load and preprocess the waveform
|
97 |
+
wav = audio.preprocess_wav(in_fpath)
|
98 |
+
if len(wav) == 0:
|
99 |
+
continue
|
100 |
+
|
101 |
+
# Create the mel spectrogram, discard those that are too short
|
102 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
103 |
+
if len(frames) < partials_n_frames:
|
104 |
+
continue
|
105 |
+
|
106 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
107 |
+
np.save(out_fpath, frames)
|
108 |
+
logger.add_sample(duration=len(wav) / sampling_rate)
|
109 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
110 |
+
|
111 |
+
sources_file.close()
|
112 |
+
|
113 |
+
# Process the utterances for each speaker
|
114 |
+
with ThreadPool(8) as pool:
|
115 |
+
list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
|
116 |
+
unit="speakers"))
|
117 |
+
logger.finalize()
|
118 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
119 |
+
|
120 |
+
def preprocess_aidatatang_200zh(datasets_root: Path, out_dir: Path, skip_existing=False):
|
121 |
+
dataset_name = "aidatatang_200zh"
|
122 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
123 |
+
if not dataset_root:
|
124 |
+
return
|
125 |
+
# Preprocess all speakers
|
126 |
+
speaker_dirs = list(dataset_root.joinpath("corpus", "train").glob("*"))
|
127 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
128 |
+
skip_existing, logger)
|
129 |
+
|
130 |
+
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
131 |
+
for dataset_name in librispeech_datasets["train"]["other"]:
|
132 |
+
# Initialize the preprocessing
|
133 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
134 |
+
if not dataset_root:
|
135 |
+
return
|
136 |
+
|
137 |
+
# Preprocess all speakers
|
138 |
+
speaker_dirs = list(dataset_root.glob("*"))
|
139 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
|
140 |
+
skip_existing, logger)
|
141 |
+
|
142 |
+
|
143 |
+
def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
|
144 |
+
# Initialize the preprocessing
|
145 |
+
dataset_name = "VoxCeleb1"
|
146 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
147 |
+
if not dataset_root:
|
148 |
+
return
|
149 |
+
|
150 |
+
# Get the contents of the meta file
|
151 |
+
with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
|
152 |
+
metadata = [line.split("\t") for line in metafile][1:]
|
153 |
+
|
154 |
+
# Select the ID and the nationality, filter out non-anglophone speakers
|
155 |
+
nationalities = {line[0]: line[3] for line in metadata}
|
156 |
+
keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
|
157 |
+
nationality.lower() in anglophone_nationalites]
|
158 |
+
print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
|
159 |
+
(len(keep_speaker_ids), len(nationalities)))
|
160 |
+
|
161 |
+
# Get the speaker directories for anglophone speakers only
|
162 |
+
speaker_dirs = dataset_root.joinpath("wav").glob("*")
|
163 |
+
speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
|
164 |
+
speaker_dir.name in keep_speaker_ids]
|
165 |
+
print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
|
166 |
+
(len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
|
167 |
+
|
168 |
+
# Preprocess all speakers
|
169 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
170 |
+
skip_existing, logger)
|
171 |
+
|
172 |
+
|
173 |
+
def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
|
174 |
+
# Initialize the preprocessing
|
175 |
+
dataset_name = "VoxCeleb2"
|
176 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
177 |
+
if not dataset_root:
|
178 |
+
return
|
179 |
+
|
180 |
+
# Get the speaker directories
|
181 |
+
# Preprocess all speakers
|
182 |
+
speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
|
183 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
|
184 |
+
skip_existing, logger)
|
encoder/saved_models/pretrained.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57715adc6f36047166ab06e37b904240aee2f4d10fc88f78ed91510cf4b38666
|
3 |
+
size 17095158
|
encoder/train.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.visualizations import Visualizations
|
2 |
+
from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
3 |
+
from encoder.params_model import *
|
4 |
+
from encoder.model import SpeakerEncoder
|
5 |
+
from utils.profiler import Profiler
|
6 |
+
from pathlib import Path
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def sync(device: torch.device):
|
10 |
+
# For correct profiling (cuda operations are async)
|
11 |
+
if device.type == "cuda":
|
12 |
+
torch.cuda.synchronize(device)
|
13 |
+
|
14 |
+
|
15 |
+
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
16 |
+
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
17 |
+
no_visdom: bool):
|
18 |
+
# Create a dataset and a dataloader
|
19 |
+
dataset = SpeakerVerificationDataset(clean_data_root)
|
20 |
+
loader = SpeakerVerificationDataLoader(
|
21 |
+
dataset,
|
22 |
+
speakers_per_batch,
|
23 |
+
utterances_per_speaker,
|
24 |
+
num_workers=8,
|
25 |
+
)
|
26 |
+
|
27 |
+
# Setup the device on which to run the forward pass and the loss. These can be different,
|
28 |
+
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
29 |
+
# hyperparameters) faster on the CPU.
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
# FIXME: currently, the gradient is None if loss_device is cuda
|
32 |
+
loss_device = torch.device("cpu")
|
33 |
+
|
34 |
+
# Create the model and the optimizer
|
35 |
+
model = SpeakerEncoder(device, loss_device)
|
36 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
37 |
+
init_step = 1
|
38 |
+
|
39 |
+
# Configure file path for the model
|
40 |
+
state_fpath = models_dir.joinpath(run_id + ".pt")
|
41 |
+
backup_dir = models_dir.joinpath(run_id + "_backups")
|
42 |
+
|
43 |
+
# Load any existing model
|
44 |
+
if not force_restart:
|
45 |
+
if state_fpath.exists():
|
46 |
+
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
47 |
+
checkpoint = torch.load(state_fpath)
|
48 |
+
init_step = checkpoint["step"]
|
49 |
+
model.load_state_dict(checkpoint["model_state"])
|
50 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
51 |
+
optimizer.param_groups[0]["lr"] = learning_rate_init
|
52 |
+
else:
|
53 |
+
print("No model \"%s\" found, starting training from scratch." % run_id)
|
54 |
+
else:
|
55 |
+
print("Starting the training from scratch.")
|
56 |
+
model.train()
|
57 |
+
|
58 |
+
# Initialize the visualization environment
|
59 |
+
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
60 |
+
vis.log_dataset(dataset)
|
61 |
+
vis.log_params()
|
62 |
+
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
63 |
+
vis.log_implementation({"Device": device_name})
|
64 |
+
|
65 |
+
# Training loop
|
66 |
+
profiler = Profiler(summarize_every=10, disabled=False)
|
67 |
+
for step, speaker_batch in enumerate(loader, init_step):
|
68 |
+
profiler.tick("Blocking, waiting for batch (threaded)")
|
69 |
+
|
70 |
+
# Forward pass
|
71 |
+
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
72 |
+
sync(device)
|
73 |
+
profiler.tick("Data to %s" % device)
|
74 |
+
embeds = model(inputs)
|
75 |
+
sync(device)
|
76 |
+
profiler.tick("Forward pass")
|
77 |
+
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
78 |
+
loss, eer = model.loss(embeds_loss)
|
79 |
+
sync(loss_device)
|
80 |
+
profiler.tick("Loss")
|
81 |
+
|
82 |
+
# Backward pass
|
83 |
+
model.zero_grad()
|
84 |
+
loss.backward()
|
85 |
+
profiler.tick("Backward pass")
|
86 |
+
model.do_gradient_ops()
|
87 |
+
optimizer.step()
|
88 |
+
profiler.tick("Parameter update")
|
89 |
+
|
90 |
+
# Update visualizations
|
91 |
+
# learning_rate = optimizer.param_groups[0]["lr"]
|
92 |
+
vis.update(loss.item(), eer, step)
|
93 |
+
|
94 |
+
# Draw projections and save them to the backup folder
|
95 |
+
if umap_every != 0 and step % umap_every == 0:
|
96 |
+
print("Drawing and saving projections (step %d)" % step)
|
97 |
+
backup_dir.mkdir(exist_ok=True)
|
98 |
+
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
|
99 |
+
embeds = embeds.detach().cpu().numpy()
|
100 |
+
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
101 |
+
vis.save()
|
102 |
+
|
103 |
+
# Overwrite the latest version of the model
|
104 |
+
if save_every != 0 and step % save_every == 0:
|
105 |
+
print("Saving the model (step %d)" % step)
|
106 |
+
torch.save({
|
107 |
+
"step": step + 1,
|
108 |
+
"model_state": model.state_dict(),
|
109 |
+
"optimizer_state": optimizer.state_dict(),
|
110 |
+
}, state_fpath)
|
111 |
+
|
112 |
+
# Make a backup
|
113 |
+
if backup_every != 0 and step % backup_every == 0:
|
114 |
+
print("Making a backup (step %d)" % step)
|
115 |
+
backup_dir.mkdir(exist_ok=True)
|
116 |
+
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
|
117 |
+
torch.save({
|
118 |
+
"step": step + 1,
|
119 |
+
"model_state": model.state_dict(),
|
120 |
+
"optimizer_state": optimizer.state_dict(),
|
121 |
+
}, backup_fpath)
|
122 |
+
|
123 |
+
profiler.tick("Extras (visualizations, saving)")
|
encoder/visualizations.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from datetime import datetime
|
3 |
+
from time import perf_counter as timer
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
# import webbrowser
|
7 |
+
import visdom
|
8 |
+
import umap
|
9 |
+
|
10 |
+
colormap = np.array([
|
11 |
+
[76, 255, 0],
|
12 |
+
[0, 127, 70],
|
13 |
+
[255, 0, 0],
|
14 |
+
[255, 217, 38],
|
15 |
+
[0, 135, 255],
|
16 |
+
[165, 0, 165],
|
17 |
+
[255, 167, 255],
|
18 |
+
[0, 255, 255],
|
19 |
+
[255, 96, 38],
|
20 |
+
[142, 76, 0],
|
21 |
+
[33, 0, 127],
|
22 |
+
[0, 0, 0],
|
23 |
+
[183, 183, 183],
|
24 |
+
], dtype=np.float) / 255
|
25 |
+
|
26 |
+
|
27 |
+
class Visualizations:
|
28 |
+
def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
|
29 |
+
# Tracking data
|
30 |
+
self.last_update_timestamp = timer()
|
31 |
+
self.update_every = update_every
|
32 |
+
self.step_times = []
|
33 |
+
self.losses = []
|
34 |
+
self.eers = []
|
35 |
+
print("Updating the visualizations every %d steps." % update_every)
|
36 |
+
|
37 |
+
# If visdom is disabled TODO: use a better paradigm for that
|
38 |
+
self.disabled = disabled
|
39 |
+
if self.disabled:
|
40 |
+
return
|
41 |
+
|
42 |
+
# Set the environment name
|
43 |
+
now = str(datetime.now().strftime("%d-%m %Hh%M"))
|
44 |
+
if env_name is None:
|
45 |
+
self.env_name = now
|
46 |
+
else:
|
47 |
+
self.env_name = "%s (%s)" % (env_name, now)
|
48 |
+
|
49 |
+
# Connect to visdom and open the corresponding window in the browser
|
50 |
+
try:
|
51 |
+
self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
|
52 |
+
except ConnectionError:
|
53 |
+
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
|
54 |
+
"start it.")
|
55 |
+
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
|
56 |
+
|
57 |
+
# Create the windows
|
58 |
+
self.loss_win = None
|
59 |
+
self.eer_win = None
|
60 |
+
# self.lr_win = None
|
61 |
+
self.implementation_win = None
|
62 |
+
self.projection_win = None
|
63 |
+
self.implementation_string = ""
|
64 |
+
|
65 |
+
def log_params(self):
|
66 |
+
if self.disabled:
|
67 |
+
return
|
68 |
+
from encoder import params_data
|
69 |
+
from encoder import params_model
|
70 |
+
param_string = "<b>Model parameters</b>:<br>"
|
71 |
+
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
72 |
+
value = getattr(params_model, param_name)
|
73 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
74 |
+
param_string += "<b>Data parameters</b>:<br>"
|
75 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
76 |
+
value = getattr(params_data, param_name)
|
77 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
78 |
+
self.vis.text(param_string, opts={"title": "Parameters"})
|
79 |
+
|
80 |
+
def log_dataset(self, dataset: SpeakerVerificationDataset):
|
81 |
+
if self.disabled:
|
82 |
+
return
|
83 |
+
dataset_string = ""
|
84 |
+
dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
|
85 |
+
dataset_string += "\n" + dataset.get_logs()
|
86 |
+
dataset_string = dataset_string.replace("\n", "<br>")
|
87 |
+
self.vis.text(dataset_string, opts={"title": "Dataset"})
|
88 |
+
|
89 |
+
def log_implementation(self, params):
|
90 |
+
if self.disabled:
|
91 |
+
return
|
92 |
+
implementation_string = ""
|
93 |
+
for param, value in params.items():
|
94 |
+
implementation_string += "<b>%s</b>: %s\n" % (param, value)
|
95 |
+
implementation_string = implementation_string.replace("\n", "<br>")
|
96 |
+
self.implementation_string = implementation_string
|
97 |
+
self.implementation_win = self.vis.text(
|
98 |
+
implementation_string,
|
99 |
+
opts={"title": "Training implementation"}
|
100 |
+
)
|
101 |
+
|
102 |
+
def update(self, loss, eer, step):
|
103 |
+
# Update the tracking data
|
104 |
+
now = timer()
|
105 |
+
self.step_times.append(1000 * (now - self.last_update_timestamp))
|
106 |
+
self.last_update_timestamp = now
|
107 |
+
self.losses.append(loss)
|
108 |
+
self.eers.append(eer)
|
109 |
+
print(".", end="")
|
110 |
+
|
111 |
+
# Update the plots every <update_every> steps
|
112 |
+
if step % self.update_every != 0:
|
113 |
+
return
|
114 |
+
time_string = "Step time: mean: %5dms std: %5dms" % \
|
115 |
+
(int(np.mean(self.step_times)), int(np.std(self.step_times)))
|
116 |
+
print("\nStep %6d Loss: %.4f EER: %.4f %s" %
|
117 |
+
(step, np.mean(self.losses), np.mean(self.eers), time_string))
|
118 |
+
if not self.disabled:
|
119 |
+
self.loss_win = self.vis.line(
|
120 |
+
[np.mean(self.losses)],
|
121 |
+
[step],
|
122 |
+
win=self.loss_win,
|
123 |
+
update="append" if self.loss_win else None,
|
124 |
+
opts=dict(
|
125 |
+
legend=["Avg. loss"],
|
126 |
+
xlabel="Step",
|
127 |
+
ylabel="Loss",
|
128 |
+
title="Loss",
|
129 |
+
)
|
130 |
+
)
|
131 |
+
self.eer_win = self.vis.line(
|
132 |
+
[np.mean(self.eers)],
|
133 |
+
[step],
|
134 |
+
win=self.eer_win,
|
135 |
+
update="append" if self.eer_win else None,
|
136 |
+
opts=dict(
|
137 |
+
legend=["Avg. EER"],
|
138 |
+
xlabel="Step",
|
139 |
+
ylabel="EER",
|
140 |
+
title="Equal error rate"
|
141 |
+
)
|
142 |
+
)
|
143 |
+
if self.implementation_win is not None:
|
144 |
+
self.vis.text(
|
145 |
+
self.implementation_string + ("<b>%s</b>" % time_string),
|
146 |
+
win=self.implementation_win,
|
147 |
+
opts={"title": "Training implementation"},
|
148 |
+
)
|
149 |
+
|
150 |
+
# Reset the tracking
|
151 |
+
self.losses.clear()
|
152 |
+
self.eers.clear()
|
153 |
+
self.step_times.clear()
|
154 |
+
|
155 |
+
def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
|
156 |
+
max_speakers=10):
|
157 |
+
max_speakers = min(max_speakers, len(colormap))
|
158 |
+
embeds = embeds[:max_speakers * utterances_per_speaker]
|
159 |
+
|
160 |
+
n_speakers = len(embeds) // utterances_per_speaker
|
161 |
+
ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
|
162 |
+
colors = [colormap[i] for i in ground_truth]
|
163 |
+
|
164 |
+
reducer = umap.UMAP()
|
165 |
+
projected = reducer.fit_transform(embeds)
|
166 |
+
plt.scatter(projected[:, 0], projected[:, 1], c=colors)
|
167 |
+
plt.gca().set_aspect("equal", "datalim")
|
168 |
+
plt.title("UMAP projection (step %d)" % step)
|
169 |
+
if not self.disabled:
|
170 |
+
self.projection_win = self.vis.matplot(plt, win=self.projection_win)
|
171 |
+
if out_fpath is not None:
|
172 |
+
plt.savefig(out_fpath)
|
173 |
+
plt.clf()
|
174 |
+
|
175 |
+
def save(self):
|
176 |
+
if not self.disabled:
|
177 |
+
self.vis.save([self.env_name])
|
178 |
+
|
encoder_preprocess.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2, preprocess_aidatatang_200zh
|
2 |
+
from utils.argutils import print_args
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
|
8 |
+
pass
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser(
|
11 |
+
description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
|
12 |
+
"writes them to the disk. This will allow you to train the encoder. The "
|
13 |
+
"datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
|
14 |
+
formatter_class=MyFormatter
|
15 |
+
)
|
16 |
+
parser.add_argument("datasets_root", type=Path, help=\
|
17 |
+
"Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
|
18 |
+
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
19 |
+
"Path to the output directory that will contain the mel spectrograms. If left out, "
|
20 |
+
"defaults to <datasets_root>/SV2TTS/encoder/")
|
21 |
+
parser.add_argument("-d", "--datasets", type=str,
|
22 |
+
default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
|
23 |
+
"Comma-separated list of the name of the datasets you want to preprocess. Only the train "
|
24 |
+
"set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
|
25 |
+
"voxceleb2.")
|
26 |
+
parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
27 |
+
"Whether to skip existing output files with the same name. Useful if this script was "
|
28 |
+
"interrupted.")
|
29 |
+
parser.add_argument("--no_trim", action="store_true", help=\
|
30 |
+
"Preprocess audio without trimming silences (not recommended).")
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
# Verify webrtcvad is available
|
34 |
+
if not args.no_trim:
|
35 |
+
try:
|
36 |
+
import webrtcvad
|
37 |
+
except:
|
38 |
+
raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
|
39 |
+
"noise removal and is recommended. Please install and try again. If installation fails, "
|
40 |
+
"use --no_trim to disable this error message.")
|
41 |
+
del args.no_trim
|
42 |
+
|
43 |
+
# Process the arguments
|
44 |
+
args.datasets = args.datasets.split(",")
|
45 |
+
if not hasattr(args, "out_dir"):
|
46 |
+
args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
|
47 |
+
assert args.datasets_root.exists()
|
48 |
+
args.out_dir.mkdir(exist_ok=True, parents=True)
|
49 |
+
|
50 |
+
# Preprocess the datasets
|
51 |
+
print_args(args, parser)
|
52 |
+
preprocess_func = {
|
53 |
+
"librispeech_other": preprocess_librispeech,
|
54 |
+
"voxceleb1": preprocess_voxceleb1,
|
55 |
+
"voxceleb2": preprocess_voxceleb2,
|
56 |
+
"aidatatang_200zh": preprocess_aidatatang_200zh,
|
57 |
+
}
|
58 |
+
args = vars(args)
|
59 |
+
for dataset in args.pop("datasets"):
|
60 |
+
print("Preprocessing %s" % dataset)
|
61 |
+
preprocess_func[dataset](**args)
|
encoder_train.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.argutils import print_args
|
2 |
+
from encoder.train import train
|
3 |
+
from pathlib import Path
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser(
|
9 |
+
description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
|
10 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
11 |
+
)
|
12 |
+
|
13 |
+
parser.add_argument("run_id", type=str, help= \
|
14 |
+
"Name for this model instance. If a model state from the same run ID was previously "
|
15 |
+
"saved, the training will restart from there. Pass -f to overwrite saved states and "
|
16 |
+
"restart from scratch.")
|
17 |
+
parser.add_argument("clean_data_root", type=Path, help= \
|
18 |
+
"Path to the output directory of encoder_preprocess.py. If you left the default "
|
19 |
+
"output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
|
20 |
+
parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
|
21 |
+
"Path to the output directory that will contain the saved model weights, as well as "
|
22 |
+
"backups of those weights and plots generated during training.")
|
23 |
+
parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
|
24 |
+
"Number of steps between updates of the loss and the plots.")
|
25 |
+
parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
|
26 |
+
"Number of steps between updates of the umap projection. Set to 0 to never update the "
|
27 |
+
"projections.")
|
28 |
+
parser.add_argument("-s", "--save_every", type=int, default=500, help= \
|
29 |
+
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
30 |
+
"model.")
|
31 |
+
parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
|
32 |
+
"Number of steps between backups of the model. Set to 0 to never make backups of the "
|
33 |
+
"model.")
|
34 |
+
parser.add_argument("-f", "--force_restart", action="store_true", help= \
|
35 |
+
"Do not load any saved model.")
|
36 |
+
parser.add_argument("--visdom_server", type=str, default="http://localhost")
|
37 |
+
parser.add_argument("--no_visdom", action="store_true", help= \
|
38 |
+
"Disable visdom.")
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
# Process the arguments
|
42 |
+
args.models_dir.mkdir(exist_ok=True)
|
43 |
+
|
44 |
+
# Run the training
|
45 |
+
print_args(args, parser)
|
46 |
+
train(**vars(args))
|
47 |
+
|
gen_voice.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encoder.params_model import model_embedding_size as speaker_embedding_size
|
2 |
+
from utils.argutils import print_args
|
3 |
+
from utils.modelutils import check_model_paths
|
4 |
+
from synthesizer.inference import Synthesizer
|
5 |
+
from encoder import inference as encoder
|
6 |
+
from vocoder.wavernn import inference as rnn_vocoder
|
7 |
+
from vocoder.hifigan import inference as gan_vocoder
|
8 |
+
from pathlib import Path
|
9 |
+
import numpy as np
|
10 |
+
import soundfile as sf
|
11 |
+
import librosa
|
12 |
+
import argparse
|
13 |
+
import torch
|
14 |
+
import sys
|
15 |
+
import os
|
16 |
+
import re
|
17 |
+
import cn2an
|
18 |
+
import glob
|
19 |
+
|
20 |
+
from audioread.exceptions import NoBackendError
|
21 |
+
vocoder = gan_vocoder
|
22 |
+
|
23 |
+
def gen_one_wav(synthesizer, in_fpath, embed, texts, file_name, seq):
|
24 |
+
embeds = [embed] * len(texts)
|
25 |
+
# If you know what the attention layer alignments are, you can retrieve them here by
|
26 |
+
# passing return_alignments=True
|
27 |
+
specs = synthesizer.synthesize_spectrograms(texts, embeds, style_idx=-1, min_stop_token=4, steps=400)
|
28 |
+
#spec = specs[0]
|
29 |
+
breaks = [spec.shape[1] for spec in specs]
|
30 |
+
spec = np.concatenate(specs, axis=1)
|
31 |
+
|
32 |
+
# If seed is specified, reset torch seed and reload vocoder
|
33 |
+
# Synthesizing the waveform is fairly straightforward. Remember that the longer the
|
34 |
+
# spectrogram, the more time-efficient the vocoder.
|
35 |
+
generated_wav, output_sample_rate = vocoder.infer_waveform(spec)
|
36 |
+
|
37 |
+
# Add breaks
|
38 |
+
b_ends = np.cumsum(np.array(breaks) * synthesizer.hparams.hop_size)
|
39 |
+
b_starts = np.concatenate(([0], b_ends[:-1]))
|
40 |
+
wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
41 |
+
breaks = [np.zeros(int(0.15 * synthesizer.sample_rate))] * len(breaks)
|
42 |
+
generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
43 |
+
|
44 |
+
## Post-generation
|
45 |
+
# There's a bug with sounddevice that makes the audio cut one second earlier, so we
|
46 |
+
# pad it.
|
47 |
+
|
48 |
+
# Trim excess silences to compensate for gaps in spectrograms (issue #53)
|
49 |
+
generated_wav = encoder.preprocess_wav(generated_wav)
|
50 |
+
generated_wav = generated_wav / np.abs(generated_wav).max() * 0.97
|
51 |
+
|
52 |
+
# Save it on the disk
|
53 |
+
model=os.path.basename(in_fpath)
|
54 |
+
filename = "%s_%d_%s.wav" %(file_name, seq, model)
|
55 |
+
sf.write(filename, generated_wav, synthesizer.sample_rate)
|
56 |
+
|
57 |
+
print("\nSaved output as %s\n\n" % filename)
|
58 |
+
|
59 |
+
|
60 |
+
def generate_wav(enc_model_fpath, syn_model_fpath, voc_model_fpath, in_fpath, input_txt, file_name):
|
61 |
+
if torch.cuda.is_available():
|
62 |
+
device_id = torch.cuda.current_device()
|
63 |
+
gpu_properties = torch.cuda.get_device_properties(device_id)
|
64 |
+
## Print some environment information (for debugging purposes)
|
65 |
+
print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
|
66 |
+
"%.1fGb total memory.\n" %
|
67 |
+
(torch.cuda.device_count(),
|
68 |
+
device_id,
|
69 |
+
gpu_properties.name,
|
70 |
+
gpu_properties.major,
|
71 |
+
gpu_properties.minor,
|
72 |
+
gpu_properties.total_memory / 1e9))
|
73 |
+
else:
|
74 |
+
print("Using CPU for inference.\n")
|
75 |
+
|
76 |
+
print("Preparing the encoder, the synthesizer and the vocoder...")
|
77 |
+
encoder.load_model(enc_model_fpath)
|
78 |
+
synthesizer = Synthesizer(syn_model_fpath)
|
79 |
+
vocoder.load_model(voc_model_fpath)
|
80 |
+
|
81 |
+
encoder_wav = synthesizer.load_preprocess_wav(in_fpath)
|
82 |
+
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
83 |
+
|
84 |
+
texts = input_txt.split("\n")
|
85 |
+
seq=0
|
86 |
+
each_num=1500
|
87 |
+
|
88 |
+
punctuation = '!,。、,' # punctuate and split/clean text
|
89 |
+
processed_texts = []
|
90 |
+
cur_num = 0
|
91 |
+
for text in texts:
|
92 |
+
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
93 |
+
if processed_text:
|
94 |
+
processed_texts.append(processed_text.strip())
|
95 |
+
cur_num += len(processed_text.strip())
|
96 |
+
if cur_num > each_num:
|
97 |
+
seq = seq +1
|
98 |
+
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
|
99 |
+
processed_texts = []
|
100 |
+
cur_num = 0
|
101 |
+
|
102 |
+
if len(processed_texts)>0:
|
103 |
+
seq = seq +1
|
104 |
+
gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
|
105 |
+
|
106 |
+
if (len(sys.argv)>=3):
|
107 |
+
my_txt = ""
|
108 |
+
print("reading from :", sys.argv[1])
|
109 |
+
with open(sys.argv[1], "r") as f:
|
110 |
+
for line in f.readlines():
|
111 |
+
#line = line.strip('\n')
|
112 |
+
my_txt += line
|
113 |
+
txt_file_name = sys.argv[1]
|
114 |
+
wav_file_name = sys.argv[2]
|
115 |
+
|
116 |
+
output = cn2an.transform(my_txt, "an2cn")
|
117 |
+
print(output)
|
118 |
+
generate_wav(
|
119 |
+
Path("encoder/saved_models/pretrained.pt"),
|
120 |
+
Path("synthesizer/saved_models/mandarin.pt"),
|
121 |
+
Path("vocoder/saved_models/pretrained/g_hifigan.pt"), wav_file_name, output, txt_file_name
|
122 |
+
)
|
123 |
+
|
124 |
+
else:
|
125 |
+
print("please input the file name")
|
126 |
+
exit(1)
|
127 |
+
|
128 |
+
|
mkgui/__init__.py
ADDED
File without changes
|
mkgui/app.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from encoder import inference as encoder
|
6 |
+
import librosa
|
7 |
+
from scipy.io.wavfile import write
|
8 |
+
import re
|
9 |
+
import numpy as np
|
10 |
+
from mkgui.base.components.types import FileContent
|
11 |
+
from vocoder.hifigan import inference as gan_vocoder
|
12 |
+
from synthesizer.inference import Synthesizer
|
13 |
+
from typing import Any, Tuple
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
|
16 |
+
# Constants
|
17 |
+
AUDIO_SAMPLES_DIR = f"samples{os.sep}"
|
18 |
+
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
|
19 |
+
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
20 |
+
VOC_MODELS_DIRT = f"vocoder{os.sep}saved_models"
|
21 |
+
TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
|
22 |
+
TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
|
23 |
+
if not os.path.isdir("wavs"):
|
24 |
+
os.makedirs("wavs")
|
25 |
+
|
26 |
+
# Load local sample audio as options TODO: load dataset
|
27 |
+
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
28 |
+
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
29 |
+
# Pre-Load models
|
30 |
+
if os.path.isdir(SYN_MODELS_DIRT):
|
31 |
+
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
32 |
+
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
33 |
+
else:
|
34 |
+
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
35 |
+
|
36 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
37 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
38 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
39 |
+
else:
|
40 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
41 |
+
|
42 |
+
if os.path.isdir(VOC_MODELS_DIRT):
|
43 |
+
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
44 |
+
print("Loaded vocoders models: " + str(len(synthesizers)))
|
45 |
+
else:
|
46 |
+
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
class Input(BaseModel):
|
51 |
+
message: str = Field(
|
52 |
+
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
53 |
+
)
|
54 |
+
local_audio_file: audio_input_selection = Field(
|
55 |
+
..., alias="输入语音(本地wav)",
|
56 |
+
description="选择本地语音文件."
|
57 |
+
)
|
58 |
+
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
59 |
+
description="拖拽或点击上传.", mime_type="audio/wav")
|
60 |
+
encoder: encoders = Field(
|
61 |
+
..., alias="编码模型",
|
62 |
+
description="选择语音编码模型文件."
|
63 |
+
)
|
64 |
+
synthesizer: synthesizers = Field(
|
65 |
+
..., alias="合成模型",
|
66 |
+
description="选择语音合成模型文件."
|
67 |
+
)
|
68 |
+
vocoder: vocoders = Field(
|
69 |
+
..., alias="语音解码模型",
|
70 |
+
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
71 |
+
)
|
72 |
+
|
73 |
+
class AudioEntity(BaseModel):
|
74 |
+
content: bytes
|
75 |
+
mel: Any
|
76 |
+
|
77 |
+
class Output(BaseModel):
|
78 |
+
__root__: Tuple[AudioEntity, AudioEntity]
|
79 |
+
|
80 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
81 |
+
"""Custom output UI.
|
82 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
83 |
+
"""
|
84 |
+
src, result = self.__root__
|
85 |
+
|
86 |
+
streamlit_app.subheader("Synthesized Audio")
|
87 |
+
streamlit_app.audio(result.content, format="audio/wav")
|
88 |
+
|
89 |
+
fig, ax = plt.subplots()
|
90 |
+
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
91 |
+
ax.set_title("mel spectrogram(Source Audio)")
|
92 |
+
streamlit_app.pyplot(fig)
|
93 |
+
fig, ax = plt.subplots()
|
94 |
+
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
95 |
+
ax.set_title("mel spectrogram(Result Audio)")
|
96 |
+
streamlit_app.pyplot(fig)
|
97 |
+
|
98 |
+
|
99 |
+
def synthesize(input: Input) -> Output:
|
100 |
+
"""synthesize(合成)"""
|
101 |
+
# load models
|
102 |
+
encoder.load_model(Path(input.encoder.value))
|
103 |
+
current_synt = Synthesizer(Path(input.synthesizer.value))
|
104 |
+
gan_vocoder.load_model(Path(input.vocoder.value))
|
105 |
+
|
106 |
+
# load file
|
107 |
+
if input.upload_audio_file != None:
|
108 |
+
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
109 |
+
f.write(input.upload_audio_file.as_bytes())
|
110 |
+
f.seek(0)
|
111 |
+
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
112 |
+
else:
|
113 |
+
wav, sample_rate = librosa.load(input.local_audio_file.value)
|
114 |
+
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
115 |
+
|
116 |
+
source_spec = Synthesizer.make_spectrogram(wav)
|
117 |
+
|
118 |
+
# preprocess
|
119 |
+
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
|
120 |
+
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
121 |
+
|
122 |
+
# Load input text
|
123 |
+
texts = filter(None, input.message.split("\n"))
|
124 |
+
punctuation = '!,。、,' # punctuate and split/clean text
|
125 |
+
processed_texts = []
|
126 |
+
for text in texts:
|
127 |
+
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
|
128 |
+
if processed_text:
|
129 |
+
processed_texts.append(processed_text.strip())
|
130 |
+
texts = processed_texts
|
131 |
+
|
132 |
+
# synthesize and vocode
|
133 |
+
embeds = [embed] * len(texts)
|
134 |
+
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
135 |
+
spec = np.concatenate(specs, axis=1)
|
136 |
+
sample_rate = Synthesizer.sample_rate
|
137 |
+
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
138 |
+
|
139 |
+
# write and output
|
140 |
+
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
141 |
+
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
142 |
+
source_file = f.read()
|
143 |
+
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
144 |
+
result_file = f.read()
|
145 |
+
return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
|
mkgui/app_vc.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from synthesizer.inference import Synthesizer
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from encoder import inference as speacker_encoder
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
from enum import Enum
|
8 |
+
import ppg_extractor as Extractor
|
9 |
+
import ppg2mel as Convertor
|
10 |
+
import librosa
|
11 |
+
from scipy.io.wavfile import write
|
12 |
+
import re
|
13 |
+
import numpy as np
|
14 |
+
from mkgui.base.components.types import FileContent
|
15 |
+
from vocoder.hifigan import inference as gan_vocoder
|
16 |
+
from typing import Any, Tuple
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
|
20 |
+
# Constants
|
21 |
+
AUDIO_SAMPLES_DIR = f'sample{os.sep}'
|
22 |
+
EXT_MODELS_DIRT = f'ppg_extractor{os.sep}saved_models'
|
23 |
+
CONV_MODELS_DIRT = f'ppg2mel{os.sep}saved_models'
|
24 |
+
VOC_MODELS_DIRT = f'vocoder{os.sep}saved_models'
|
25 |
+
TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
|
26 |
+
TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
|
27 |
+
TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
|
28 |
+
|
29 |
+
# Load local sample audio as options TODO: load dataset
|
30 |
+
if os.path.isdir(AUDIO_SAMPLES_DIR):
|
31 |
+
audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
|
32 |
+
# Pre-Load models
|
33 |
+
if os.path.isdir(EXT_MODELS_DIRT):
|
34 |
+
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
35 |
+
print("Loaded extractor models: " + str(len(extractors)))
|
36 |
+
else:
|
37 |
+
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
38 |
+
|
39 |
+
if os.path.isdir(CONV_MODELS_DIRT):
|
40 |
+
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
41 |
+
print("Loaded convertor models: " + str(len(convertors)))
|
42 |
+
else:
|
43 |
+
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
44 |
+
|
45 |
+
if os.path.isdir(VOC_MODELS_DIRT):
|
46 |
+
vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
|
47 |
+
print("Loaded vocoders models: " + str(len(vocoders)))
|
48 |
+
else:
|
49 |
+
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
50 |
+
|
51 |
+
class Input(BaseModel):
|
52 |
+
local_audio_file: audio_input_selection = Field(
|
53 |
+
..., alias="输入语音(本地wav)",
|
54 |
+
description="选择本地语音文件."
|
55 |
+
)
|
56 |
+
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
57 |
+
description="拖拽或点击上传.", mime_type="audio/wav")
|
58 |
+
local_audio_file_target: audio_input_selection = Field(
|
59 |
+
..., alias="目标语音(本地wav)",
|
60 |
+
description="选择本地语音文件."
|
61 |
+
)
|
62 |
+
upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
|
63 |
+
description="拖拽或点击上传.", mime_type="audio/wav")
|
64 |
+
extractor: extractors = Field(
|
65 |
+
..., alias="编码模型",
|
66 |
+
description="选择语音编码模型文件."
|
67 |
+
)
|
68 |
+
convertor: convertors = Field(
|
69 |
+
..., alias="转换模型",
|
70 |
+
description="选择语音转换模型文件."
|
71 |
+
)
|
72 |
+
vocoder: vocoders = Field(
|
73 |
+
..., alias="语音解码模型",
|
74 |
+
description="选择语音解码模型文件(目前只支持HifiGan类型)."
|
75 |
+
)
|
76 |
+
|
77 |
+
class AudioEntity(BaseModel):
|
78 |
+
content: bytes
|
79 |
+
mel: Any
|
80 |
+
|
81 |
+
class Output(BaseModel):
|
82 |
+
__root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
|
83 |
+
|
84 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
85 |
+
"""Custom output UI.
|
86 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
87 |
+
"""
|
88 |
+
src, target, result = self.__root__
|
89 |
+
|
90 |
+
streamlit_app.subheader("Synthesized Audio")
|
91 |
+
streamlit_app.audio(result.content, format="audio/wav")
|
92 |
+
|
93 |
+
fig, ax = plt.subplots()
|
94 |
+
ax.imshow(src.mel, aspect="equal", interpolation="none")
|
95 |
+
ax.set_title("mel spectrogram(Source Audio)")
|
96 |
+
streamlit_app.pyplot(fig)
|
97 |
+
fig, ax = plt.subplots()
|
98 |
+
ax.imshow(target.mel, aspect="equal", interpolation="none")
|
99 |
+
ax.set_title("mel spectrogram(Target Audio)")
|
100 |
+
streamlit_app.pyplot(fig)
|
101 |
+
fig, ax = plt.subplots()
|
102 |
+
ax.imshow(result.mel, aspect="equal", interpolation="none")
|
103 |
+
ax.set_title("mel spectrogram(Result Audio)")
|
104 |
+
streamlit_app.pyplot(fig)
|
105 |
+
|
106 |
+
def convert(input: Input) -> Output:
|
107 |
+
"""convert(转换)"""
|
108 |
+
# load models
|
109 |
+
extractor = Extractor.load_model(Path(input.extractor.value))
|
110 |
+
convertor = Convertor.load_model(Path(input.convertor.value))
|
111 |
+
# current_synt = Synthesizer(Path(input.synthesizer.value))
|
112 |
+
gan_vocoder.load_model(Path(input.vocoder.value))
|
113 |
+
|
114 |
+
# load file
|
115 |
+
if input.upload_audio_file != None:
|
116 |
+
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
117 |
+
f.write(input.upload_audio_file.as_bytes())
|
118 |
+
f.seek(0)
|
119 |
+
src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
120 |
+
else:
|
121 |
+
src_wav, sample_rate = librosa.load(input.local_audio_file.value)
|
122 |
+
write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
|
123 |
+
|
124 |
+
if input.upload_audio_file_target != None:
|
125 |
+
with open(TEMP_TARGET_AUDIO, "w+b") as f:
|
126 |
+
f.write(input.upload_audio_file_target.as_bytes())
|
127 |
+
f.seek(0)
|
128 |
+
ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
|
129 |
+
else:
|
130 |
+
ref_wav, _ = librosa.load(input.local_audio_file_target.value)
|
131 |
+
write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
|
132 |
+
|
133 |
+
ppg = extractor.extract_from_wav(src_wav)
|
134 |
+
# Import necessary dependency of Voice Conversion
|
135 |
+
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
136 |
+
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
137 |
+
speacker_encoder.load_model(Path("encoder{os.sep}saved_models{os.sep}pretrained_bak_5805000.pt"))
|
138 |
+
embed = speacker_encoder.embed_utterance(ref_wav)
|
139 |
+
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
140 |
+
min_len = min(ppg.shape[1], len(lf0_uv))
|
141 |
+
ppg = ppg[:, :min_len]
|
142 |
+
lf0_uv = lf0_uv[:min_len]
|
143 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
144 |
+
_, mel_pred, att_ws = convertor.inference(
|
145 |
+
ppg,
|
146 |
+
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
147 |
+
spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
|
148 |
+
)
|
149 |
+
mel_pred= mel_pred.transpose(0, 1)
|
150 |
+
breaks = [mel_pred.shape[1]]
|
151 |
+
mel_pred= mel_pred.detach().cpu().numpy()
|
152 |
+
|
153 |
+
# synthesize and vocode
|
154 |
+
wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
|
155 |
+
|
156 |
+
# write and output
|
157 |
+
write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
|
158 |
+
with open(TEMP_SOURCE_AUDIO, "rb") as f:
|
159 |
+
source_file = f.read()
|
160 |
+
with open(TEMP_TARGET_AUDIO, "rb") as f:
|
161 |
+
target_file = f.read()
|
162 |
+
with open(TEMP_RESULT_AUDIO, "rb") as f:
|
163 |
+
result_file = f.read()
|
164 |
+
|
165 |
+
|
166 |
+
return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
|
mkgui/base/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .core import Opyrator
|
mkgui/base/api/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .fastapi_app import create_api
|
mkgui/base/api/fastapi_utils.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Collection of utilities for FastAPI apps."""
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Any, Type
|
5 |
+
|
6 |
+
from fastapi import FastAPI, Form
|
7 |
+
from pydantic import BaseModel
|
8 |
+
|
9 |
+
|
10 |
+
def as_form(cls: Type[BaseModel]) -> Any:
|
11 |
+
"""Adds an as_form class method to decorated models.
|
12 |
+
|
13 |
+
The as_form class method can be used with FastAPI endpoints
|
14 |
+
"""
|
15 |
+
new_params = [
|
16 |
+
inspect.Parameter(
|
17 |
+
field.alias,
|
18 |
+
inspect.Parameter.POSITIONAL_ONLY,
|
19 |
+
default=(Form(field.default) if not field.required else Form(...)),
|
20 |
+
)
|
21 |
+
for field in cls.__fields__.values()
|
22 |
+
]
|
23 |
+
|
24 |
+
async def _as_form(**data): # type: ignore
|
25 |
+
return cls(**data)
|
26 |
+
|
27 |
+
sig = inspect.signature(_as_form)
|
28 |
+
sig = sig.replace(parameters=new_params)
|
29 |
+
_as_form.__signature__ = sig # type: ignore
|
30 |
+
setattr(cls, "as_form", _as_form)
|
31 |
+
return cls
|
32 |
+
|
33 |
+
|
34 |
+
def patch_fastapi(app: FastAPI) -> None:
|
35 |
+
"""Patch function to allow relative url resolution.
|
36 |
+
|
37 |
+
This patch is required to make fastapi fully functional with a relative url path.
|
38 |
+
This code snippet can be copy-pasted to any Fastapi application.
|
39 |
+
"""
|
40 |
+
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
41 |
+
from starlette.requests import Request
|
42 |
+
from starlette.responses import HTMLResponse
|
43 |
+
|
44 |
+
async def redoc_ui_html(req: Request) -> HTMLResponse:
|
45 |
+
assert app.openapi_url is not None
|
46 |
+
redoc_ui = get_redoc_html(
|
47 |
+
openapi_url="./" + app.openapi_url.lstrip("/"),
|
48 |
+
title=app.title + " - Redoc UI",
|
49 |
+
)
|
50 |
+
|
51 |
+
return HTMLResponse(redoc_ui.body.decode("utf-8"))
|
52 |
+
|
53 |
+
async def swagger_ui_html(req: Request) -> HTMLResponse:
|
54 |
+
assert app.openapi_url is not None
|
55 |
+
swagger_ui = get_swagger_ui_html(
|
56 |
+
openapi_url="./" + app.openapi_url.lstrip("/"),
|
57 |
+
title=app.title + " - Swagger UI",
|
58 |
+
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
59 |
+
)
|
60 |
+
|
61 |
+
# insert request interceptor to have all request run on relativ path
|
62 |
+
request_interceptor = (
|
63 |
+
"requestInterceptor: (e) => {"
|
64 |
+
"\n\t\t\tvar url = window.location.origin + window.location.pathname"
|
65 |
+
'\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
|
66 |
+
"\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
|
67 |
+
"\n\t\t\te.contextUrl = url"
|
68 |
+
"\n\t\t\te.url = url"
|
69 |
+
"\n\t\t\treturn e;}"
|
70 |
+
)
|
71 |
+
|
72 |
+
return HTMLResponse(
|
73 |
+
swagger_ui.body.decode("utf-8").replace(
|
74 |
+
"dom_id: '#swagger-ui',",
|
75 |
+
"dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
|
76 |
+
)
|
77 |
+
)
|
78 |
+
|
79 |
+
# remove old docs route and add our patched route
|
80 |
+
routes_new = []
|
81 |
+
for app_route in app.routes:
|
82 |
+
if app_route.path == "/docs": # type: ignore
|
83 |
+
continue
|
84 |
+
|
85 |
+
if app_route.path == "/redoc": # type: ignore
|
86 |
+
continue
|
87 |
+
|
88 |
+
routes_new.append(app_route)
|
89 |
+
|
90 |
+
app.router.routes = routes_new
|
91 |
+
|
92 |
+
assert app.docs_url is not None
|
93 |
+
app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
|
94 |
+
assert app.redoc_url is not None
|
95 |
+
app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
|
96 |
+
|
97 |
+
# Make graphql realtive
|
98 |
+
from starlette import graphql
|
99 |
+
|
100 |
+
graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
|
101 |
+
"({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
|
102 |
+
)
|
mkgui/base/components/__init__.py
ADDED
File without changes
|
mkgui/base/components/outputs.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class ScoredLabel(BaseModel):
|
7 |
+
label: str
|
8 |
+
score: float
|
9 |
+
|
10 |
+
|
11 |
+
class ClassificationOutput(BaseModel):
|
12 |
+
__root__: List[ScoredLabel]
|
13 |
+
|
14 |
+
def __iter__(self): # type: ignore
|
15 |
+
return iter(self.__root__)
|
16 |
+
|
17 |
+
def __getitem__(self, item): # type: ignore
|
18 |
+
return self.__root__[item]
|
19 |
+
|
20 |
+
def render_output_ui(self, streamlit) -> None: # type: ignore
|
21 |
+
import plotly.express as px
|
22 |
+
|
23 |
+
sorted_predictions = sorted(
|
24 |
+
[prediction.dict() for prediction in self.__root__],
|
25 |
+
key=lambda k: k["score"],
|
26 |
+
)
|
27 |
+
|
28 |
+
num_labels = len(sorted_predictions)
|
29 |
+
if len(sorted_predictions) > 10:
|
30 |
+
num_labels = streamlit.slider(
|
31 |
+
"Maximum labels to show: ",
|
32 |
+
min_value=1,
|
33 |
+
max_value=len(sorted_predictions),
|
34 |
+
value=len(sorted_predictions),
|
35 |
+
)
|
36 |
+
fig = px.bar(
|
37 |
+
sorted_predictions[len(sorted_predictions) - num_labels :],
|
38 |
+
x="score",
|
39 |
+
y="label",
|
40 |
+
orientation="h",
|
41 |
+
)
|
42 |
+
streamlit.plotly_chart(fig, use_container_width=True)
|
43 |
+
# fig.show()
|
mkgui/base/components/types.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from typing import Any, Dict, overload
|
3 |
+
|
4 |
+
|
5 |
+
class FileContent(str):
|
6 |
+
def as_bytes(self) -> bytes:
|
7 |
+
return base64.b64decode(self, validate=True)
|
8 |
+
|
9 |
+
def as_str(self) -> str:
|
10 |
+
return self.as_bytes().decode()
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
14 |
+
field_schema.update(format="byte")
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def __get_validators__(cls) -> Any: # type: ignore
|
18 |
+
yield cls.validate
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def validate(cls, value: Any) -> "FileContent":
|
22 |
+
if isinstance(value, FileContent):
|
23 |
+
return value
|
24 |
+
elif isinstance(value, str):
|
25 |
+
return FileContent(value)
|
26 |
+
elif isinstance(value, (bytes, bytearray, memoryview)):
|
27 |
+
return FileContent(base64.b64encode(value).decode())
|
28 |
+
else:
|
29 |
+
raise Exception("Wrong type")
|
30 |
+
|
31 |
+
# # 暂时无法使用,因为浏览器中没有考虑选择文件夹
|
32 |
+
# class DirectoryContent(FileContent):
|
33 |
+
# @classmethod
|
34 |
+
# def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
35 |
+
# field_schema.update(format="path")
|
36 |
+
|
37 |
+
# @classmethod
|
38 |
+
# def validate(cls, value: Any) -> "DirectoryContent":
|
39 |
+
# if isinstance(value, DirectoryContent):
|
40 |
+
# return value
|
41 |
+
# elif isinstance(value, str):
|
42 |
+
# return DirectoryContent(value)
|
43 |
+
# elif isinstance(value, (bytes, bytearray, memoryview)):
|
44 |
+
# return DirectoryContent(base64.b64encode(value).decode())
|
45 |
+
# else:
|
46 |
+
# raise Exception("Wrong type")
|
mkgui/base/core.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import inspect
|
3 |
+
import re
|
4 |
+
from typing import Any, Callable, Type, Union, get_type_hints
|
5 |
+
|
6 |
+
from pydantic import BaseModel, parse_raw_as
|
7 |
+
from pydantic.tools import parse_obj_as
|
8 |
+
|
9 |
+
|
10 |
+
def name_to_title(name: str) -> str:
|
11 |
+
"""Converts a camelCase or snake_case name to title case."""
|
12 |
+
# If camelCase -> convert to snake case
|
13 |
+
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
14 |
+
name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
|
15 |
+
# Convert to title case
|
16 |
+
return name.replace("_", " ").strip().title()
|
17 |
+
|
18 |
+
|
19 |
+
def is_compatible_type(type: Type) -> bool:
|
20 |
+
"""Returns `True` if the type is opyrator-compatible."""
|
21 |
+
try:
|
22 |
+
if issubclass(type, BaseModel):
|
23 |
+
return True
|
24 |
+
except Exception:
|
25 |
+
pass
|
26 |
+
|
27 |
+
try:
|
28 |
+
# valid list type
|
29 |
+
if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
|
30 |
+
return True
|
31 |
+
except Exception:
|
32 |
+
pass
|
33 |
+
|
34 |
+
return False
|
35 |
+
|
36 |
+
|
37 |
+
def get_input_type(func: Callable) -> Type:
|
38 |
+
"""Returns the input type of a given function (callable).
|
39 |
+
|
40 |
+
Args:
|
41 |
+
func: The function for which to get the input type.
|
42 |
+
|
43 |
+
Raises:
|
44 |
+
ValueError: If the function does not have a valid input type annotation.
|
45 |
+
"""
|
46 |
+
type_hints = get_type_hints(func)
|
47 |
+
|
48 |
+
if "input" not in type_hints:
|
49 |
+
raise ValueError(
|
50 |
+
"The callable MUST have a parameter with the name `input` with typing annotation. "
|
51 |
+
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
52 |
+
)
|
53 |
+
|
54 |
+
input_type = type_hints["input"]
|
55 |
+
|
56 |
+
if not is_compatible_type(input_type):
|
57 |
+
raise ValueError(
|
58 |
+
"The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
59 |
+
)
|
60 |
+
|
61 |
+
# TODO: return warning if more than one input parameters
|
62 |
+
|
63 |
+
return input_type
|
64 |
+
|
65 |
+
|
66 |
+
def get_output_type(func: Callable) -> Type:
|
67 |
+
"""Returns the output type of a given function (callable).
|
68 |
+
|
69 |
+
Args:
|
70 |
+
func: The function for which to get the output type.
|
71 |
+
|
72 |
+
Raises:
|
73 |
+
ValueError: If the function does not have a valid output type annotation.
|
74 |
+
"""
|
75 |
+
type_hints = get_type_hints(func)
|
76 |
+
if "return" not in type_hints:
|
77 |
+
raise ValueError(
|
78 |
+
"The return type of the callable MUST be annotated with type hints."
|
79 |
+
"For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
|
80 |
+
)
|
81 |
+
|
82 |
+
output_type = type_hints["return"]
|
83 |
+
|
84 |
+
if not is_compatible_type(output_type):
|
85 |
+
raise ValueError(
|
86 |
+
"The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
|
87 |
+
)
|
88 |
+
|
89 |
+
return output_type
|
90 |
+
|
91 |
+
|
92 |
+
def get_callable(import_string: str) -> Callable:
|
93 |
+
"""Import a callable from an string."""
|
94 |
+
callable_seperator = ":"
|
95 |
+
if callable_seperator not in import_string:
|
96 |
+
# Use dot as seperator
|
97 |
+
callable_seperator = "."
|
98 |
+
|
99 |
+
if callable_seperator not in import_string:
|
100 |
+
raise ValueError("The callable path MUST specify the function. ")
|
101 |
+
|
102 |
+
mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
|
103 |
+
mod = importlib.import_module(mod_name)
|
104 |
+
return getattr(mod, callable_name)
|
105 |
+
|
106 |
+
|
107 |
+
class Opyrator:
|
108 |
+
def __init__(self, func: Union[Callable, str]) -> None:
|
109 |
+
if isinstance(func, str):
|
110 |
+
# Try to load the function from a string notion
|
111 |
+
self.function = get_callable(func)
|
112 |
+
else:
|
113 |
+
self.function = func
|
114 |
+
|
115 |
+
self._action = "Execute"
|
116 |
+
self._input_type = None
|
117 |
+
self._output_type = None
|
118 |
+
|
119 |
+
if not callable(self.function):
|
120 |
+
raise ValueError("The provided function parameters is not a callable.")
|
121 |
+
|
122 |
+
if inspect.isclass(self.function):
|
123 |
+
raise ValueError(
|
124 |
+
"The provided callable is an uninitialized Class. This is not allowed."
|
125 |
+
)
|
126 |
+
|
127 |
+
if inspect.isfunction(self.function):
|
128 |
+
# The provided callable is a function
|
129 |
+
self._input_type = get_input_type(self.function)
|
130 |
+
self._output_type = get_output_type(self.function)
|
131 |
+
|
132 |
+
try:
|
133 |
+
# Get name
|
134 |
+
self._name = name_to_title(self.function.__name__)
|
135 |
+
except Exception:
|
136 |
+
pass
|
137 |
+
|
138 |
+
try:
|
139 |
+
# Get description from function
|
140 |
+
doc_string = inspect.getdoc(self.function)
|
141 |
+
if doc_string:
|
142 |
+
self._action = doc_string
|
143 |
+
except Exception:
|
144 |
+
pass
|
145 |
+
elif hasattr(self.function, "__call__"):
|
146 |
+
# The provided callable is a function
|
147 |
+
self._input_type = get_input_type(self.function.__call__) # type: ignore
|
148 |
+
self._output_type = get_output_type(self.function.__call__) # type: ignore
|
149 |
+
|
150 |
+
try:
|
151 |
+
# Get name
|
152 |
+
self._name = name_to_title(type(self.function).__name__)
|
153 |
+
except Exception:
|
154 |
+
pass
|
155 |
+
|
156 |
+
try:
|
157 |
+
# Get action from
|
158 |
+
doc_string = inspect.getdoc(self.function.__call__) # type: ignore
|
159 |
+
if doc_string:
|
160 |
+
self._action = doc_string
|
161 |
+
|
162 |
+
if (
|
163 |
+
not self._action
|
164 |
+
or self._action == "Call"
|
165 |
+
):
|
166 |
+
# Get docstring from class instead of __call__ function
|
167 |
+
doc_string = inspect.getdoc(self.function)
|
168 |
+
if doc_string:
|
169 |
+
self._action = doc_string
|
170 |
+
except Exception:
|
171 |
+
pass
|
172 |
+
else:
|
173 |
+
raise ValueError("Unknown callable type.")
|
174 |
+
|
175 |
+
@property
|
176 |
+
def name(self) -> str:
|
177 |
+
return self._name
|
178 |
+
|
179 |
+
@property
|
180 |
+
def action(self) -> str:
|
181 |
+
return self._action
|
182 |
+
|
183 |
+
@property
|
184 |
+
def input_type(self) -> Any:
|
185 |
+
return self._input_type
|
186 |
+
|
187 |
+
@property
|
188 |
+
def output_type(self) -> Any:
|
189 |
+
return self._output_type
|
190 |
+
|
191 |
+
def __call__(self, input: Any, **kwargs: Any) -> Any:
|
192 |
+
|
193 |
+
input_obj = input
|
194 |
+
|
195 |
+
if isinstance(input, str):
|
196 |
+
# Allow json input
|
197 |
+
input_obj = parse_raw_as(self.input_type, input)
|
198 |
+
|
199 |
+
if isinstance(input, dict):
|
200 |
+
# Allow dict input
|
201 |
+
input_obj = parse_obj_as(self.input_type, input)
|
202 |
+
|
203 |
+
return self.function(input_obj, **kwargs)
|
mkgui/base/ui/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .streamlit_ui import render_streamlit_ui
|
mkgui/base/ui/schema_utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
|
4 |
+
def resolve_reference(reference: str, references: Dict) -> Dict:
|
5 |
+
return references[reference.split("/")[-1]]
|
6 |
+
|
7 |
+
|
8 |
+
def get_single_reference_item(property: Dict, references: Dict) -> Dict:
|
9 |
+
# Ref can either be directly in the properties or the first element of allOf
|
10 |
+
reference = property.get("$ref")
|
11 |
+
if reference is None:
|
12 |
+
reference = property["allOf"][0]["$ref"]
|
13 |
+
return resolve_reference(reference, references)
|
14 |
+
|
15 |
+
|
16 |
+
def is_single_string_property(property: Dict) -> bool:
|
17 |
+
return property.get("type") == "string"
|
18 |
+
|
19 |
+
|
20 |
+
def is_single_datetime_property(property: Dict) -> bool:
|
21 |
+
if property.get("type") != "string":
|
22 |
+
return False
|
23 |
+
return property.get("format") in ["date-time", "time", "date"]
|
24 |
+
|
25 |
+
|
26 |
+
def is_single_boolean_property(property: Dict) -> bool:
|
27 |
+
return property.get("type") == "boolean"
|
28 |
+
|
29 |
+
|
30 |
+
def is_single_number_property(property: Dict) -> bool:
|
31 |
+
return property.get("type") in ["integer", "number"]
|
32 |
+
|
33 |
+
|
34 |
+
def is_single_file_property(property: Dict) -> bool:
|
35 |
+
if property.get("type") != "string":
|
36 |
+
return False
|
37 |
+
# TODO: binary?
|
38 |
+
return property.get("format") == "byte"
|
39 |
+
|
40 |
+
|
41 |
+
def is_single_directory_property(property: Dict) -> bool:
|
42 |
+
if property.get("type") != "string":
|
43 |
+
return False
|
44 |
+
return property.get("format") == "path"
|
45 |
+
|
46 |
+
def is_multi_enum_property(property: Dict, references: Dict) -> bool:
|
47 |
+
if property.get("type") != "array":
|
48 |
+
return False
|
49 |
+
|
50 |
+
if property.get("uniqueItems") is not True:
|
51 |
+
# Only relevant if it is a set or other datastructures with unique items
|
52 |
+
return False
|
53 |
+
|
54 |
+
try:
|
55 |
+
_ = resolve_reference(property["items"]["$ref"], references)["enum"]
|
56 |
+
return True
|
57 |
+
except Exception:
|
58 |
+
return False
|
59 |
+
|
60 |
+
|
61 |
+
def is_single_enum_property(property: Dict, references: Dict) -> bool:
|
62 |
+
try:
|
63 |
+
_ = get_single_reference_item(property, references)["enum"]
|
64 |
+
return True
|
65 |
+
except Exception:
|
66 |
+
return False
|
67 |
+
|
68 |
+
|
69 |
+
def is_single_dict_property(property: Dict) -> bool:
|
70 |
+
if property.get("type") != "object":
|
71 |
+
return False
|
72 |
+
return "additionalProperties" in property
|
73 |
+
|
74 |
+
|
75 |
+
def is_single_reference(property: Dict) -> bool:
|
76 |
+
if property.get("type") is not None:
|
77 |
+
return False
|
78 |
+
|
79 |
+
return bool(property.get("$ref"))
|
80 |
+
|
81 |
+
|
82 |
+
def is_multi_file_property(property: Dict) -> bool:
|
83 |
+
if property.get("type") != "array":
|
84 |
+
return False
|
85 |
+
|
86 |
+
if property.get("items") is None:
|
87 |
+
return False
|
88 |
+
|
89 |
+
try:
|
90 |
+
# TODO: binary
|
91 |
+
return property["items"]["format"] == "byte"
|
92 |
+
except Exception:
|
93 |
+
return False
|
94 |
+
|
95 |
+
|
96 |
+
def is_single_object(property: Dict, references: Dict) -> bool:
|
97 |
+
try:
|
98 |
+
object_reference = get_single_reference_item(property, references)
|
99 |
+
if object_reference["type"] != "object":
|
100 |
+
return False
|
101 |
+
return "properties" in object_reference
|
102 |
+
except Exception:
|
103 |
+
return False
|
104 |
+
|
105 |
+
|
106 |
+
def is_property_list(property: Dict) -> bool:
|
107 |
+
if property.get("type") != "array":
|
108 |
+
return False
|
109 |
+
|
110 |
+
if property.get("items") is None:
|
111 |
+
return False
|
112 |
+
|
113 |
+
try:
|
114 |
+
return property["items"]["type"] in ["string", "number", "integer"]
|
115 |
+
except Exception:
|
116 |
+
return False
|
117 |
+
|
118 |
+
|
119 |
+
def is_object_list_property(property: Dict, references: Dict) -> bool:
|
120 |
+
if property.get("type") != "array":
|
121 |
+
return False
|
122 |
+
|
123 |
+
try:
|
124 |
+
object_reference = resolve_reference(property["items"]["$ref"], references)
|
125 |
+
if object_reference["type"] != "object":
|
126 |
+
return False
|
127 |
+
return "properties" in object_reference
|
128 |
+
except Exception:
|
129 |
+
return False
|
mkgui/base/ui/streamlit_ui.py
ADDED
@@ -0,0 +1,888 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import inspect
|
3 |
+
import mimetypes
|
4 |
+
import sys
|
5 |
+
from os import getcwd, unlink
|
6 |
+
from platform import system
|
7 |
+
from tempfile import NamedTemporaryFile
|
8 |
+
from typing import Any, Callable, Dict, List, Type
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
import streamlit as st
|
13 |
+
from fastapi.encoders import jsonable_encoder
|
14 |
+
from loguru import logger
|
15 |
+
from pydantic import BaseModel, ValidationError, parse_obj_as
|
16 |
+
|
17 |
+
from mkgui.base import Opyrator
|
18 |
+
from mkgui.base.core import name_to_title
|
19 |
+
from mkgui.base.ui import schema_utils
|
20 |
+
from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
|
21 |
+
|
22 |
+
STREAMLIT_RUNNER_SNIPPET = """
|
23 |
+
from mkgui.base.ui import render_streamlit_ui
|
24 |
+
from mkgui.base import Opyrator
|
25 |
+
|
26 |
+
import streamlit as st
|
27 |
+
|
28 |
+
# TODO: Make it configurable
|
29 |
+
# Page config can only be setup once
|
30 |
+
st.set_page_config(
|
31 |
+
page_title="MockingBird",
|
32 |
+
page_icon="🧊",
|
33 |
+
layout="wide")
|
34 |
+
|
35 |
+
render_streamlit_ui()
|
36 |
+
"""
|
37 |
+
|
38 |
+
# with st.spinner("Loading MockingBird GUI. Please wait..."):
|
39 |
+
# opyrator = Opyrator("{opyrator_path}")
|
40 |
+
|
41 |
+
|
42 |
+
def launch_ui(port: int = 8501) -> None:
|
43 |
+
with NamedTemporaryFile(
|
44 |
+
suffix=".py", mode="w", encoding="utf-8", delete=False
|
45 |
+
) as f:
|
46 |
+
f.write(STREAMLIT_RUNNER_SNIPPET)
|
47 |
+
f.seek(0)
|
48 |
+
|
49 |
+
import subprocess
|
50 |
+
|
51 |
+
python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
|
52 |
+
if system() == "Windows":
|
53 |
+
python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
|
54 |
+
subprocess.run(
|
55 |
+
f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
|
56 |
+
shell=True,
|
57 |
+
)
|
58 |
+
|
59 |
+
subprocess.run(
|
60 |
+
f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
|
61 |
+
shell=True,
|
62 |
+
)
|
63 |
+
|
64 |
+
f.close()
|
65 |
+
unlink(f.name)
|
66 |
+
|
67 |
+
|
68 |
+
def function_has_named_arg(func: Callable, parameter: str) -> bool:
|
69 |
+
try:
|
70 |
+
sig = inspect.signature(func)
|
71 |
+
for param in sig.parameters.values():
|
72 |
+
if param.name == "input":
|
73 |
+
return True
|
74 |
+
except Exception:
|
75 |
+
return False
|
76 |
+
return False
|
77 |
+
|
78 |
+
|
79 |
+
def has_output_ui_renderer(data_item: BaseModel) -> bool:
|
80 |
+
return hasattr(data_item, "render_output_ui")
|
81 |
+
|
82 |
+
|
83 |
+
def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
|
84 |
+
return hasattr(input_class, "render_input_ui")
|
85 |
+
|
86 |
+
|
87 |
+
def is_compatible_audio(mime_type: str) -> bool:
|
88 |
+
return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
|
89 |
+
|
90 |
+
|
91 |
+
def is_compatible_image(mime_type: str) -> bool:
|
92 |
+
return mime_type in ["image/png", "image/jpeg"]
|
93 |
+
|
94 |
+
|
95 |
+
def is_compatible_video(mime_type: str) -> bool:
|
96 |
+
return mime_type in ["video/mp4"]
|
97 |
+
|
98 |
+
|
99 |
+
class InputUI:
|
100 |
+
def __init__(self, session_state, input_class: Type[BaseModel]):
|
101 |
+
self._session_state = session_state
|
102 |
+
self._input_class = input_class
|
103 |
+
|
104 |
+
self._schema_properties = input_class.schema(by_alias=True).get(
|
105 |
+
"properties", {}
|
106 |
+
)
|
107 |
+
self._schema_references = input_class.schema(by_alias=True).get(
|
108 |
+
"definitions", {}
|
109 |
+
)
|
110 |
+
|
111 |
+
def render_ui(self, streamlit_app_root) -> None:
|
112 |
+
if has_input_ui_renderer(self._input_class):
|
113 |
+
# The input model has a rendering function
|
114 |
+
# The rendering also returns the current state of input data
|
115 |
+
self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
|
116 |
+
st, self._session_state.input_data
|
117 |
+
)
|
118 |
+
return
|
119 |
+
|
120 |
+
# print(self._schema_properties)
|
121 |
+
for property_key in self._schema_properties.keys():
|
122 |
+
property = self._schema_properties[property_key]
|
123 |
+
|
124 |
+
if not property.get("title"):
|
125 |
+
# Set property key as fallback title
|
126 |
+
property["title"] = name_to_title(property_key)
|
127 |
+
|
128 |
+
try:
|
129 |
+
if "input_data" in self._session_state:
|
130 |
+
self._store_value(
|
131 |
+
property_key,
|
132 |
+
self._render_property(streamlit_app_root, property_key, property),
|
133 |
+
)
|
134 |
+
except Exception as e:
|
135 |
+
print("Exception!", e)
|
136 |
+
pass
|
137 |
+
|
138 |
+
def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
|
139 |
+
streamlit_kwargs = {
|
140 |
+
"label": property.get("title"),
|
141 |
+
"key": key,
|
142 |
+
}
|
143 |
+
|
144 |
+
if property.get("description"):
|
145 |
+
streamlit_kwargs["help"] = property.get("description")
|
146 |
+
return streamlit_kwargs
|
147 |
+
|
148 |
+
def _store_value(self, key: str, value: Any) -> None:
|
149 |
+
data_element = self._session_state.input_data
|
150 |
+
key_elements = key.split(".")
|
151 |
+
for i, key_element in enumerate(key_elements):
|
152 |
+
if i == len(key_elements) - 1:
|
153 |
+
# add value to this element
|
154 |
+
data_element[key_element] = value
|
155 |
+
return
|
156 |
+
if key_element not in data_element:
|
157 |
+
data_element[key_element] = {}
|
158 |
+
data_element = data_element[key_element]
|
159 |
+
|
160 |
+
def _get_value(self, key: str) -> Any:
|
161 |
+
data_element = self._session_state.input_data
|
162 |
+
key_elements = key.split(".")
|
163 |
+
for i, key_element in enumerate(key_elements):
|
164 |
+
if i == len(key_elements) - 1:
|
165 |
+
# add value to this element
|
166 |
+
if key_element not in data_element:
|
167 |
+
return None
|
168 |
+
return data_element[key_element]
|
169 |
+
if key_element not in data_element:
|
170 |
+
data_element[key_element] = {}
|
171 |
+
data_element = data_element[key_element]
|
172 |
+
return None
|
173 |
+
|
174 |
+
def _render_single_datetime_input(
|
175 |
+
self, streamlit_app: st, key: str, property: Dict
|
176 |
+
) -> Any:
|
177 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
178 |
+
|
179 |
+
if property.get("format") == "time":
|
180 |
+
if property.get("default"):
|
181 |
+
try:
|
182 |
+
streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
|
183 |
+
property.get("default")
|
184 |
+
)
|
185 |
+
except Exception:
|
186 |
+
pass
|
187 |
+
return streamlit_app.time_input(**streamlit_kwargs)
|
188 |
+
elif property.get("format") == "date":
|
189 |
+
if property.get("default"):
|
190 |
+
try:
|
191 |
+
streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
|
192 |
+
property.get("default")
|
193 |
+
)
|
194 |
+
except Exception:
|
195 |
+
pass
|
196 |
+
return streamlit_app.date_input(**streamlit_kwargs)
|
197 |
+
elif property.get("format") == "date-time":
|
198 |
+
if property.get("default"):
|
199 |
+
try:
|
200 |
+
streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
|
201 |
+
property.get("default")
|
202 |
+
)
|
203 |
+
except Exception:
|
204 |
+
pass
|
205 |
+
with streamlit_app.container():
|
206 |
+
streamlit_app.subheader(streamlit_kwargs.get("label"))
|
207 |
+
if streamlit_kwargs.get("description"):
|
208 |
+
streamlit_app.text(streamlit_kwargs.get("description"))
|
209 |
+
selected_date = None
|
210 |
+
selected_time = None
|
211 |
+
date_col, time_col = streamlit_app.columns(2)
|
212 |
+
with date_col:
|
213 |
+
date_kwargs = {"label": "Date", "key": key + "-date-input"}
|
214 |
+
if streamlit_kwargs.get("value"):
|
215 |
+
try:
|
216 |
+
date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
217 |
+
"value"
|
218 |
+
).date()
|
219 |
+
except Exception:
|
220 |
+
pass
|
221 |
+
selected_date = streamlit_app.date_input(**date_kwargs)
|
222 |
+
|
223 |
+
with time_col:
|
224 |
+
time_kwargs = {"label": "Time", "key": key + "-time-input"}
|
225 |
+
if streamlit_kwargs.get("value"):
|
226 |
+
try:
|
227 |
+
time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
|
228 |
+
"value"
|
229 |
+
).time()
|
230 |
+
except Exception:
|
231 |
+
pass
|
232 |
+
selected_time = streamlit_app.time_input(**time_kwargs)
|
233 |
+
return datetime.datetime.combine(selected_date, selected_time)
|
234 |
+
else:
|
235 |
+
streamlit_app.warning(
|
236 |
+
"Date format is not supported: " + str(property.get("format"))
|
237 |
+
)
|
238 |
+
|
239 |
+
def _render_single_file_input(
|
240 |
+
self, streamlit_app: st, key: str, property: Dict
|
241 |
+
) -> Any:
|
242 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
243 |
+
file_extension = None
|
244 |
+
if "mime_type" in property:
|
245 |
+
file_extension = mimetypes.guess_extension(property["mime_type"])
|
246 |
+
|
247 |
+
uploaded_file = streamlit_app.file_uploader(
|
248 |
+
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
249 |
+
)
|
250 |
+
if uploaded_file is None:
|
251 |
+
return None
|
252 |
+
|
253 |
+
bytes = uploaded_file.getvalue()
|
254 |
+
if property.get("mime_type"):
|
255 |
+
if is_compatible_audio(property["mime_type"]):
|
256 |
+
# Show audio
|
257 |
+
streamlit_app.audio(bytes, format=property.get("mime_type"))
|
258 |
+
if is_compatible_image(property["mime_type"]):
|
259 |
+
# Show image
|
260 |
+
streamlit_app.image(bytes)
|
261 |
+
if is_compatible_video(property["mime_type"]):
|
262 |
+
# Show video
|
263 |
+
streamlit_app.video(bytes, format=property.get("mime_type"))
|
264 |
+
return bytes
|
265 |
+
|
266 |
+
def _render_single_string_input(
|
267 |
+
self, streamlit_app: st, key: str, property: Dict
|
268 |
+
) -> Any:
|
269 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
270 |
+
|
271 |
+
if property.get("default"):
|
272 |
+
streamlit_kwargs["value"] = property.get("default")
|
273 |
+
elif property.get("example"):
|
274 |
+
# TODO: also use example for other property types
|
275 |
+
# Use example as value if it is provided
|
276 |
+
streamlit_kwargs["value"] = property.get("example")
|
277 |
+
|
278 |
+
if property.get("maxLength") is not None:
|
279 |
+
streamlit_kwargs["max_chars"] = property.get("maxLength")
|
280 |
+
|
281 |
+
if (
|
282 |
+
property.get("format")
|
283 |
+
or (
|
284 |
+
property.get("maxLength") is not None
|
285 |
+
and int(property.get("maxLength")) < 140 # type: ignore
|
286 |
+
)
|
287 |
+
or property.get("writeOnly")
|
288 |
+
):
|
289 |
+
# If any format is set, use single text input
|
290 |
+
# If max chars is set to less than 140, use single text input
|
291 |
+
# If write only -> password field
|
292 |
+
if property.get("writeOnly"):
|
293 |
+
streamlit_kwargs["type"] = "password"
|
294 |
+
return streamlit_app.text_input(**streamlit_kwargs)
|
295 |
+
else:
|
296 |
+
# Otherwise use multiline text area
|
297 |
+
return streamlit_app.text_area(**streamlit_kwargs)
|
298 |
+
|
299 |
+
def _render_multi_enum_input(
|
300 |
+
self, streamlit_app: st, key: str, property: Dict
|
301 |
+
) -> Any:
|
302 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
303 |
+
reference_item = schema_utils.resolve_reference(
|
304 |
+
property["items"]["$ref"], self._schema_references
|
305 |
+
)
|
306 |
+
# TODO: how to select defaults
|
307 |
+
return streamlit_app.multiselect(
|
308 |
+
**streamlit_kwargs, options=reference_item["enum"]
|
309 |
+
)
|
310 |
+
|
311 |
+
def _render_single_enum_input(
|
312 |
+
self, streamlit_app: st, key: str, property: Dict
|
313 |
+
) -> Any:
|
314 |
+
|
315 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
316 |
+
reference_item = schema_utils.get_single_reference_item(
|
317 |
+
property, self._schema_references
|
318 |
+
)
|
319 |
+
|
320 |
+
if property.get("default") is not None:
|
321 |
+
try:
|
322 |
+
streamlit_kwargs["index"] = reference_item["enum"].index(
|
323 |
+
property.get("default")
|
324 |
+
)
|
325 |
+
except Exception:
|
326 |
+
# Use default selection
|
327 |
+
pass
|
328 |
+
|
329 |
+
return streamlit_app.selectbox(
|
330 |
+
**streamlit_kwargs, options=reference_item["enum"]
|
331 |
+
)
|
332 |
+
|
333 |
+
def _render_single_dict_input(
|
334 |
+
self, streamlit_app: st, key: str, property: Dict
|
335 |
+
) -> Any:
|
336 |
+
|
337 |
+
# Add title and subheader
|
338 |
+
streamlit_app.subheader(property.get("title"))
|
339 |
+
if property.get("description"):
|
340 |
+
streamlit_app.markdown(property.get("description"))
|
341 |
+
|
342 |
+
streamlit_app.markdown("---")
|
343 |
+
|
344 |
+
current_dict = self._get_value(key)
|
345 |
+
if not current_dict:
|
346 |
+
current_dict = {}
|
347 |
+
|
348 |
+
key_col, value_col = streamlit_app.columns(2)
|
349 |
+
|
350 |
+
with key_col:
|
351 |
+
updated_key = streamlit_app.text_input(
|
352 |
+
"Key", value="", key=key + "-new-key"
|
353 |
+
)
|
354 |
+
|
355 |
+
with value_col:
|
356 |
+
# TODO: also add boolean?
|
357 |
+
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
358 |
+
if property["additionalProperties"].get("type") == "integer":
|
359 |
+
value_kwargs["value"] = 0 # type: ignore
|
360 |
+
updated_value = streamlit_app.number_input(**value_kwargs)
|
361 |
+
elif property["additionalProperties"].get("type") == "number":
|
362 |
+
value_kwargs["value"] = 0.0 # type: ignore
|
363 |
+
value_kwargs["format"] = "%f"
|
364 |
+
updated_value = streamlit_app.number_input(**value_kwargs)
|
365 |
+
else:
|
366 |
+
value_kwargs["value"] = ""
|
367 |
+
updated_value = streamlit_app.text_input(**value_kwargs)
|
368 |
+
|
369 |
+
streamlit_app.markdown("---")
|
370 |
+
|
371 |
+
with streamlit_app.container():
|
372 |
+
clear_col, add_col = streamlit_app.columns([1, 2])
|
373 |
+
|
374 |
+
with clear_col:
|
375 |
+
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
376 |
+
current_dict = {}
|
377 |
+
|
378 |
+
with add_col:
|
379 |
+
if (
|
380 |
+
streamlit_app.button("Add Item", key=key + "-add-item")
|
381 |
+
and updated_key
|
382 |
+
):
|
383 |
+
current_dict[updated_key] = updated_value
|
384 |
+
|
385 |
+
streamlit_app.write(current_dict)
|
386 |
+
|
387 |
+
return current_dict
|
388 |
+
|
389 |
+
def _render_single_reference(
|
390 |
+
self, streamlit_app: st, key: str, property: Dict
|
391 |
+
) -> Any:
|
392 |
+
reference_item = schema_utils.get_single_reference_item(
|
393 |
+
property, self._schema_references
|
394 |
+
)
|
395 |
+
return self._render_property(streamlit_app, key, reference_item)
|
396 |
+
|
397 |
+
def _render_multi_file_input(
|
398 |
+
self, streamlit_app: st, key: str, property: Dict
|
399 |
+
) -> Any:
|
400 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
401 |
+
|
402 |
+
file_extension = None
|
403 |
+
if "mime_type" in property:
|
404 |
+
file_extension = mimetypes.guess_extension(property["mime_type"])
|
405 |
+
|
406 |
+
uploaded_files = streamlit_app.file_uploader(
|
407 |
+
**streamlit_kwargs, accept_multiple_files=True, type=file_extension
|
408 |
+
)
|
409 |
+
uploaded_files_bytes = []
|
410 |
+
if uploaded_files:
|
411 |
+
for uploaded_file in uploaded_files:
|
412 |
+
uploaded_files_bytes.append(uploaded_file.read())
|
413 |
+
return uploaded_files_bytes
|
414 |
+
|
415 |
+
def _render_single_boolean_input(
|
416 |
+
self, streamlit_app: st, key: str, property: Dict
|
417 |
+
) -> Any:
|
418 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
419 |
+
|
420 |
+
if property.get("default"):
|
421 |
+
streamlit_kwargs["value"] = property.get("default")
|
422 |
+
return streamlit_app.checkbox(**streamlit_kwargs)
|
423 |
+
|
424 |
+
def _render_single_number_input(
|
425 |
+
self, streamlit_app: st, key: str, property: Dict
|
426 |
+
) -> Any:
|
427 |
+
streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
428 |
+
|
429 |
+
number_transform = int
|
430 |
+
if property.get("type") == "number":
|
431 |
+
number_transform = float # type: ignore
|
432 |
+
streamlit_kwargs["format"] = "%f"
|
433 |
+
|
434 |
+
if "multipleOf" in property:
|
435 |
+
# Set stepcount based on multiple of parameter
|
436 |
+
streamlit_kwargs["step"] = number_transform(property["multipleOf"])
|
437 |
+
elif number_transform == int:
|
438 |
+
# Set step size to 1 as default
|
439 |
+
streamlit_kwargs["step"] = 1
|
440 |
+
elif number_transform == float:
|
441 |
+
# Set step size to 0.01 as default
|
442 |
+
# TODO: adapt to default value
|
443 |
+
streamlit_kwargs["step"] = 0.01
|
444 |
+
|
445 |
+
if "minimum" in property:
|
446 |
+
streamlit_kwargs["min_value"] = number_transform(property["minimum"])
|
447 |
+
if "exclusiveMinimum" in property:
|
448 |
+
streamlit_kwargs["min_value"] = number_transform(
|
449 |
+
property["exclusiveMinimum"] + streamlit_kwargs["step"]
|
450 |
+
)
|
451 |
+
if "maximum" in property:
|
452 |
+
streamlit_kwargs["max_value"] = number_transform(property["maximum"])
|
453 |
+
|
454 |
+
if "exclusiveMaximum" in property:
|
455 |
+
streamlit_kwargs["max_value"] = number_transform(
|
456 |
+
property["exclusiveMaximum"] - streamlit_kwargs["step"]
|
457 |
+
)
|
458 |
+
|
459 |
+
if property.get("default") is not None:
|
460 |
+
streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
|
461 |
+
else:
|
462 |
+
if "min_value" in streamlit_kwargs:
|
463 |
+
streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
|
464 |
+
elif number_transform == int:
|
465 |
+
streamlit_kwargs["value"] = 0
|
466 |
+
else:
|
467 |
+
# Set default value to step
|
468 |
+
streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
|
469 |
+
|
470 |
+
if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
|
471 |
+
# TODO: Only if less than X steps
|
472 |
+
return streamlit_app.slider(**streamlit_kwargs)
|
473 |
+
else:
|
474 |
+
return streamlit_app.number_input(**streamlit_kwargs)
|
475 |
+
|
476 |
+
def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
477 |
+
properties = property["properties"]
|
478 |
+
object_inputs = {}
|
479 |
+
for property_key in properties:
|
480 |
+
property = properties[property_key]
|
481 |
+
if not property.get("title"):
|
482 |
+
# Set property key as fallback title
|
483 |
+
property["title"] = name_to_title(property_key)
|
484 |
+
# construct full key based on key parts -> required later to get the value
|
485 |
+
full_key = key + "." + property_key
|
486 |
+
object_inputs[property_key] = self._render_property(
|
487 |
+
streamlit_app, full_key, property
|
488 |
+
)
|
489 |
+
return object_inputs
|
490 |
+
|
491 |
+
def _render_single_object_input(
|
492 |
+
self, streamlit_app: st, key: str, property: Dict
|
493 |
+
) -> Any:
|
494 |
+
# Add title and subheader
|
495 |
+
title = property.get("title")
|
496 |
+
streamlit_app.subheader(title)
|
497 |
+
if property.get("description"):
|
498 |
+
streamlit_app.markdown(property.get("description"))
|
499 |
+
|
500 |
+
object_reference = schema_utils.get_single_reference_item(
|
501 |
+
property, self._schema_references
|
502 |
+
)
|
503 |
+
return self._render_object_input(streamlit_app, key, object_reference)
|
504 |
+
|
505 |
+
def _render_property_list_input(
|
506 |
+
self, streamlit_app: st, key: str, property: Dict
|
507 |
+
) -> Any:
|
508 |
+
|
509 |
+
# Add title and subheader
|
510 |
+
streamlit_app.subheader(property.get("title"))
|
511 |
+
if property.get("description"):
|
512 |
+
streamlit_app.markdown(property.get("description"))
|
513 |
+
|
514 |
+
streamlit_app.markdown("---")
|
515 |
+
|
516 |
+
current_list = self._get_value(key)
|
517 |
+
if not current_list:
|
518 |
+
current_list = []
|
519 |
+
|
520 |
+
value_kwargs = {"label": "Value", "key": key + "-new-value"}
|
521 |
+
if property["items"]["type"] == "integer":
|
522 |
+
value_kwargs["value"] = 0 # type: ignore
|
523 |
+
new_value = streamlit_app.number_input(**value_kwargs)
|
524 |
+
elif property["items"]["type"] == "number":
|
525 |
+
value_kwargs["value"] = 0.0 # type: ignore
|
526 |
+
value_kwargs["format"] = "%f"
|
527 |
+
new_value = streamlit_app.number_input(**value_kwargs)
|
528 |
+
else:
|
529 |
+
value_kwargs["value"] = ""
|
530 |
+
new_value = streamlit_app.text_input(**value_kwargs)
|
531 |
+
|
532 |
+
streamlit_app.markdown("---")
|
533 |
+
|
534 |
+
with streamlit_app.container():
|
535 |
+
clear_col, add_col = streamlit_app.columns([1, 2])
|
536 |
+
|
537 |
+
with clear_col:
|
538 |
+
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
539 |
+
current_list = []
|
540 |
+
|
541 |
+
with add_col:
|
542 |
+
if (
|
543 |
+
streamlit_app.button("Add Item", key=key + "-add-item")
|
544 |
+
and new_value is not None
|
545 |
+
):
|
546 |
+
current_list.append(new_value)
|
547 |
+
|
548 |
+
streamlit_app.write(current_list)
|
549 |
+
|
550 |
+
return current_list
|
551 |
+
|
552 |
+
def _render_object_list_input(
|
553 |
+
self, streamlit_app: st, key: str, property: Dict
|
554 |
+
) -> Any:
|
555 |
+
|
556 |
+
# TODO: support max_items, and min_items properties
|
557 |
+
|
558 |
+
# Add title and subheader
|
559 |
+
streamlit_app.subheader(property.get("title"))
|
560 |
+
if property.get("description"):
|
561 |
+
streamlit_app.markdown(property.get("description"))
|
562 |
+
|
563 |
+
streamlit_app.markdown("---")
|
564 |
+
|
565 |
+
current_list = self._get_value(key)
|
566 |
+
if not current_list:
|
567 |
+
current_list = []
|
568 |
+
|
569 |
+
object_reference = schema_utils.resolve_reference(
|
570 |
+
property["items"]["$ref"], self._schema_references
|
571 |
+
)
|
572 |
+
input_data = self._render_object_input(streamlit_app, key, object_reference)
|
573 |
+
|
574 |
+
streamlit_app.markdown("---")
|
575 |
+
|
576 |
+
with streamlit_app.container():
|
577 |
+
clear_col, add_col = streamlit_app.columns([1, 2])
|
578 |
+
|
579 |
+
with clear_col:
|
580 |
+
if streamlit_app.button("Clear Items", key=key + "-clear-items"):
|
581 |
+
current_list = []
|
582 |
+
|
583 |
+
with add_col:
|
584 |
+
if (
|
585 |
+
streamlit_app.button("Add Item", key=key + "-add-item")
|
586 |
+
and input_data
|
587 |
+
):
|
588 |
+
current_list.append(input_data)
|
589 |
+
|
590 |
+
streamlit_app.write(current_list)
|
591 |
+
return current_list
|
592 |
+
|
593 |
+
def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
|
594 |
+
if schema_utils.is_single_enum_property(property, self._schema_references):
|
595 |
+
return self._render_single_enum_input(streamlit_app, key, property)
|
596 |
+
|
597 |
+
if schema_utils.is_multi_enum_property(property, self._schema_references):
|
598 |
+
return self._render_multi_enum_input(streamlit_app, key, property)
|
599 |
+
|
600 |
+
if schema_utils.is_single_file_property(property):
|
601 |
+
return self._render_single_file_input(streamlit_app, key, property)
|
602 |
+
|
603 |
+
if schema_utils.is_multi_file_property(property):
|
604 |
+
return self._render_multi_file_input(streamlit_app, key, property)
|
605 |
+
|
606 |
+
if schema_utils.is_single_datetime_property(property):
|
607 |
+
return self._render_single_datetime_input(streamlit_app, key, property)
|
608 |
+
|
609 |
+
if schema_utils.is_single_boolean_property(property):
|
610 |
+
return self._render_single_boolean_input(streamlit_app, key, property)
|
611 |
+
|
612 |
+
if schema_utils.is_single_dict_property(property):
|
613 |
+
return self._render_single_dict_input(streamlit_app, key, property)
|
614 |
+
|
615 |
+
if schema_utils.is_single_number_property(property):
|
616 |
+
return self._render_single_number_input(streamlit_app, key, property)
|
617 |
+
|
618 |
+
if schema_utils.is_single_string_property(property):
|
619 |
+
return self._render_single_string_input(streamlit_app, key, property)
|
620 |
+
|
621 |
+
if schema_utils.is_single_object(property, self._schema_references):
|
622 |
+
return self._render_single_object_input(streamlit_app, key, property)
|
623 |
+
|
624 |
+
if schema_utils.is_object_list_property(property, self._schema_references):
|
625 |
+
return self._render_object_list_input(streamlit_app, key, property)
|
626 |
+
|
627 |
+
if schema_utils.is_property_list(property):
|
628 |
+
return self._render_property_list_input(streamlit_app, key, property)
|
629 |
+
|
630 |
+
if schema_utils.is_single_reference(property):
|
631 |
+
return self._render_single_reference(streamlit_app, key, property)
|
632 |
+
|
633 |
+
streamlit_app.warning(
|
634 |
+
"The type of the following property is currently not supported: "
|
635 |
+
+ str(property.get("title"))
|
636 |
+
)
|
637 |
+
raise Exception("Unsupported property")
|
638 |
+
|
639 |
+
|
640 |
+
class OutputUI:
|
641 |
+
def __init__(self, output_data: Any, input_data: Any):
|
642 |
+
self._output_data = output_data
|
643 |
+
self._input_data = input_data
|
644 |
+
|
645 |
+
def render_ui(self, streamlit_app) -> None:
|
646 |
+
try:
|
647 |
+
if isinstance(self._output_data, BaseModel):
|
648 |
+
self._render_single_output(streamlit_app, self._output_data)
|
649 |
+
return
|
650 |
+
if type(self._output_data) == list:
|
651 |
+
self._render_list_output(streamlit_app, self._output_data)
|
652 |
+
return
|
653 |
+
except Exception as ex:
|
654 |
+
streamlit_app.exception(ex)
|
655 |
+
# Fallback to
|
656 |
+
streamlit_app.json(jsonable_encoder(self._output_data))
|
657 |
+
|
658 |
+
def _render_single_text_property(
|
659 |
+
self, streamlit: st, property_schema: Dict, value: Any
|
660 |
+
) -> None:
|
661 |
+
# Add title and subheader
|
662 |
+
streamlit.subheader(property_schema.get("title"))
|
663 |
+
if property_schema.get("description"):
|
664 |
+
streamlit.markdown(property_schema.get("description"))
|
665 |
+
if value is None or value == "":
|
666 |
+
streamlit.info("No value returned!")
|
667 |
+
else:
|
668 |
+
streamlit.code(str(value), language="plain")
|
669 |
+
|
670 |
+
def _render_single_file_property(
|
671 |
+
self, streamlit: st, property_schema: Dict, value: Any
|
672 |
+
) -> None:
|
673 |
+
# Add title and subheader
|
674 |
+
streamlit.subheader(property_schema.get("title"))
|
675 |
+
if property_schema.get("description"):
|
676 |
+
streamlit.markdown(property_schema.get("description"))
|
677 |
+
if value is None or value == "":
|
678 |
+
streamlit.info("No value returned!")
|
679 |
+
else:
|
680 |
+
# TODO: Detect if it is a FileContent instance
|
681 |
+
# TODO: detect if it is base64
|
682 |
+
file_extension = ""
|
683 |
+
if "mime_type" in property_schema:
|
684 |
+
mime_type = property_schema["mime_type"]
|
685 |
+
file_extension = mimetypes.guess_extension(mime_type) or ""
|
686 |
+
|
687 |
+
if is_compatible_audio(mime_type):
|
688 |
+
streamlit.audio(value.as_bytes(), format=mime_type)
|
689 |
+
return
|
690 |
+
|
691 |
+
if is_compatible_image(mime_type):
|
692 |
+
streamlit.image(value.as_bytes())
|
693 |
+
return
|
694 |
+
|
695 |
+
if is_compatible_video(mime_type):
|
696 |
+
streamlit.video(value.as_bytes(), format=mime_type)
|
697 |
+
return
|
698 |
+
|
699 |
+
filename = (
|
700 |
+
(property_schema["title"] + file_extension)
|
701 |
+
.lower()
|
702 |
+
.strip()
|
703 |
+
.replace(" ", "-")
|
704 |
+
)
|
705 |
+
streamlit.markdown(
|
706 |
+
f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
|
707 |
+
unsafe_allow_html=True,
|
708 |
+
)
|
709 |
+
|
710 |
+
def _render_single_complex_property(
|
711 |
+
self, streamlit: st, property_schema: Dict, value: Any
|
712 |
+
) -> None:
|
713 |
+
# Add title and subheader
|
714 |
+
streamlit.subheader(property_schema.get("title"))
|
715 |
+
if property_schema.get("description"):
|
716 |
+
streamlit.markdown(property_schema.get("description"))
|
717 |
+
|
718 |
+
streamlit.json(jsonable_encoder(value))
|
719 |
+
|
720 |
+
def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
|
721 |
+
try:
|
722 |
+
if has_output_ui_renderer(output_data):
|
723 |
+
if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
|
724 |
+
# render method also requests the input data
|
725 |
+
output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
|
726 |
+
else:
|
727 |
+
output_data.render_output_ui(streamlit) # type: ignore
|
728 |
+
return
|
729 |
+
except Exception:
|
730 |
+
# Use default auto-generation methods if the custom rendering throws an exception
|
731 |
+
logger.exception(
|
732 |
+
"Failed to execute custom render_output_ui function. Using auto-generation instead"
|
733 |
+
)
|
734 |
+
|
735 |
+
model_schema = output_data.schema(by_alias=False)
|
736 |
+
model_properties = model_schema.get("properties")
|
737 |
+
definitions = model_schema.get("definitions")
|
738 |
+
|
739 |
+
if model_properties:
|
740 |
+
for property_key in output_data.__dict__:
|
741 |
+
property_schema = model_properties.get(property_key)
|
742 |
+
if not property_schema.get("title"):
|
743 |
+
# Set property key as fallback title
|
744 |
+
property_schema["title"] = property_key
|
745 |
+
|
746 |
+
output_property_value = output_data.__dict__[property_key]
|
747 |
+
|
748 |
+
if has_output_ui_renderer(output_property_value):
|
749 |
+
output_property_value.render_output_ui(streamlit) # type: ignore
|
750 |
+
continue
|
751 |
+
|
752 |
+
if isinstance(output_property_value, BaseModel):
|
753 |
+
# Render output recursivly
|
754 |
+
streamlit.subheader(property_schema.get("title"))
|
755 |
+
if property_schema.get("description"):
|
756 |
+
streamlit.markdown(property_schema.get("description"))
|
757 |
+
self._render_single_output(streamlit, output_property_value)
|
758 |
+
continue
|
759 |
+
|
760 |
+
if property_schema:
|
761 |
+
if schema_utils.is_single_file_property(property_schema):
|
762 |
+
self._render_single_file_property(
|
763 |
+
streamlit, property_schema, output_property_value
|
764 |
+
)
|
765 |
+
continue
|
766 |
+
|
767 |
+
if (
|
768 |
+
schema_utils.is_single_string_property(property_schema)
|
769 |
+
or schema_utils.is_single_number_property(property_schema)
|
770 |
+
or schema_utils.is_single_datetime_property(property_schema)
|
771 |
+
or schema_utils.is_single_boolean_property(property_schema)
|
772 |
+
):
|
773 |
+
self._render_single_text_property(
|
774 |
+
streamlit, property_schema, output_property_value
|
775 |
+
)
|
776 |
+
continue
|
777 |
+
if definitions and schema_utils.is_single_enum_property(
|
778 |
+
property_schema, definitions
|
779 |
+
):
|
780 |
+
self._render_single_text_property(
|
781 |
+
streamlit, property_schema, output_property_value.value
|
782 |
+
)
|
783 |
+
continue
|
784 |
+
|
785 |
+
# TODO: render dict as table
|
786 |
+
|
787 |
+
self._render_single_complex_property(
|
788 |
+
streamlit, property_schema, output_property_value
|
789 |
+
)
|
790 |
+
return
|
791 |
+
|
792 |
+
def _render_list_output(self, streamlit: st, output_data: List) -> None:
|
793 |
+
try:
|
794 |
+
data_items: List = []
|
795 |
+
for data_item in output_data:
|
796 |
+
if has_output_ui_renderer(data_item):
|
797 |
+
# Render using the render function
|
798 |
+
data_item.render_output_ui(streamlit) # type: ignore
|
799 |
+
continue
|
800 |
+
data_items.append(data_item.dict())
|
801 |
+
# Try to show as dataframe
|
802 |
+
streamlit.table(pd.DataFrame(data_items))
|
803 |
+
except Exception:
|
804 |
+
# Fallback to
|
805 |
+
streamlit.json(jsonable_encoder(output_data))
|
806 |
+
|
807 |
+
|
808 |
+
def getOpyrator(mode: str) -> Opyrator:
|
809 |
+
if mode == None or mode.startswith('VC'):
|
810 |
+
from mkgui.app_vc import convert
|
811 |
+
return Opyrator(convert)
|
812 |
+
if mode == None or mode.startswith('预处理'):
|
813 |
+
from mkgui.preprocess import preprocess
|
814 |
+
return Opyrator(preprocess)
|
815 |
+
if mode == None or mode.startswith('模型训练'):
|
816 |
+
from mkgui.train import train
|
817 |
+
return Opyrator(train)
|
818 |
+
if mode == None or mode.startswith('模型训练(VC)'):
|
819 |
+
from mkgui.train_vc import train_vc
|
820 |
+
return Opyrator(train_vc)
|
821 |
+
from mkgui.app import synthesize
|
822 |
+
return Opyrator(synthesize)
|
823 |
+
|
824 |
+
|
825 |
+
def render_streamlit_ui() -> None:
|
826 |
+
# init
|
827 |
+
session_state = st.session_state
|
828 |
+
session_state.input_data = {}
|
829 |
+
# Add custom css settings
|
830 |
+
st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
|
831 |
+
|
832 |
+
with st.spinner("Loading MockingBird GUI. Please wait..."):
|
833 |
+
session_state.mode = st.sidebar.selectbox(
|
834 |
+
'模式选择',
|
835 |
+
( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
|
836 |
+
)
|
837 |
+
if "mode" in session_state:
|
838 |
+
mode = session_state.mode
|
839 |
+
else:
|
840 |
+
mode = ""
|
841 |
+
opyrator = getOpyrator(mode)
|
842 |
+
title = opyrator.name + mode
|
843 |
+
|
844 |
+
col1, col2, _ = st.columns(3)
|
845 |
+
col2.title(title)
|
846 |
+
col2.markdown("欢迎使用MockingBird Web 2")
|
847 |
+
|
848 |
+
image = Image.open('.\\mkgui\\static\\mb.png')
|
849 |
+
col1.image(image)
|
850 |
+
|
851 |
+
st.markdown("---")
|
852 |
+
left, right = st.columns([0.4, 0.6])
|
853 |
+
|
854 |
+
with left:
|
855 |
+
st.header("Control 控制")
|
856 |
+
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
857 |
+
execute_selected = st.button(opyrator.action)
|
858 |
+
if execute_selected:
|
859 |
+
with st.spinner("Executing operation. Please wait..."):
|
860 |
+
try:
|
861 |
+
input_data_obj = parse_obj_as(
|
862 |
+
opyrator.input_type, session_state.input_data
|
863 |
+
)
|
864 |
+
session_state.output_data = opyrator(input=input_data_obj)
|
865 |
+
session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
|
866 |
+
except ValidationError as ex:
|
867 |
+
st.error(ex)
|
868 |
+
else:
|
869 |
+
# st.success("Operation executed successfully.")
|
870 |
+
pass
|
871 |
+
|
872 |
+
with right:
|
873 |
+
st.header("Result 结果")
|
874 |
+
if 'output_data' in session_state:
|
875 |
+
OutputUI(
|
876 |
+
session_state.output_data, session_state.latest_operation_input
|
877 |
+
).render_ui(st)
|
878 |
+
if st.button("Clear"):
|
879 |
+
# Clear all state
|
880 |
+
for key in st.session_state.keys():
|
881 |
+
del st.session_state[key]
|
882 |
+
session_state.input_data = {}
|
883 |
+
st.experimental_rerun()
|
884 |
+
else:
|
885 |
+
# placeholder
|
886 |
+
st.caption("请使用左侧控制板进行输入并运行获得结果")
|
887 |
+
|
888 |
+
|
mkgui/base/ui/streamlit_utils.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUSTOM_STREAMLIT_CSS = """
|
2 |
+
div[data-testid="stBlock"] button {
|
3 |
+
width: 100% !important;
|
4 |
+
margin-bottom: 20px !important;
|
5 |
+
border-color: #bfbfbf !important;
|
6 |
+
}
|
7 |
+
section[data-testid="stSidebar"] div {
|
8 |
+
max-width: 10rem;
|
9 |
+
}
|
10 |
+
pre code {
|
11 |
+
white-space: pre-wrap;
|
12 |
+
}
|
13 |
+
"""
|
mkgui/preprocess.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Any, Tuple
|
6 |
+
|
7 |
+
|
8 |
+
# Constants
|
9 |
+
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
10 |
+
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
11 |
+
|
12 |
+
|
13 |
+
if os.path.isdir(EXT_MODELS_DIRT):
|
14 |
+
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
15 |
+
print("Loaded extractor models: " + str(len(extractors)))
|
16 |
+
else:
|
17 |
+
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
18 |
+
|
19 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
20 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
21 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
22 |
+
else:
|
23 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
24 |
+
|
25 |
+
class Model(str, Enum):
|
26 |
+
VC_PPG2MEL = "ppg2mel"
|
27 |
+
|
28 |
+
class Dataset(str, Enum):
|
29 |
+
AIDATATANG_200ZH = "aidatatang_200zh"
|
30 |
+
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
31 |
+
|
32 |
+
class Input(BaseModel):
|
33 |
+
# def render_input_ui(st, input) -> Dict:
|
34 |
+
# input["selected_dataset"] = st.selectbox(
|
35 |
+
# '选择数据集',
|
36 |
+
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
37 |
+
# )
|
38 |
+
# return input
|
39 |
+
model: Model = Field(
|
40 |
+
Model.VC_PPG2MEL, title="目标模型",
|
41 |
+
)
|
42 |
+
dataset: Dataset = Field(
|
43 |
+
Dataset.AIDATATANG_200ZH, title="数据集选择",
|
44 |
+
)
|
45 |
+
datasets_root: str = Field(
|
46 |
+
..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
|
47 |
+
format=True,
|
48 |
+
example="..\\trainning_data\\"
|
49 |
+
)
|
50 |
+
output_root: str = Field(
|
51 |
+
..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
|
52 |
+
format=True,
|
53 |
+
example="..\\trainning_data\\"
|
54 |
+
)
|
55 |
+
n_processes: int = Field(
|
56 |
+
2, alias="处理线程数", description="根据CPU线程数来设置",
|
57 |
+
le=32, ge=1
|
58 |
+
)
|
59 |
+
extractor: extractors = Field(
|
60 |
+
..., alias="特征提取模型",
|
61 |
+
description="选择PPG特征提取模型文件."
|
62 |
+
)
|
63 |
+
encoder: encoders = Field(
|
64 |
+
..., alias="语音编码模型",
|
65 |
+
description="选择语音编码模型文件."
|
66 |
+
)
|
67 |
+
|
68 |
+
class AudioEntity(BaseModel):
|
69 |
+
content: bytes
|
70 |
+
mel: Any
|
71 |
+
|
72 |
+
class Output(BaseModel):
|
73 |
+
__root__: Tuple[str, int]
|
74 |
+
|
75 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
76 |
+
"""Custom output UI.
|
77 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
78 |
+
"""
|
79 |
+
sr, count = self.__root__
|
80 |
+
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
81 |
+
|
82 |
+
def preprocess(input: Input) -> Output:
|
83 |
+
"""Preprocess(预处理)"""
|
84 |
+
finished = 0
|
85 |
+
if input.model == Model.VC_PPG2MEL:
|
86 |
+
from ppg2mel.preprocess import preprocess_dataset
|
87 |
+
finished = preprocess_dataset(
|
88 |
+
datasets_root=Path(input.datasets_root),
|
89 |
+
dataset=input.dataset,
|
90 |
+
out_dir=Path(input.output_root),
|
91 |
+
n_processes=input.n_processes,
|
92 |
+
ppg_encoder_model_fpath=Path(input.extractor.value),
|
93 |
+
speaker_encoder_model=Path(input.encoder.value)
|
94 |
+
)
|
95 |
+
# TODO: pass useful return code
|
96 |
+
return Output(__root__=(input.dataset, finished))
|
mkgui/static/mb.png
ADDED
![]() |
mkgui/train.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Any
|
6 |
+
from synthesizer.hparams import hparams
|
7 |
+
from synthesizer.train import train as synt_train
|
8 |
+
|
9 |
+
# Constants
|
10 |
+
SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
|
11 |
+
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
12 |
+
|
13 |
+
|
14 |
+
# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
15 |
+
# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
16 |
+
# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
17 |
+
|
18 |
+
# Pre-Load models
|
19 |
+
if os.path.isdir(SYN_MODELS_DIRT):
|
20 |
+
synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
|
21 |
+
print("Loaded synthesizer models: " + str(len(synthesizers)))
|
22 |
+
else:
|
23 |
+
raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
|
24 |
+
|
25 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
26 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
27 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
28 |
+
else:
|
29 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
30 |
+
|
31 |
+
class Model(str, Enum):
|
32 |
+
DEFAULT = "default"
|
33 |
+
|
34 |
+
class Input(BaseModel):
|
35 |
+
model: Model = Field(
|
36 |
+
Model.DEFAULT, title="模型类型",
|
37 |
+
)
|
38 |
+
# datasets_root: str = Field(
|
39 |
+
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
40 |
+
# format=True,
|
41 |
+
# example="..\\trainning_data\\"
|
42 |
+
# )
|
43 |
+
input_root: str = Field(
|
44 |
+
..., alias="输入目录", description="预处理数据根目录",
|
45 |
+
format=True,
|
46 |
+
example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
|
47 |
+
)
|
48 |
+
run_id: str = Field(
|
49 |
+
"", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
|
50 |
+
)
|
51 |
+
synthesizer: synthesizers = Field(
|
52 |
+
..., alias="已有合成模型",
|
53 |
+
description="选择语音合成模型文件."
|
54 |
+
)
|
55 |
+
gpu: bool = Field(
|
56 |
+
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
57 |
+
)
|
58 |
+
verbose: bool = Field(
|
59 |
+
True, alias="打印详情", description="选择“是”,输出更多详情",
|
60 |
+
)
|
61 |
+
encoder: encoders = Field(
|
62 |
+
..., alias="语音编码模型",
|
63 |
+
description="选择语音编码模型文件."
|
64 |
+
)
|
65 |
+
save_every: int = Field(
|
66 |
+
1000, alias="更新间隔", description="每隔n步则更新一次模型",
|
67 |
+
)
|
68 |
+
backup_every: int = Field(
|
69 |
+
10000, alias="保存间隔", description="每隔n步则保存一次模型",
|
70 |
+
)
|
71 |
+
log_every: int = Field(
|
72 |
+
500, alias="打印间隔", description="每隔n步则打印一次训练统计",
|
73 |
+
)
|
74 |
+
|
75 |
+
class AudioEntity(BaseModel):
|
76 |
+
content: bytes
|
77 |
+
mel: Any
|
78 |
+
|
79 |
+
class Output(BaseModel):
|
80 |
+
__root__: int
|
81 |
+
|
82 |
+
def render_output_ui(self, streamlit_app) -> None: # type: ignore
|
83 |
+
"""Custom output UI.
|
84 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
85 |
+
"""
|
86 |
+
streamlit_app.subheader(f"Training started with code: {self.__root__}")
|
87 |
+
|
88 |
+
def train(input: Input) -> Output:
|
89 |
+
"""Train(训练)"""
|
90 |
+
|
91 |
+
print(">>> Start training ...")
|
92 |
+
force_restart = len(input.run_id) > 0
|
93 |
+
if not force_restart:
|
94 |
+
input.run_id = Path(input.synthesizer.value).name.split('.')[0]
|
95 |
+
|
96 |
+
synt_train(
|
97 |
+
input.run_id,
|
98 |
+
input.input_root,
|
99 |
+
f"synthesizer{os.sep}saved_models",
|
100 |
+
input.save_every,
|
101 |
+
input.backup_every,
|
102 |
+
input.log_every,
|
103 |
+
force_restart,
|
104 |
+
hparams
|
105 |
+
)
|
106 |
+
return Output(__root__=0)
|
mkgui/train_vc.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Any, Tuple
|
6 |
+
import numpy as np
|
7 |
+
from utils.load_yaml import HpsYaml
|
8 |
+
from utils.util import AttrDict
|
9 |
+
import torch
|
10 |
+
|
11 |
+
# Constants
|
12 |
+
EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
|
13 |
+
CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
|
14 |
+
ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
|
15 |
+
|
16 |
+
|
17 |
+
if os.path.isdir(EXT_MODELS_DIRT):
|
18 |
+
extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
|
19 |
+
print("Loaded extractor models: " + str(len(extractors)))
|
20 |
+
else:
|
21 |
+
raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
|
22 |
+
|
23 |
+
if os.path.isdir(CONV_MODELS_DIRT):
|
24 |
+
convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
|
25 |
+
print("Loaded convertor models: " + str(len(convertors)))
|
26 |
+
else:
|
27 |
+
raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
|
28 |
+
|
29 |
+
if os.path.isdir(ENC_MODELS_DIRT):
|
30 |
+
encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
|
31 |
+
print("Loaded encoders models: " + str(len(encoders)))
|
32 |
+
else:
|
33 |
+
raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
|
34 |
+
|
35 |
+
class Model(str, Enum):
|
36 |
+
VC_PPG2MEL = "ppg2mel"
|
37 |
+
|
38 |
+
class Dataset(str, Enum):
|
39 |
+
AIDATATANG_200ZH = "aidatatang_200zh"
|
40 |
+
AIDATATANG_200ZH_S = "aidatatang_200zh_s"
|
41 |
+
|
42 |
+
class Input(BaseModel):
|
43 |
+
# def render_input_ui(st, input) -> Dict:
|
44 |
+
# input["selected_dataset"] = st.selectbox(
|
45 |
+
# '选择数据集',
|
46 |
+
# ("aidatatang_200zh", "aidatatang_200zh_s")
|
47 |
+
# )
|
48 |
+
# return input
|
49 |
+
model: Model = Field(
|
50 |
+
Model.VC_PPG2MEL, title="模型类型",
|
51 |
+
)
|
52 |
+
# datasets_root: str = Field(
|
53 |
+
# ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
|
54 |
+
# format=True,
|
55 |
+
# example="..\\trainning_data\\"
|
56 |
+
# )
|
57 |
+
output_root: str = Field(
|
58 |
+
..., alias="输出目录(可选)", description="建议不填,保持默认",
|
59 |
+
format=True,
|
60 |
+
example=""
|
61 |
+
)
|
62 |
+
continue_mode: bool = Field(
|
63 |
+
True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
|
64 |
+
)
|
65 |
+
gpu: bool = Field(
|
66 |
+
True, alias="GPU训练", description="选择“是”,则使用GPU训练",
|
67 |
+
)
|
68 |
+
verbose: bool = Field(
|
69 |
+
True, alias="打印详情", description="选择“是”,输出更多详情",
|
70 |
+
)
|
71 |
+
# TODO: Move to hiden fields by default
|
72 |
+
convertor: convertors = Field(
|
73 |
+
..., alias="转换模型",
|
74 |
+
description="选择语音转换模型文件."
|
75 |
+
)
|
76 |
+
extractor: extractors = Field(
|
77 |
+
..., alias="特征提取模型",
|
78 |
+
description="选择PPG特征提取模型文件."
|
79 |
+
)
|
80 |
+
encoder: encoders = Field(
|
81 |
+
..., alias="语音编码模型",
|
82 |
+
description="选择语音编码模型文件."
|
83 |
+
)
|
84 |
+
njobs: int = Field(
|
85 |
+
8, alias="进程数", description="适用于ppg2mel",
|
86 |
+
)
|
87 |
+
seed: int = Field(
|
88 |
+
default=0, alias="初始随机数", description="适用于ppg2mel",
|
89 |
+
)
|
90 |
+
model_name: str = Field(
|
91 |
+
..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
|
92 |
+
example="test"
|
93 |
+
)
|
94 |
+
model_config: str = Field(
|
95 |
+
..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
|
96 |
+
example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
|
97 |
+
)
|
98 |
+
|
99 |
+
class AudioEntity(BaseModel):
|
100 |
+
content: bytes
|
101 |
+
mel: Any
|
102 |
+
|
103 |
+
class Output(BaseModel):
|
104 |
+
__root__: Tuple[str, int]
|
105 |
+
|
106 |
+
def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
|
107 |
+
"""Custom output UI.
|
108 |
+
If this method is implmeneted, it will be used instead of the default Output UI renderer.
|
109 |
+
"""
|
110 |
+
sr, count = self.__root__
|
111 |
+
streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
|
112 |
+
|
113 |
+
def train_vc(input: Input) -> Output:
|
114 |
+
"""Train VC(训练 VC)"""
|
115 |
+
|
116 |
+
print(">>> OneShot VC training ...")
|
117 |
+
params = AttrDict()
|
118 |
+
params.update({
|
119 |
+
"gpu": input.gpu,
|
120 |
+
"cpu": not input.gpu,
|
121 |
+
"njobs": input.njobs,
|
122 |
+
"seed": input.seed,
|
123 |
+
"verbose": input.verbose,
|
124 |
+
"load": input.convertor.value,
|
125 |
+
"warm_start": False,
|
126 |
+
})
|
127 |
+
if input.continue_mode:
|
128 |
+
# trace old model and config
|
129 |
+
p = Path(input.convertor.value)
|
130 |
+
params.name = p.parent.name
|
131 |
+
# search a config file
|
132 |
+
model_config_fpaths = list(p.parent.rglob("*.yaml"))
|
133 |
+
if len(model_config_fpaths) == 0:
|
134 |
+
raise "No model yaml config found for convertor"
|
135 |
+
config = HpsYaml(model_config_fpaths[0])
|
136 |
+
params.ckpdir = p.parent.parent
|
137 |
+
params.config = model_config_fpaths[0]
|
138 |
+
params.logdir = os.path.join(p.parent, "log")
|
139 |
+
else:
|
140 |
+
# Make the config dict dot visitable
|
141 |
+
config = HpsYaml(input.config)
|
142 |
+
np.random.seed(input.seed)
|
143 |
+
torch.manual_seed(input.seed)
|
144 |
+
if torch.cuda.is_available():
|
145 |
+
torch.cuda.manual_seed_all(input.seed)
|
146 |
+
mode = "train"
|
147 |
+
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
148 |
+
solver = Solver(config, params, mode)
|
149 |
+
solver.load_data()
|
150 |
+
solver.set_model()
|
151 |
+
solver.exec()
|
152 |
+
print(">>> Oneshot VC train finished!")
|
153 |
+
|
154 |
+
# TODO: pass useful return code
|
155 |
+
return Output(__root__=(input.dataset, 0))
|
packages.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
libasound2-dev
|
2 |
+
portaudio19-dev
|
3 |
+
libportaudio2
|
4 |
+
libportaudiocpp0
|
5 |
+
ffmpeg
|
ppg2mel/__init__.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
# Copyright 2020 Songxiang Liu
|
4 |
+
# Apache 2.0
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from .utils.abs_model import AbsMelDecoder
|
14 |
+
from .rnn_decoder_mol import Decoder
|
15 |
+
from .utils.cnn_postnet import Postnet
|
16 |
+
from .utils.vc_utils import get_mask_from_lengths
|
17 |
+
|
18 |
+
from utils.load_yaml import HpsYaml
|
19 |
+
|
20 |
+
class MelDecoderMOLv2(AbsMelDecoder):
|
21 |
+
"""Use an encoder to preprocess ppg."""
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
num_speakers: int,
|
25 |
+
spk_embed_dim: int,
|
26 |
+
bottle_neck_feature_dim: int,
|
27 |
+
encoder_dim: int = 256,
|
28 |
+
encoder_downsample_rates: List = [2, 2],
|
29 |
+
attention_rnn_dim: int = 512,
|
30 |
+
decoder_rnn_dim: int = 512,
|
31 |
+
num_decoder_rnn_layer: int = 1,
|
32 |
+
concat_context_to_last: bool = True,
|
33 |
+
prenet_dims: List = [256, 128],
|
34 |
+
num_mixtures: int = 5,
|
35 |
+
frames_per_step: int = 2,
|
36 |
+
mask_padding: bool = True,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.mask_padding = mask_padding
|
41 |
+
self.bottle_neck_feature_dim = bottle_neck_feature_dim
|
42 |
+
self.num_mels = 80
|
43 |
+
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
|
44 |
+
self.frames_per_step = frames_per_step
|
45 |
+
self.use_spk_dvec = True
|
46 |
+
|
47 |
+
input_dim = bottle_neck_feature_dim
|
48 |
+
|
49 |
+
# Downsampling convolution
|
50 |
+
self.bnf_prenet = torch.nn.Sequential(
|
51 |
+
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
|
52 |
+
torch.nn.LeakyReLU(0.1),
|
53 |
+
|
54 |
+
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
55 |
+
torch.nn.Conv1d(
|
56 |
+
encoder_dim, encoder_dim,
|
57 |
+
kernel_size=2*encoder_downsample_rates[0],
|
58 |
+
stride=encoder_downsample_rates[0],
|
59 |
+
padding=encoder_downsample_rates[0]//2,
|
60 |
+
),
|
61 |
+
torch.nn.LeakyReLU(0.1),
|
62 |
+
|
63 |
+
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
64 |
+
torch.nn.Conv1d(
|
65 |
+
encoder_dim, encoder_dim,
|
66 |
+
kernel_size=2*encoder_downsample_rates[1],
|
67 |
+
stride=encoder_downsample_rates[1],
|
68 |
+
padding=encoder_downsample_rates[1]//2,
|
69 |
+
),
|
70 |
+
torch.nn.LeakyReLU(0.1),
|
71 |
+
|
72 |
+
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
73 |
+
)
|
74 |
+
decoder_enc_dim = encoder_dim
|
75 |
+
self.pitch_convs = torch.nn.Sequential(
|
76 |
+
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
|
77 |
+
torch.nn.LeakyReLU(0.1),
|
78 |
+
|
79 |
+
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
80 |
+
torch.nn.Conv1d(
|
81 |
+
encoder_dim, encoder_dim,
|
82 |
+
kernel_size=2*encoder_downsample_rates[0],
|
83 |
+
stride=encoder_downsample_rates[0],
|
84 |
+
padding=encoder_downsample_rates[0]//2,
|
85 |
+
),
|
86 |
+
torch.nn.LeakyReLU(0.1),
|
87 |
+
|
88 |
+
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
89 |
+
torch.nn.Conv1d(
|
90 |
+
encoder_dim, encoder_dim,
|
91 |
+
kernel_size=2*encoder_downsample_rates[1],
|
92 |
+
stride=encoder_downsample_rates[1],
|
93 |
+
padding=encoder_downsample_rates[1]//2,
|
94 |
+
),
|
95 |
+
torch.nn.LeakyReLU(0.1),
|
96 |
+
|
97 |
+
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
98 |
+
)
|
99 |
+
|
100 |
+
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
|
101 |
+
|
102 |
+
# Decoder
|
103 |
+
self.decoder = Decoder(
|
104 |
+
enc_dim=decoder_enc_dim,
|
105 |
+
num_mels=self.num_mels,
|
106 |
+
frames_per_step=frames_per_step,
|
107 |
+
attention_rnn_dim=attention_rnn_dim,
|
108 |
+
decoder_rnn_dim=decoder_rnn_dim,
|
109 |
+
num_decoder_rnn_layer=num_decoder_rnn_layer,
|
110 |
+
prenet_dims=prenet_dims,
|
111 |
+
num_mixtures=num_mixtures,
|
112 |
+
use_stop_tokens=True,
|
113 |
+
concat_context_to_last=concat_context_to_last,
|
114 |
+
encoder_down_factor=self.encoder_down_factor,
|
115 |
+
)
|
116 |
+
|
117 |
+
# Mel-Spec Postnet: some residual CNN layers
|
118 |
+
self.postnet = Postnet()
|
119 |
+
|
120 |
+
def parse_output(self, outputs, output_lengths=None):
|
121 |
+
if self.mask_padding and output_lengths is not None:
|
122 |
+
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
|
123 |
+
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
|
124 |
+
outputs[0].data.masked_fill_(mask, 0.0)
|
125 |
+
outputs[1].data.masked_fill_(mask, 0.0)
|
126 |
+
return outputs
|
127 |
+
|
128 |
+
def forward(
|
129 |
+
self,
|
130 |
+
bottle_neck_features: torch.Tensor,
|
131 |
+
feature_lengths: torch.Tensor,
|
132 |
+
speech: torch.Tensor,
|
133 |
+
speech_lengths: torch.Tensor,
|
134 |
+
logf0_uv: torch.Tensor = None,
|
135 |
+
spembs: torch.Tensor = None,
|
136 |
+
output_att_ws: bool = False,
|
137 |
+
):
|
138 |
+
decoder_inputs = self.bnf_prenet(
|
139 |
+
bottle_neck_features.transpose(1, 2)
|
140 |
+
).transpose(1, 2)
|
141 |
+
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
142 |
+
decoder_inputs = decoder_inputs + logf0_uv
|
143 |
+
|
144 |
+
assert spembs is not None
|
145 |
+
spk_embeds = F.normalize(
|
146 |
+
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
147 |
+
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
148 |
+
decoder_inputs = self.reduce_proj(decoder_inputs)
|
149 |
+
|
150 |
+
# (B, num_mels, T_dec)
|
151 |
+
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
152 |
+
mel_outputs, predicted_stop, alignments = self.decoder(
|
153 |
+
decoder_inputs, speech, T_dec)
|
154 |
+
## Post-processing
|
155 |
+
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
156 |
+
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
157 |
+
if output_att_ws:
|
158 |
+
return self.parse_output(
|
159 |
+
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
|
160 |
+
else:
|
161 |
+
return self.parse_output(
|
162 |
+
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
|
163 |
+
|
164 |
+
# return mel_outputs, mel_outputs_postnet
|
165 |
+
|
166 |
+
def inference(
|
167 |
+
self,
|
168 |
+
bottle_neck_features: torch.Tensor,
|
169 |
+
logf0_uv: torch.Tensor = None,
|
170 |
+
spembs: torch.Tensor = None,
|
171 |
+
):
|
172 |
+
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
|
173 |
+
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
174 |
+
decoder_inputs = decoder_inputs + logf0_uv
|
175 |
+
|
176 |
+
assert spembs is not None
|
177 |
+
spk_embeds = F.normalize(
|
178 |
+
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
179 |
+
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
180 |
+
bottle_neck_features = self.reduce_proj(bottle_neck_features)
|
181 |
+
|
182 |
+
## Decoder
|
183 |
+
if bottle_neck_features.size(0) > 1:
|
184 |
+
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
|
185 |
+
else:
|
186 |
+
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
|
187 |
+
## Post-processing
|
188 |
+
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
189 |
+
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
190 |
+
# outputs = mel_outputs_postnet[0]
|
191 |
+
|
192 |
+
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
193 |
+
|
194 |
+
def load_model(model_file, device=None):
|
195 |
+
# search a config file
|
196 |
+
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
197 |
+
if len(model_config_fpaths) == 0:
|
198 |
+
raise "No model yaml config found for convertor"
|
199 |
+
if device is None:
|
200 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
201 |
+
|
202 |
+
model_config = HpsYaml(model_config_fpaths[0])
|
203 |
+
ppg2mel_model = MelDecoderMOLv2(
|
204 |
+
**model_config["model"]
|
205 |
+
).to(device)
|
206 |
+
ckpt = torch.load(model_file, map_location=device)
|
207 |
+
ppg2mel_model.load_state_dict(ckpt["model"])
|
208 |
+
ppg2mel_model.eval()
|
209 |
+
return ppg2mel_model
|
ppg2mel/preprocess.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from pathlib import Path
|
7 |
+
import soundfile
|
8 |
+
import resampy
|
9 |
+
|
10 |
+
from ppg_extractor import load_model
|
11 |
+
import encoder.inference as Encoder
|
12 |
+
from encoder.audio import preprocess_wav
|
13 |
+
from encoder import audio
|
14 |
+
from utils.f0_utils import compute_f0
|
15 |
+
|
16 |
+
from torch.multiprocessing import Pool, cpu_count
|
17 |
+
from functools import partial
|
18 |
+
|
19 |
+
SAMPLE_RATE=16000
|
20 |
+
|
21 |
+
def _compute_bnf(
|
22 |
+
wav: any,
|
23 |
+
output_fpath: str,
|
24 |
+
device: torch.device,
|
25 |
+
ppg_model_local: any,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
|
29 |
+
"""
|
30 |
+
ppg_model_local.to(device)
|
31 |
+
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
|
32 |
+
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
|
33 |
+
with torch.no_grad():
|
34 |
+
bnf = ppg_model_local(wav_tensor, wav_length)
|
35 |
+
bnf_npy = bnf.squeeze(0).cpu().numpy()
|
36 |
+
np.save(output_fpath, bnf_npy, allow_pickle=False)
|
37 |
+
return bnf_npy, len(bnf_npy)
|
38 |
+
|
39 |
+
def _compute_f0_from_wav(wav, output_fpath):
|
40 |
+
"""Compute merged f0 values."""
|
41 |
+
f0 = compute_f0(wav, SAMPLE_RATE)
|
42 |
+
np.save(output_fpath, f0, allow_pickle=False)
|
43 |
+
return f0, len(f0)
|
44 |
+
|
45 |
+
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
46 |
+
Encoder.set_model(encoder_model_local)
|
47 |
+
# Compute where to split the utterance into partials and pad if necessary
|
48 |
+
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
49 |
+
max_wave_length = wave_slices[-1].stop
|
50 |
+
if max_wave_length >= len(wav):
|
51 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
52 |
+
|
53 |
+
# Split the utterance into partials
|
54 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
55 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
56 |
+
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
57 |
+
|
58 |
+
# Compute the utterance embedding from the partial embeddings
|
59 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
60 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
61 |
+
|
62 |
+
np.save(output_fpath, embed, allow_pickle=False)
|
63 |
+
return embed, len(embed)
|
64 |
+
|
65 |
+
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
66 |
+
# wav = preprocess_wav(wav_path)
|
67 |
+
# try:
|
68 |
+
wav, sr = soundfile.read(wav_path)
|
69 |
+
if len(wav) < sr:
|
70 |
+
return None, sr, len(wav)
|
71 |
+
if sr != SAMPLE_RATE:
|
72 |
+
wav = resampy.resample(wav, sr, SAMPLE_RATE)
|
73 |
+
sr = SAMPLE_RATE
|
74 |
+
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
75 |
+
|
76 |
+
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
77 |
+
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
78 |
+
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
79 |
+
|
80 |
+
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
81 |
+
# Glob wav files
|
82 |
+
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
|
83 |
+
print(f"Globbed {len(wav_file_list)} wav files.")
|
84 |
+
|
85 |
+
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
86 |
+
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
87 |
+
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
88 |
+
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
89 |
+
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
90 |
+
if n_processes is None:
|
91 |
+
n_processes = cpu_count()
|
92 |
+
|
93 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
94 |
+
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
95 |
+
job = Pool(n_processes).imap(func, wav_file_list)
|
96 |
+
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
97 |
+
|
98 |
+
# finish processing and mark
|
99 |
+
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
100 |
+
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
101 |
+
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
102 |
+
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
|
103 |
+
id = os.path.basename(file).split(".f0.npy")[0]
|
104 |
+
if id.endswith("01"):
|
105 |
+
d_fid_file.write(id + "\n")
|
106 |
+
elif id.endswith("09"):
|
107 |
+
e_fid_file.write(id + "\n")
|
108 |
+
else:
|
109 |
+
t_fid_file.write(id + "\n")
|
110 |
+
t_fid_file.close()
|
111 |
+
d_fid_file.close()
|
112 |
+
e_fid_file.close()
|
113 |
+
return len(wav_file_list)
|
ppg2mel/rnn_decoder_mol.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from .utils.mol_attention import MOLAttention
|
6 |
+
from .utils.basic_layers import Linear
|
7 |
+
from .utils.vc_utils import get_mask_from_lengths
|
8 |
+
|
9 |
+
|
10 |
+
class DecoderPrenet(nn.Module):
|
11 |
+
def __init__(self, in_dim, sizes):
|
12 |
+
super().__init__()
|
13 |
+
in_sizes = [in_dim] + sizes[:-1]
|
14 |
+
self.layers = nn.ModuleList(
|
15 |
+
[Linear(in_size, out_size, bias=False)
|
16 |
+
for (in_size, out_size) in zip(in_sizes, sizes)])
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
for linear in self.layers:
|
20 |
+
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class Decoder(nn.Module):
|
25 |
+
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
enc_dim,
|
29 |
+
num_mels,
|
30 |
+
frames_per_step,
|
31 |
+
attention_rnn_dim,
|
32 |
+
decoder_rnn_dim,
|
33 |
+
prenet_dims,
|
34 |
+
num_mixtures,
|
35 |
+
encoder_down_factor=1,
|
36 |
+
num_decoder_rnn_layer=1,
|
37 |
+
use_stop_tokens=False,
|
38 |
+
concat_context_to_last=False,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.enc_dim = enc_dim
|
42 |
+
self.encoder_down_factor = encoder_down_factor
|
43 |
+
self.num_mels = num_mels
|
44 |
+
self.frames_per_step = frames_per_step
|
45 |
+
self.attention_rnn_dim = attention_rnn_dim
|
46 |
+
self.decoder_rnn_dim = decoder_rnn_dim
|
47 |
+
self.prenet_dims = prenet_dims
|
48 |
+
self.use_stop_tokens = use_stop_tokens
|
49 |
+
self.num_decoder_rnn_layer = num_decoder_rnn_layer
|
50 |
+
self.concat_context_to_last = concat_context_to_last
|
51 |
+
|
52 |
+
# Mel prenet
|
53 |
+
self.prenet = DecoderPrenet(num_mels, prenet_dims)
|
54 |
+
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
|
55 |
+
|
56 |
+
# Attention RNN
|
57 |
+
self.attention_rnn = nn.LSTMCell(
|
58 |
+
prenet_dims[-1] + enc_dim,
|
59 |
+
attention_rnn_dim
|
60 |
+
)
|
61 |
+
|
62 |
+
# Attention
|
63 |
+
self.attention_layer = MOLAttention(
|
64 |
+
attention_rnn_dim,
|
65 |
+
r=frames_per_step/encoder_down_factor,
|
66 |
+
M=num_mixtures,
|
67 |
+
)
|
68 |
+
|
69 |
+
# Decoder RNN
|
70 |
+
self.decoder_rnn_layers = nn.ModuleList()
|
71 |
+
for i in range(num_decoder_rnn_layer):
|
72 |
+
if i == 0:
|
73 |
+
self.decoder_rnn_layers.append(
|
74 |
+
nn.LSTMCell(
|
75 |
+
enc_dim + attention_rnn_dim,
|
76 |
+
decoder_rnn_dim))
|
77 |
+
else:
|
78 |
+
self.decoder_rnn_layers.append(
|
79 |
+
nn.LSTMCell(
|
80 |
+
decoder_rnn_dim,
|
81 |
+
decoder_rnn_dim))
|
82 |
+
# self.decoder_rnn = nn.LSTMCell(
|
83 |
+
# 2 * enc_dim + attention_rnn_dim,
|
84 |
+
# decoder_rnn_dim
|
85 |
+
# )
|
86 |
+
if concat_context_to_last:
|
87 |
+
self.linear_projection = Linear(
|
88 |
+
enc_dim + decoder_rnn_dim,
|
89 |
+
num_mels * frames_per_step
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
self.linear_projection = Linear(
|
93 |
+
decoder_rnn_dim,
|
94 |
+
num_mels * frames_per_step
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
# Stop-token layer
|
99 |
+
if self.use_stop_tokens:
|
100 |
+
if concat_context_to_last:
|
101 |
+
self.stop_layer = Linear(
|
102 |
+
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
self.stop_layer = Linear(
|
106 |
+
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
def get_go_frame(self, memory):
|
111 |
+
B = memory.size(0)
|
112 |
+
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
|
113 |
+
device=memory.device)
|
114 |
+
return go_frame
|
115 |
+
|
116 |
+
def initialize_decoder_states(self, memory, mask):
|
117 |
+
device = next(self.parameters()).device
|
118 |
+
B = memory.size(0)
|
119 |
+
|
120 |
+
# attention rnn states
|
121 |
+
self.attention_hidden = torch.zeros(
|
122 |
+
(B, self.attention_rnn_dim), device=device)
|
123 |
+
self.attention_cell = torch.zeros(
|
124 |
+
(B, self.attention_rnn_dim), device=device)
|
125 |
+
|
126 |
+
# decoder rnn states
|
127 |
+
self.decoder_hiddens = []
|
128 |
+
self.decoder_cells = []
|
129 |
+
for i in range(self.num_decoder_rnn_layer):
|
130 |
+
self.decoder_hiddens.append(
|
131 |
+
torch.zeros((B, self.decoder_rnn_dim),
|
132 |
+
device=device)
|
133 |
+
)
|
134 |
+
self.decoder_cells.append(
|
135 |
+
torch.zeros((B, self.decoder_rnn_dim),
|
136 |
+
device=device)
|
137 |
+
)
|
138 |
+
# self.decoder_hidden = torch.zeros(
|
139 |
+
# (B, self.decoder_rnn_dim), device=device)
|
140 |
+
# self.decoder_cell = torch.zeros(
|
141 |
+
# (B, self.decoder_rnn_dim), device=device)
|
142 |
+
|
143 |
+
self.attention_context = torch.zeros(
|
144 |
+
(B, self.enc_dim), device=device)
|
145 |
+
|
146 |
+
self.memory = memory
|
147 |
+
# self.processed_memory = self.attention_layer.memory_layer(memory)
|
148 |
+
self.mask = mask
|
149 |
+
|
150 |
+
def parse_decoder_inputs(self, decoder_inputs):
|
151 |
+
"""Prepare decoder inputs, i.e. gt mel
|
152 |
+
Args:
|
153 |
+
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
|
154 |
+
"""
|
155 |
+
decoder_inputs = decoder_inputs.reshape(
|
156 |
+
decoder_inputs.size(0),
|
157 |
+
int(decoder_inputs.size(1)/self.frames_per_step), -1)
|
158 |
+
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
|
159 |
+
decoder_inputs = decoder_inputs.transpose(0, 1)
|
160 |
+
# (T_out//r, B, num_mels)
|
161 |
+
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
|
162 |
+
return decoder_inputs
|
163 |
+
|
164 |
+
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
|
165 |
+
""" Prepares decoder outputs for output
|
166 |
+
Args:
|
167 |
+
mel_outputs:
|
168 |
+
alignments:
|
169 |
+
"""
|
170 |
+
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
|
171 |
+
alignments = torch.stack(alignments).transpose(0, 1)
|
172 |
+
# (T_out//r, B) -> (B, T_out//r)
|
173 |
+
if stop_outputs is not None:
|
174 |
+
if alignments.size(0) == 1:
|
175 |
+
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
|
176 |
+
else:
|
177 |
+
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
|
178 |
+
stop_outputs = stop_outputs.contiguous()
|
179 |
+
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
|
180 |
+
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
181 |
+
# decouple frames per step
|
182 |
+
# (B, T_out, num_mels)
|
183 |
+
mel_outputs = mel_outputs.view(
|
184 |
+
mel_outputs.size(0), -1, self.num_mels)
|
185 |
+
return mel_outputs, alignments, stop_outputs
|
186 |
+
|
187 |
+
def attend(self, decoder_input):
|
188 |
+
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
189 |
+
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
190 |
+
cell_input, (self.attention_hidden, self.attention_cell))
|
191 |
+
self.attention_context, attention_weights = self.attention_layer(
|
192 |
+
self.attention_hidden, self.memory, None, self.mask)
|
193 |
+
|
194 |
+
decoder_rnn_input = torch.cat(
|
195 |
+
(self.attention_hidden, self.attention_context), -1)
|
196 |
+
|
197 |
+
return decoder_rnn_input, self.attention_context, attention_weights
|
198 |
+
|
199 |
+
def decode(self, decoder_input):
|
200 |
+
for i in range(self.num_decoder_rnn_layer):
|
201 |
+
if i == 0:
|
202 |
+
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
203 |
+
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
|
204 |
+
else:
|
205 |
+
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
206 |
+
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
|
207 |
+
return self.decoder_hiddens[-1]
|
208 |
+
|
209 |
+
def forward(self, memory, mel_inputs, memory_lengths):
|
210 |
+
""" Decoder forward pass for training
|
211 |
+
Args:
|
212 |
+
memory: (B, T_enc, enc_dim) Encoder outputs
|
213 |
+
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
|
214 |
+
memory_lengths: (B, ) Encoder output lengths for attention masking.
|
215 |
+
Returns:
|
216 |
+
mel_outputs: (B, T, num_mels) mel outputs from the decoder
|
217 |
+
alignments: (B, T//r, T_enc) attention weights.
|
218 |
+
"""
|
219 |
+
# [1, B, num_mels]
|
220 |
+
go_frame = self.get_go_frame(memory).unsqueeze(0)
|
221 |
+
# [T//r, B, num_mels]
|
222 |
+
mel_inputs = self.parse_decoder_inputs(mel_inputs)
|
223 |
+
# [T//r + 1, B, num_mels]
|
224 |
+
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
|
225 |
+
# [T//r + 1, B, prenet_dim]
|
226 |
+
decoder_inputs = self.prenet(mel_inputs)
|
227 |
+
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
|
228 |
+
|
229 |
+
self.initialize_decoder_states(
|
230 |
+
memory, mask=~get_mask_from_lengths(memory_lengths),
|
231 |
+
)
|
232 |
+
|
233 |
+
self.attention_layer.init_states(memory)
|
234 |
+
# self.attention_layer_pitch.init_states(memory_pitch)
|
235 |
+
|
236 |
+
mel_outputs, alignments = [], []
|
237 |
+
if self.use_stop_tokens:
|
238 |
+
stop_outputs = []
|
239 |
+
else:
|
240 |
+
stop_outputs = None
|
241 |
+
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
242 |
+
decoder_input = decoder_inputs[len(mel_outputs)]
|
243 |
+
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
|
244 |
+
|
245 |
+
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
|
246 |
+
|
247 |
+
decoder_rnn_output = self.decode(decoder_rnn_input)
|
248 |
+
if self.concat_context_to_last:
|
249 |
+
decoder_rnn_output = torch.cat(
|
250 |
+
(decoder_rnn_output, context), dim=1)
|
251 |
+
|
252 |
+
mel_output = self.linear_projection(decoder_rnn_output)
|
253 |
+
if self.use_stop_tokens:
|
254 |
+
stop_output = self.stop_layer(decoder_rnn_output)
|
255 |
+
stop_outputs += [stop_output.squeeze()]
|
256 |
+
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
|
257 |
+
alignments += [attention_weights]
|
258 |
+
# alignments_pitch += [attention_weights_pitch]
|
259 |
+
|
260 |
+
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
261 |
+
mel_outputs, alignments, stop_outputs)
|
262 |
+
if stop_outputs is None:
|
263 |
+
return mel_outputs, alignments
|
264 |
+
else:
|
265 |
+
return mel_outputs, stop_outputs, alignments
|
266 |
+
|
267 |
+
def inference(self, memory, stop_threshold=0.5):
|
268 |
+
""" Decoder inference
|
269 |
+
Args:
|
270 |
+
memory: (1, T_enc, D_enc) Encoder outputs
|
271 |
+
Returns:
|
272 |
+
mel_outputs: mel outputs from the decoder
|
273 |
+
alignments: sequence of attention weights from the decoder
|
274 |
+
"""
|
275 |
+
# [1, num_mels]
|
276 |
+
decoder_input = self.get_go_frame(memory)
|
277 |
+
|
278 |
+
self.initialize_decoder_states(memory, mask=None)
|
279 |
+
|
280 |
+
self.attention_layer.init_states(memory)
|
281 |
+
|
282 |
+
mel_outputs, alignments = [], []
|
283 |
+
# NOTE(sx): heuristic
|
284 |
+
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
285 |
+
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
286 |
+
while True:
|
287 |
+
decoder_input = self.prenet(decoder_input)
|
288 |
+
|
289 |
+
decoder_input_final, context, alignment = self.attend(decoder_input)
|
290 |
+
|
291 |
+
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
292 |
+
decoder_rnn_output = self.decode(decoder_input_final)
|
293 |
+
if self.concat_context_to_last:
|
294 |
+
decoder_rnn_output = torch.cat(
|
295 |
+
(decoder_rnn_output, context), dim=1)
|
296 |
+
|
297 |
+
mel_output = self.linear_projection(decoder_rnn_output)
|
298 |
+
stop_output = self.stop_layer(decoder_rnn_output)
|
299 |
+
|
300 |
+
mel_outputs += [mel_output.squeeze(1)]
|
301 |
+
alignments += [alignment]
|
302 |
+
|
303 |
+
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
|
304 |
+
break
|
305 |
+
if len(mel_outputs) >= max_decoder_step:
|
306 |
+
# print("Warning! Decoding steps reaches max decoder steps.")
|
307 |
+
break
|
308 |
+
|
309 |
+
decoder_input = mel_output[:,-self.num_mels:]
|
310 |
+
|
311 |
+
|
312 |
+
mel_outputs, alignments, _ = self.parse_decoder_outputs(
|
313 |
+
mel_outputs, alignments, None)
|
314 |
+
|
315 |
+
return mel_outputs, alignments
|
316 |
+
|
317 |
+
def inference_batched(self, memory, stop_threshold=0.5):
|
318 |
+
""" Decoder inference
|
319 |
+
Args:
|
320 |
+
memory: (B, T_enc, D_enc) Encoder outputs
|
321 |
+
Returns:
|
322 |
+
mel_outputs: mel outputs from the decoder
|
323 |
+
alignments: sequence of attention weights from the decoder
|
324 |
+
"""
|
325 |
+
# [1, num_mels]
|
326 |
+
decoder_input = self.get_go_frame(memory)
|
327 |
+
|
328 |
+
self.initialize_decoder_states(memory, mask=None)
|
329 |
+
|
330 |
+
self.attention_layer.init_states(memory)
|
331 |
+
|
332 |
+
mel_outputs, alignments = [], []
|
333 |
+
stop_outputs = []
|
334 |
+
# NOTE(sx): heuristic
|
335 |
+
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
336 |
+
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
337 |
+
while True:
|
338 |
+
decoder_input = self.prenet(decoder_input)
|
339 |
+
|
340 |
+
decoder_input_final, context, alignment = self.attend(decoder_input)
|
341 |
+
|
342 |
+
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
343 |
+
decoder_rnn_output = self.decode(decoder_input_final)
|
344 |
+
if self.concat_context_to_last:
|
345 |
+
decoder_rnn_output = torch.cat(
|
346 |
+
(decoder_rnn_output, context), dim=1)
|
347 |
+
|
348 |
+
mel_output = self.linear_projection(decoder_rnn_output)
|
349 |
+
# (B, 1)
|
350 |
+
stop_output = self.stop_layer(decoder_rnn_output)
|
351 |
+
stop_outputs += [stop_output.squeeze()]
|
352 |
+
# stop_outputs.append(stop_output)
|
353 |
+
|
354 |
+
mel_outputs += [mel_output.squeeze(1)]
|
355 |
+
alignments += [alignment]
|
356 |
+
# print(stop_output.shape)
|
357 |
+
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
|
358 |
+
and len(mel_outputs) >= min_decoder_step:
|
359 |
+
break
|
360 |
+
if len(mel_outputs) >= max_decoder_step:
|
361 |
+
# print("Warning! Decoding steps reaches max decoder steps.")
|
362 |
+
break
|
363 |
+
|
364 |
+
decoder_input = mel_output[:,-self.num_mels:]
|
365 |
+
|
366 |
+
|
367 |
+
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
368 |
+
mel_outputs, alignments, stop_outputs)
|
369 |
+
mel_outputs_stacked = []
|
370 |
+
for mel, stop_logit in zip(mel_outputs, stop_outputs):
|
371 |
+
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
|
372 |
+
mel_outputs_stacked.append(mel[:idx,:])
|
373 |
+
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
|
374 |
+
return mel_outputs, alignments
|