lewiswu1209 commited on
Commit
f4dac30
·
0 Parent(s):

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +31 -0
  2. .gitignore +18 -0
  3. CODE_OF_CONDUCT.md +130 -0
  4. LICENSE.txt +24 -0
  5. README-CN.md +230 -0
  6. README.md +13 -0
  7. app.py +80 -0
  8. demo_toolbox.py +49 -0
  9. encoder/__init__.py +0 -0
  10. encoder/audio.py +117 -0
  11. encoder/config.py +45 -0
  12. encoder/data_objects/__init__.py +2 -0
  13. encoder/data_objects/random_cycler.py +37 -0
  14. encoder/data_objects/speaker.py +40 -0
  15. encoder/data_objects/speaker_batch.py +12 -0
  16. encoder/data_objects/speaker_verification_dataset.py +56 -0
  17. encoder/data_objects/utterance.py +26 -0
  18. encoder/inference.py +195 -0
  19. encoder/model.py +135 -0
  20. encoder/params_data.py +29 -0
  21. encoder/params_model.py +11 -0
  22. encoder/preprocess.py +184 -0
  23. encoder/saved_models/pretrained.pt +3 -0
  24. encoder/train.py +123 -0
  25. encoder/visualizations.py +178 -0
  26. encoder_preprocess.py +61 -0
  27. encoder_train.py +47 -0
  28. gen_voice.py +128 -0
  29. mkgui/__init__.py +0 -0
  30. mkgui/app.py +145 -0
  31. mkgui/app_vc.py +166 -0
  32. mkgui/base/__init__.py +2 -0
  33. mkgui/base/api/__init__.py +1 -0
  34. mkgui/base/api/fastapi_utils.py +102 -0
  35. mkgui/base/components/__init__.py +0 -0
  36. mkgui/base/components/outputs.py +43 -0
  37. mkgui/base/components/types.py +46 -0
  38. mkgui/base/core.py +203 -0
  39. mkgui/base/ui/__init__.py +1 -0
  40. mkgui/base/ui/schema_utils.py +129 -0
  41. mkgui/base/ui/streamlit_ui.py +888 -0
  42. mkgui/base/ui/streamlit_utils.py +13 -0
  43. mkgui/preprocess.py +96 -0
  44. mkgui/static/mb.png +0 -0
  45. mkgui/train.py +106 -0
  46. mkgui/train_vc.py +155 -0
  47. packages.txt +5 -0
  48. ppg2mel/__init__.py +209 -0
  49. ppg2mel/preprocess.py +113 -0
  50. ppg2mel/rnn_decoder_mol.py +374 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.aux
3
+ *.log
4
+ *.out
5
+ *.synctex.gz
6
+ *.suo
7
+ *__pycache__
8
+ *.idea
9
+ *.ipynb_checkpoints
10
+ *.pickle
11
+ *.npy
12
+ *.blg
13
+ *.bbl
14
+ *.bcf
15
+ *.toc
16
+ *.sh
17
+ wavs
18
+ log
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+ ## First of all
3
+ Don't be evil, never
4
+
5
+ ## Our Pledge
6
+
7
+ We as members, contributors, and leaders pledge to make participation in our
8
+ community a harassment-free experience for everyone, regardless of age, body
9
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
10
+ identity and expression, level of experience, education, socio-economic status,
11
+ nationality, personal appearance, race, religion, or sexual identity
12
+ and orientation.
13
+
14
+ We pledge to act and interact in ways that contribute to an open, welcoming,
15
+ diverse, inclusive, and healthy community.
16
+
17
+ ## Our Standards
18
+
19
+ Examples of behavior that contributes to a positive environment for our
20
+ community include:
21
+
22
+ * Demonstrating empathy and kindness toward other people
23
+ * Being respectful of differing opinions, viewpoints, and experiences
24
+ * Giving and gracefully accepting constructive feedback
25
+ * Accepting responsibility and apologizing to those affected by our mistakes,
26
+ and learning from the experience
27
+ * Focusing on what is best not just for us as individuals, but for the
28
+ overall community
29
+
30
+ Examples of unacceptable behavior include:
31
+
32
+ * The use of sexualized language or imagery, and sexual attention or
33
+ advances of any kind
34
+ * Trolling, insulting or derogatory comments, and personal or political attacks
35
+ * Public or private harassment
36
+ * Publishing others' private information, such as a physical or email
37
+ address, without their explicit permission
38
+ * Other conduct which could reasonably be considered inappropriate in a
39
+ professional setting
40
+
41
+ ## Enforcement Responsibilities
42
+
43
+ Community leaders are responsible for clarifying and enforcing our standards of
44
+ acceptable behavior and will take appropriate and fair corrective action in
45
+ response to any behavior that they deem inappropriate, threatening, offensive,
46
+ or harmful.
47
+
48
+ Community leaders have the right and responsibility to remove, edit, or reject
49
+ comments, commits, code, wiki edits, issues, and other contributions that are
50
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
51
+ decisions when appropriate.
52
+
53
+ ## Scope
54
+
55
+ This Code of Conduct applies within all community spaces, and also applies when
56
+ an individual is officially representing the community in public spaces.
57
+ Examples of representing our community include using an official e-mail address,
58
+ posting via an official social media account, or acting as an appointed
59
+ representative at an online or offline event.
60
+
61
+ ## Enforcement
62
+
63
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
64
+ reported to the community leaders responsible for enforcement at
65
66
+ All complaints will be reviewed and investigated promptly and fairly.
67
+
68
+ All community leaders are obligated to respect the privacy and security of the
69
+ reporter of any incident.
70
+
71
+ ## Enforcement Guidelines
72
+
73
+ Community leaders will follow these Community Impact Guidelines in determining
74
+ the consequences for any action they deem in violation of this Code of Conduct:
75
+
76
+ ### 1. Correction
77
+
78
+ **Community Impact**: Use of inappropriate language or other behavior deemed
79
+ unprofessional or unwelcome in the community.
80
+
81
+ **Consequence**: A private, written warning from community leaders, providing
82
+ clarity around the nature of the violation and an explanation of why the
83
+ behavior was inappropriate. A public apology may be requested.
84
+
85
+ ### 2. Warning
86
+
87
+ **Community Impact**: A violation through a single incident or series
88
+ of actions.
89
+
90
+ **Consequence**: A warning with consequences for continued behavior. No
91
+ interaction with the people involved, including unsolicited interaction with
92
+ those enforcing the Code of Conduct, for a specified period of time. This
93
+ includes avoiding interactions in community spaces as well as external channels
94
+ like social media. Violating these terms may lead to a temporary or
95
+ permanent ban.
96
+
97
+ ### 3. Temporary Ban
98
+
99
+ **Community Impact**: A serious violation of community standards, including
100
+ sustained inappropriate behavior.
101
+
102
+ **Consequence**: A temporary ban from any sort of interaction or public
103
+ communication with the community for a specified period of time. No public or
104
+ private interaction with the people involved, including unsolicited interaction
105
+ with those enforcing the Code of Conduct, is allowed during this period.
106
+ Violating these terms may lead to a permanent ban.
107
+
108
+ ### 4. Permanent Ban
109
+
110
+ **Community Impact**: Demonstrating a pattern of violation of community
111
+ standards, including sustained inappropriate behavior, harassment of an
112
+ individual, or aggression toward or disparagement of classes of individuals.
113
+
114
+ **Consequence**: A permanent ban from any sort of public interaction within
115
+ the community.
116
+
117
+ ## Attribution
118
+
119
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
120
+ version 2.0, available at
121
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
122
+
123
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
124
+ enforcement ladder](https://github.com/mozilla/diversity).
125
+
126
+ [homepage]: https://www.contributor-covenant.org
127
+
128
+ For answers to common questions about this code of conduct, see the FAQ at
129
+ https://www.contributor-covenant.org/faq. Translations are available at
130
+ https://www.contributor-covenant.org/translations.
LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
4
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
5
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
6
+ Original work Copyright (c) 2015 braindead (https://github.com/braindead)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
README-CN.md ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 实时语音克隆 - 中文/普通话
2
+ ![mockingbird](https://user-images.githubusercontent.com/12797292/131216767-6eb251d6-14fc-4951-8324-2722f0cd4c63.jpg)
3
+
4
+ [![MIT License](https://img.shields.io/badge/license-MIT-blue.svg?style=flat)](http://choosealicense.com/licenses/mit/)
5
+
6
+ ### [English](README.md) | 中文
7
+
8
+ ### [DEMO VIDEO](https://www.bilibili.com/video/BV17Q4y1B7mY/) | [Wiki教程](https://github.com/babysor/MockingBird/wiki/Quick-Start-(Newbie)) | [训练教程](https://vaj2fgg8yn.feishu.cn/docs/doccn7kAbr3SJz0KM0SIDJ0Xnhd)
9
+
10
+ ## 特性
11
+ 🌍 **中文** 支持普通话并使用多种中文数据集进行测试:aidatatang_200zh, magicdata, aishell3, biaobei, MozillaCommonVoice, data_aishell 等
12
+
13
+ 🤩 **PyTorch** 适用于 pytorch,已在 1.9.0 版本(最新于 2021 年 8 月)中测试,GPU Tesla T4 和 GTX 2060
14
+
15
+ 🌍 **Windows + Linux** 可在 Windows 操作系统和 linux 操作系统中运行(苹果系统M1版也有社区成功运行案例)
16
+
17
+ 🤩 **Easy & Awesome** 仅需下载或新训练合成器(synthesizer)就有良好效果,复用预训练的编码器/声码器,或实时的HiFi-GAN作为vocoder
18
+
19
+ 🌍 **Webserver Ready** 可伺服你的训练结果,供远程调用
20
+
21
+ ### 进行中的工作
22
+ * GUI/客户端大升级与合并
23
+ [X] 初始化框架 `./mkgui` (基于streamlit + fastapi)和 [技术设计](https://vaj2fgg8yn.feishu.cn/docs/doccnvotLWylBub8VJIjKzoEaee)
24
+ [X] 增加 Voice Cloning and Conversion的演示页面
25
+ [X] 增加Voice Conversion的预处理preprocessing 和训练 training 页面
26
+ [ ] 增加其他的的预处理preprocessing 和训练 training 页面
27
+ * 模型后端基于ESPnet2升级
28
+
29
+
30
+ ## 开始
31
+ ### 1. 安装要求
32
+ > 按照原始存储库测试您是否已准备好所有环境。
33
+ 运行工具箱(demo_toolbox.py)需要 **Python 3.7 或更高版本** 。
34
+
35
+ * 安装 [PyTorch](https://pytorch.org/get-started/locally/)。
36
+ > 如果在用 pip 方式安装的时候出现 `ERROR: Could not find a version that satisfies the requirement torch==1.9.0+cu102 (from versions: 0.1.2, 0.1.2.post1, 0.1.2.post2)` 这个错误可能是 python 版本过低,3.9 可以安装成功
37
+ * 安装 [ffmpeg](https://ffmpeg.org/download.html#get-packages)。
38
+ * 运行`pip install -r requirements.txt` 来安装剩余的必要包。
39
+ * 安装 webrtcvad `pip install webrtcvad-wheels`。
40
+
41
+ ### 2. 准备预训练模型
42
+ 考虑训练您自己专属的模型或者下载社区他人训练好的模型:
43
+ > 近期创建了[知乎专题](https://www.zhihu.com/column/c_1425605280340504576) 将不定期更新炼丹小技巧or心得,也欢迎提问
44
+ #### 2.1 使用数据集自己训练encoder模型 (可选)
45
+
46
+ * 进行音频和梅尔频谱图预处理:
47
+ `python encoder_preprocess.py <datasets_root>`
48
+ 使用`-d {dataset}` 指定数据集,支持 librispeech_other,voxceleb1,aidatatang_200zh,使用逗号分割处理多数据集。
49
+ * 训练encoder: `python encoder_train.py my_run <datasets_root>/SV2TTS/encoder`
50
+ > 训练encoder使用了visdom。你可以加上`-no_visdom`禁用visdom,但是有可视化会更好。在单独的命令行/进程中运行"visdom"来启动visdom服务器。
51
+
52
+ #### 2.2 使用数据集自己训练合成器模型(与2.3二选一)
53
+ * 下载 数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
54
+ * 进行音频和梅尔频谱图预处理:
55
+ `python pre.py <datasets_root> -d {dataset} -n {number}`
56
+ 可传入参数:
57
+ * `-d {dataset}` 指定数据集,支持 aidatatang_200zh, magicdata, aishell3, data_aishell, 不传默认为aidatatang_200zh
58
+ * `-n {number}` 指定并行数,CPU 11770k + 32GB实测10没有问题
59
+ > 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
60
+
61
+ * 训练合成器:
62
+ `python synthesizer_train.py mandarin <datasets_root>/SV2TTS/synthesizer`
63
+
64
+ * 当您在训练文件夹 *synthesizer/saved_models/* 中看到注意线显示和损失满足您的需要时,请转到`启动程序`一步。
65
+
66
+ #### 2.3使用社区预先训练好的合成器(与2.2二选一)
67
+ > 当实在没有设备或者不想慢慢调试,可以使用社区贡献的模型(欢迎持续分享):
68
+
69
+ | 作者 | 下载链接 | 效果预览 | 信息 |
70
+ | --- | ----------- | ----- | ----- |
71
+ | 作者 | https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g [百度盘链接](https://pan.baidu.com/s/1iONvRxmkI-t1nHqxKytY3g) 4j5d | | 75k steps 用3个开源数据集混合训练
72
+ | 作者 | https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw [百度盘链接](https://pan.baidu.com/s/1fMh9IlgKJlL2PIiRTYDUvw) 提取码:om7f | | 25k steps 用3个开源数据集混合训练, 切换到tag v0.0.1使用
73
+ |@FawenYo | https://drive.google.com/file/d/1H-YGOUHpmqKxJ9FRc6vAjPuqQki24UbC/view?usp=sharing [百度盘链接](https://pan.baidu.com/s/1vSYXO4wsLyjnF3Unl-Xoxg) 提取码:1024 | [input](https://github.com/babysor/MockingBird/wiki/audio/self_test.mp3) [output](https://github.com/babysor/MockingBird/wiki/audio/export.wav) | 200k steps 台湾口音需切换到tag v0.0.1使用
74
+ |@miven| https://pan.baidu.com/s/1PI-hM3sn5wbeChRryX-RCQ 提取码:2021 | https://www.bilibili.com/video/BV1uh411B7AD/ | 150k steps 注意:根据[issue](https://github.com/babysor/MockingBird/issues/37)修复 并切换到tag v0.0.1使用
75
+
76
+ #### 2.4训练声码器 (可选)
77
+ 对效果影响不大,已经预置3款,如果希望自己训练可以参考以下命令。
78
+ * 预处理数据:
79
+ `python vocoder_preprocess.py <datasets_root> -m <synthesizer_model_path>`
80
+ > `<datasets_root>`替换为你的数据集目录,`<synthesizer_model_path>`替换为一个你最好的synthesizer模型目录,例如 *sythensizer\saved_models\xxx*
81
+
82
+
83
+ * 训练wavernn声码器:
84
+ `python vocoder_train.py <trainid> <datasets_root>`
85
+ > `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
86
+
87
+ * 训练hifigan声码器:
88
+ `python vocoder_train.py <trainid> <datasets_root> hifigan`
89
+ > `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
90
+ * 训练fregan声码器:
91
+ `python vocoder_train.py <trainid> <datasets_root> --config config.json fregan`
92
+ > `<trainid>`替换为你想要的标识,同一标识再次训练时会延续原模型
93
+ * 将GAN声码器的训练切换为多GPU模式:修改GAN文件夹下.json文件中的"num_gpus"参数
94
+ ### 3. 启动程序或工具箱
95
+ 您可以尝试使用以下命令:
96
+
97
+ ### 3.1 启动Web程序(v2):
98
+ `python web.py`
99
+ 运行成功后在浏览器打开地址, 默认为 `http://localhost:8080`
100
+ > * 仅支持手动新录音(16khz), 不支持超过4MB的录音,最佳长度在5~15秒
101
+
102
+ ### 3.2 启动工具箱:
103
+ `python demo_toolbox.py -d <datasets_root>`
104
+ > 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
105
+
106
+ <img width="1042" alt="d48ea37adf3660e657cfb047c10edbc" src="https://user-images.githubusercontent.com/7423248/134275227-c1ddf154-f118-4b77-8949-8c4c7daf25f0.png">
107
+
108
+ ### 4. 番外:语音转换Voice Conversion(PPG based)
109
+ 想像柯南拿着变声器然后发出毛利小五郎的声音吗?本项目现基于PPG-VC,引入额外两个模块(PPG extractor + PPG2Mel), 可以实现变声功能。(文档不全,尤其是训练部分,正在努力补充中)
110
+ #### 4.0 准备环境
111
+ * 确保项目以上环境已经安装ok,运行`pip install espnet` 来安装剩余的必要包。
112
+ * 下载以下模型 链接:https://pan.baidu.com/s/1bl_x_DHJSAUyN2fma-Q_Wg
113
+ 提取码:gh41
114
+ * 24K采样率专用的vocoder(hifigan)到 *vocoder\saved_models\xxx*
115
+ * 预训练的ppg特征encoder(ppg_extractor)到 *ppg_extractor\saved_models\xxx*
116
+ * 预训练的PPG2Mel到 *ppg2mel\saved_models\xxx*
117
+
118
+ #### 4.1 使用数据集自己训练PPG2Mel模型 (可选)
119
+
120
+ * 下载aidatatang_200zh数据集并解压:确保您可以访问 *train* 文件夹中的所有音频文件(如.wav)
121
+ * 进行音频和梅尔频谱图预处理:
122
+ `python pre4ppg.py <datasets_root> -d {dataset} -n {number}`
123
+ 可传入参数:
124
+ * `-d {dataset}` 指定数据集,支持 aidatatang_200zh, 不传默认为aidatatang_200zh
125
+ * `-n {number}` 指定并行数,CPU 11770k在8的情况下,需要运行12到18小时!待优化
126
+ > 假如你下载的 `aidatatang_200zh`文件放在D盘,`train`文件路径为 `D:\data\aidatatang_200zh\corpus\train` , 你的`datasets_root`就是 `D:\data\`
127
+
128
+ * 训练合成器, 注意在上一步先下载好`ppg2mel.yaml`, 修改里面的地址指向预训练好的文件夹:
129
+ `python ppg2mel_train.py --config .\ppg2mel\saved_models\ppg2mel.yaml --oneshotvc `
130
+ * 如果想要继续上一次的训练,可以通过`--load .\ppg2mel\saved_models\<old_pt_file>` 参数指定一个预训练模型文件。
131
+
132
+ #### 4.2 启动工具箱VC模式
133
+ 您可以尝试使用以下命令:
134
+ `python demo_toolbox.py -vc -d <datasets_root>`
135
+ > 请指定一个可用的数据集文件路径,如果有支持的数据集则会自动加载供调试,也同时会作为手动录制音频的存储目录。
136
+ <img width="971" alt="微信图片_20220305005351" src="https://user-images.githubusercontent.com/7423248/156805733-2b093dbc-d989-4e68-8609-db11f365886a.png">
137
+
138
+ ## 引用及论文
139
+ > 该库一开始从仅支持英语的[Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) 分叉出来的,鸣谢作者。
140
+
141
+ | URL | Designation | 标题 | 实现源码 |
142
+ | --- | ----------- | ----- | --------------------- |
143
+ | [1803.09017](https://arxiv.org/abs/1803.09017) | GlobalStyleToken (synthesizer)| Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis | 本代码库 |
144
+ | [2010.05646](https://arxiv.org/abs/2010.05646) | HiFi-GAN (vocoder)| Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis | 本代码库 |
145
+ | [2106.02297](https://arxiv.org/abs/2106.02297) | Fre-GAN (vocoder)| Fre-GAN: Adversarial Frequency-consistent Audio Synthesis | 本代码库 |
146
+ |[**1806.04558**](https://arxiv.org/pdf/1806.04558.pdf) | SV2TTS | Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis | 本代码库 |
147
+ |[1802.08435](https://arxiv.org/pdf/1802.08435.pdf) | WaveRNN (vocoder) | Efficient Neural Audio Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN) |
148
+ |[1703.10135](https://arxiv.org/pdf/1703.10135.pdf) | Tacotron (synthesizer) | Tacotron: Towards End-to-End Speech Synthesis | [fatchord/WaveRNN](https://github.com/fatchord/WaveRNN)
149
+ |[1710.10467](https://arxiv.org/pdf/1710.10467.pdf) | GE2E (encoder)| Generalized End-To-End Loss for Speaker Verification | 本代码库 |
150
+
151
+ ## 常見問題(FQ&A)
152
+ #### 1.數據集哪裡下載?
153
+ | 数据集 | OpenSLR地址 | 其他源 (Google Drive, Baidu网盘等) |
154
+ | --- | ----------- | ---------------|
155
+ | aidatatang_200zh | [OpenSLR](http://www.openslr.org/62/) | [Google Drive](https://drive.google.com/file/d/110A11KZoVe7vy6kXlLb6zVPLb_J91I_t/view?usp=sharing) |
156
+ | magicdata | [OpenSLR](http://www.openslr.org/68/) | [Google Drive (Dev set)](https://drive.google.com/file/d/1g5bWRUSNH68ycC6eNvtwh07nX3QhOOlo/view?usp=sharing) |
157
+ | aishell3 | [OpenSLR](https://www.openslr.org/93/) | [Google Drive](https://drive.google.com/file/d/1shYp_o4Z0X0cZSKQDtFirct2luFUwKzZ/view?usp=sharing) |
158
+ | data_aishell | [OpenSLR](https://www.openslr.org/33/) | |
159
+ > 解壓 aidatatang_200zh 後,還需將 `aidatatang_200zh\corpus\train`下的檔案全選解壓縮
160
+
161
+ #### 2.`<datasets_root>`是什麼意思?
162
+ 假如數據集路徑為 `D:\data\aidatatang_200zh`,那麼 `<datasets_root>`就是 `D:\data`
163
+
164
+ #### 3.訓練模型顯存不足
165
+ 訓練合成器時:將 `synthesizer/hparams.py`中的batch_size參數調小
166
+ ```
167
+ //調整前
168
+ tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
169
+ (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
170
+ (2, 2e-4, 80_000, 12), #
171
+ (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
172
+ (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
173
+ (2, 1e-5, 640_000, 12)], # lr = learning rate
174
+ //調整後
175
+ tts_schedule = [(2, 1e-3, 20_000, 8), # Progressive training schedule
176
+ (2, 5e-4, 40_000, 8), # (r, lr, step, batch_size)
177
+ (2, 2e-4, 80_000, 8), #
178
+ (2, 1e-4, 160_000, 8), # r = reduction factor (# of mel frames
179
+ (2, 3e-5, 320_000, 8), # synthesized for each decoder iteration)
180
+ (2, 1e-5, 640_000, 8)], # lr = learning rate
181
+ ```
182
+
183
+ 聲碼器-預處理數據集時:將 `synthesizer/hparams.py`中的batch_size參數調小
184
+ ```
185
+ //調整前
186
+ ### Data Preprocessing
187
+ max_mel_frames = 900,
188
+ rescale = True,
189
+ rescaling_max = 0.9,
190
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
191
+ //調整後
192
+ ### Data Preprocessing
193
+ max_mel_frames = 900,
194
+ rescale = True,
195
+ rescaling_max = 0.9,
196
+ synthesis_batch_size = 8, # For vocoder preprocessing and inference.
197
+ ```
198
+
199
+ 聲碼器-訓練聲碼器時:將 `vocoder/wavernn/hparams.py`中的batch_size參數調小
200
+ ```
201
+ //調整前
202
+ # Training
203
+ voc_batch_size = 100
204
+ voc_lr = 1e-4
205
+ voc_gen_at_checkpoint = 5
206
+ voc_pad = 2
207
+
208
+ //調整後
209
+ # Training
210
+ voc_batch_size = 6
211
+ voc_lr = 1e-4
212
+ voc_gen_at_checkpoint = 5
213
+ voc_pad =2
214
+ ```
215
+
216
+ #### 4.碰到`RuntimeError: Error(s) in loading state_dict for Tacotron: size mismatch for encoder.embedding.weight: copying a param with shape torch.Size([70, 512]) from checkpoint, the shape in current model is torch.Size([75, 512]).`
217
+ 請參照 issue [#37](https://github.com/babysor/MockingBird/issues/37)
218
+
219
+ #### 5.如何改善CPU、GPU佔用率?
220
+ 適情況調整batch_size參數來改善
221
+
222
+ #### 6.發生 `頁面文件太小,無法完成操作`
223
+ 請參考這篇[文章](https://blog.csdn.net/qq_17755303/article/details/112564030),將虛擬內存更改為100G(102400),例如:档案放置D槽就更改D槽的虚拟内存
224
+
225
+ #### 7.什么时候算训练完成?
226
+ 首先一定要出现注意力模型,其次是loss足够低,取决于硬件设备和数据集。拿本人的供参考,我的注意力是在 18k 步之后出现的,并且在 50k 步之后损失变得低于 0.4
227
+ ![attention_step_20500_sample_1](https://user-images.githubusercontent.com/7423248/128587252-f669f05a-f411-4811-8784-222156ea5e9d.png)
228
+
229
+ ![step-135500-mel-spectrogram_sample_1](https://user-images.githubusercontent.com/7423248/128587255-4945faa0-5517-46ea-b173-928eff999330.png)
230
+
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MockingBird
3
+ emoji: 🔥
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.1.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+
4
+ import re
5
+ import random
6
+ import string
7
+ import librosa
8
+ import numpy as np
9
+
10
+ from pathlib import Path
11
+ from scipy.io.wavfile import write
12
+
13
+ from encoder import inference as encoder
14
+ from vocoder.hifigan import inference as gan_vocoder
15
+ from synthesizer.inference import Synthesizer
16
+
17
+ class Mandarin:
18
+ def __init__(self):
19
+ self.encoder_path = "encoder/saved_models/pretrained.pt"
20
+ self.vocoder_path = "vocoder/saved_models/pretrained/g_hifigan.pt"
21
+ self.config_fpath = "vocoder/hifigan/config_16k_.json"
22
+ self.accent = "synthesizer/saved_models/普通话.pt"
23
+
24
+ synthesizers_cache = {}
25
+ if synthesizers_cache.get(self.accent) is None:
26
+ self.current_synt = Synthesizer(Path(self.accent))
27
+ synthesizers_cache[self.accent] = self.current_synt
28
+ else:
29
+ self.current_synt = synthesizers_cache[self.accent]
30
+
31
+ encoder.load_model(Path(self.encoder_path))
32
+ gan_vocoder.load_model(Path(self.vocoder_path), self.config_fpath)
33
+
34
+ def setVoice(self, timbre):
35
+ self.timbre = timbre
36
+ wav, sample_rate, = librosa.load(self.timbre)
37
+
38
+ encoder_wav = encoder.preprocess_wav(wav, sample_rate)
39
+ self.embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
40
+
41
+ def say(self, text):
42
+ texts = filter(None, text.split("\n"))
43
+ punctuation = "!,。、?!,.?::" # punctuate and split/clean text
44
+ processed_texts = []
45
+ for text in texts:
46
+ for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
47
+ if processed_text:
48
+ processed_texts.append(processed_text.strip())
49
+ texts = processed_texts
50
+ embeds = [self.embed] * len(texts)
51
+
52
+ specs = self.current_synt.synthesize_spectrograms(texts, embeds)
53
+ spec = np.concatenate(specs, axis=1)
54
+ wav, sample_rate = gan_vocoder.infer_waveform(spec)
55
+
56
+ return wav, sample_rate
57
+
58
+ def greet(audio, text, voice=None):
59
+
60
+ if voice is None:
61
+ voice = Mandarin()
62
+ voice.setVoice(audio.name)
63
+ voice.say("加载成功")
64
+ wav, sample_rate = voice.say(text)
65
+
66
+ output_file = "".join( random.sample(string.ascii_lowercase + string.digits, 11) ) + ".wav"
67
+
68
+ write(output_file, sample_rate, wav.astype(np.float32))
69
+
70
+ return output_file, voice
71
+
72
+ def main():
73
+ gr.Interface(
74
+ fn=greet,
75
+ inputs=[gr.inputs.Audio(type="file"),"text", "state"],
76
+ outputs=[gr.outputs.Audio(type="file"), "state"]
77
+ ).launch()
78
+
79
+ if __name__=="__main__":
80
+ main()
demo_toolbox.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from toolbox import Toolbox
3
+ from utils.argutils import print_args
4
+ from utils.modelutils import check_model_paths
5
+ import argparse
6
+ import os
7
+
8
+
9
+ if __name__ == '__main__':
10
+ parser = argparse.ArgumentParser(
11
+ description="Runs the toolbox",
12
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
13
+ )
14
+
15
+ parser.add_argument("-d", "--datasets_root", type=Path, help= \
16
+ "Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
17
+ "supported datasets.", default=None)
18
+ parser.add_argument("-vc", "--vc_mode", action="store_true",
19
+ help="Voice Conversion Mode(PPG based)")
20
+ parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
21
+ help="Directory containing saved encoder models")
22
+ parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
23
+ help="Directory containing saved synthesizer models")
24
+ parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
25
+ help="Directory containing saved vocoder models")
26
+ parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models",
27
+ help="Directory containing saved extrator models")
28
+ parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models",
29
+ help="Directory containing saved convert models")
30
+ parser.add_argument("--cpu", action="store_true", help=\
31
+ "If True, processing is done on CPU, even when a GPU is available.")
32
+ parser.add_argument("--seed", type=int, default=None, help=\
33
+ "Optional random number seed value to make toolbox deterministic.")
34
+ parser.add_argument("--no_mp3_support", action="store_true", help=\
35
+ "If True, no mp3 files are allowed.")
36
+ args = parser.parse_args()
37
+ print_args(args, parser)
38
+
39
+ if args.cpu:
40
+ # Hide GPUs from Pytorch to force CPU processing
41
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
42
+ del args.cpu
43
+
44
+ ## Remind the user to download pretrained models if needed
45
+ check_model_paths(encoder_path=args.enc_models_dir, synthesizer_path=args.syn_models_dir,
46
+ vocoder_path=args.voc_models_dir)
47
+
48
+ # Launch the toolbox
49
+ Toolbox(**vars(args))
encoder/__init__.py ADDED
File without changes
encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ y=wav,
60
+ sr=sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(np.bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
encoder/inference.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath, _device)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+ return _model
38
+
39
+ def set_model(model, device=None):
40
+ global _model, _device
41
+ _model = model
42
+ if device is None:
43
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ _device = device
45
+ _model.to(device)
46
+
47
+ def is_loaded():
48
+ return _model is not None
49
+
50
+
51
+ def embed_frames_batch(frames_batch):
52
+ """
53
+ Computes embeddings for a batch of mel spectrogram.
54
+
55
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
56
+ (batch_size, n_frames, n_channels)
57
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
58
+ """
59
+ if _model is None:
60
+ raise Exception("Model was not loaded. Call load_model() before inference.")
61
+
62
+ frames = torch.from_numpy(frames_batch).to(_device)
63
+ embed = _model.forward(frames).detach().cpu().numpy()
64
+ return embed
65
+
66
+
67
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
68
+ min_pad_coverage=0.75, overlap=0.5, rate=None):
69
+ """
70
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
71
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
72
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
73
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
74
+ defined in params_data.py.
75
+
76
+ The returned ranges may be indexing further than the length of the waveform. It is
77
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
78
+
79
+ :param n_samples: the number of samples in the waveform
80
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
81
+ utterance
82
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
83
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
84
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
85
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
86
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
87
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
88
+ utterances are entirely disjoint.
89
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
90
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
91
+ utterances.
92
+ """
93
+ assert 0 <= overlap < 1
94
+ assert 0 < min_pad_coverage <= 1
95
+
96
+ if rate != None:
97
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
98
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
99
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
100
+ else:
101
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
102
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
103
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
104
+
105
+ assert 0 < frame_step, "The rate is too high"
106
+ assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
107
+ (sampling_rate / (samples_per_frame * partials_n_frames))
108
+
109
+ # Compute the slices
110
+ wav_slices, mel_slices = [], []
111
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
112
+ for i in range(0, steps, frame_step):
113
+ mel_range = np.array([i, i + partial_utterance_n_frames])
114
+ wav_range = mel_range * samples_per_frame
115
+ mel_slices.append(slice(*mel_range))
116
+ wav_slices.append(slice(*wav_range))
117
+
118
+ # Evaluate whether extra padding is warranted or not
119
+ last_wav_range = wav_slices[-1]
120
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
121
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
122
+ mel_slices = mel_slices[:-1]
123
+ wav_slices = wav_slices[:-1]
124
+
125
+ return wav_slices, mel_slices
126
+
127
+
128
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
129
+ """
130
+ Computes an embedding for a single utterance.
131
+
132
+ # TODO: handle multiple wavs to benefit from batching on GPU
133
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
134
+ :param using_partials: if True, then the utterance is split in partial utterances of
135
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
136
+ normalized average. If False, the utterance is instead computed from feeding the entire
137
+ spectogram to the network.
138
+ :param return_partials: if True, the partial embeddings will also be returned along with the
139
+ wav slices that correspond to the partial embeddings.
140
+ :param kwargs: additional arguments to compute_partial_splits()
141
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
142
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
143
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
144
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
145
+ instead.
146
+ """
147
+ # Process the entire utterance if not using partials
148
+ if not using_partials:
149
+ frames = audio.wav_to_mel_spectrogram(wav)
150
+ embed = embed_frames_batch(frames[None, ...])[0]
151
+ if return_partials:
152
+ return embed, None, None
153
+ return embed
154
+
155
+ # Compute where to split the utterance into partials and pad if necessary
156
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
157
+ max_wave_length = wave_slices[-1].stop
158
+ if max_wave_length >= len(wav):
159
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
160
+
161
+ # Split the utterance into partials
162
+ frames = audio.wav_to_mel_spectrogram(wav)
163
+ frames_batch = np.array([frames[s] for s in mel_slices])
164
+ partial_embeds = embed_frames_batch(frames_batch)
165
+
166
+ # Compute the utterance embedding from the partial embeddings
167
+ raw_embed = np.mean(partial_embeds, axis=0)
168
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
169
+
170
+ if return_partials:
171
+ return embed, partial_embeds, wave_slices
172
+ return embed
173
+
174
+
175
+ def embed_speaker(wavs, **kwargs):
176
+ raise NotImplemented()
177
+
178
+
179
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
180
+ if ax is None:
181
+ ax = plt.gca()
182
+
183
+ if shape is None:
184
+ height = int(np.sqrt(len(embed)))
185
+ shape = (height, -1)
186
+ embed = embed.reshape(shape)
187
+
188
+ cmap = cm.get_cmap()
189
+ mappable = ax.imshow(embed, cmap=cmap)
190
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
191
+ sm = cm.ScalarMappable(cmap=cmap)
192
+ sm.set_clim(*color_range)
193
+
194
+ ax.set_xticks([]), ax.set_yticks([])
195
+ ax.set_title(title)
encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
encoder/preprocess.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from encoder.params_data import *
3
+ from encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+ def preprocess_aidatatang_200zh(datasets_root: Path, out_dir: Path, skip_existing=False):
121
+ dataset_name = "aidatatang_200zh"
122
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
123
+ if not dataset_root:
124
+ return
125
+ # Preprocess all speakers
126
+ speaker_dirs = list(dataset_root.joinpath("corpus", "train").glob("*"))
127
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
128
+ skip_existing, logger)
129
+
130
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
131
+ for dataset_name in librispeech_datasets["train"]["other"]:
132
+ # Initialize the preprocessing
133
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
134
+ if not dataset_root:
135
+ return
136
+
137
+ # Preprocess all speakers
138
+ speaker_dirs = list(dataset_root.glob("*"))
139
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
140
+ skip_existing, logger)
141
+
142
+
143
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
144
+ # Initialize the preprocessing
145
+ dataset_name = "VoxCeleb1"
146
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
147
+ if not dataset_root:
148
+ return
149
+
150
+ # Get the contents of the meta file
151
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
152
+ metadata = [line.split("\t") for line in metafile][1:]
153
+
154
+ # Select the ID and the nationality, filter out non-anglophone speakers
155
+ nationalities = {line[0]: line[3] for line in metadata}
156
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
157
+ nationality.lower() in anglophone_nationalites]
158
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
159
+ (len(keep_speaker_ids), len(nationalities)))
160
+
161
+ # Get the speaker directories for anglophone speakers only
162
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
163
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
164
+ speaker_dir.name in keep_speaker_ids]
165
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
166
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
167
+
168
+ # Preprocess all speakers
169
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
170
+ skip_existing, logger)
171
+
172
+
173
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
174
+ # Initialize the preprocessing
175
+ dataset_name = "VoxCeleb2"
176
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
177
+ if not dataset_root:
178
+ return
179
+
180
+ # Get the speaker directories
181
+ # Preprocess all speakers
182
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
183
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
184
+ skip_existing, logger)
encoder/saved_models/pretrained.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57715adc6f36047166ab06e37b904240aee2f4d10fc88f78ed91510cf4b38666
3
+ size 17095158
encoder/train.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.visualizations import Visualizations
2
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from encoder.params_model import *
4
+ from encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # For correct profiling (cuda operations are async)
11
+ if device.type == "cuda":
12
+ torch.cuda.synchronize(device)
13
+
14
+
15
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
16
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
17
+ no_visdom: bool):
18
+ # Create a dataset and a dataloader
19
+ dataset = SpeakerVerificationDataset(clean_data_root)
20
+ loader = SpeakerVerificationDataLoader(
21
+ dataset,
22
+ speakers_per_batch,
23
+ utterances_per_speaker,
24
+ num_workers=8,
25
+ )
26
+
27
+ # Setup the device on which to run the forward pass and the loss. These can be different,
28
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
29
+ # hyperparameters) faster on the CPU.
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ # FIXME: currently, the gradient is None if loss_device is cuda
32
+ loss_device = torch.device("cpu")
33
+
34
+ # Create the model and the optimizer
35
+ model = SpeakerEncoder(device, loss_device)
36
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
37
+ init_step = 1
38
+
39
+ # Configure file path for the model
40
+ state_fpath = models_dir.joinpath(run_id + ".pt")
41
+ backup_dir = models_dir.joinpath(run_id + "_backups")
42
+
43
+ # Load any existing model
44
+ if not force_restart:
45
+ if state_fpath.exists():
46
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
47
+ checkpoint = torch.load(state_fpath)
48
+ init_step = checkpoint["step"]
49
+ model.load_state_dict(checkpoint["model_state"])
50
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
51
+ optimizer.param_groups[0]["lr"] = learning_rate_init
52
+ else:
53
+ print("No model \"%s\" found, starting training from scratch." % run_id)
54
+ else:
55
+ print("Starting the training from scratch.")
56
+ model.train()
57
+
58
+ # Initialize the visualization environment
59
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
60
+ vis.log_dataset(dataset)
61
+ vis.log_params()
62
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
63
+ vis.log_implementation({"Device": device_name})
64
+
65
+ # Training loop
66
+ profiler = Profiler(summarize_every=10, disabled=False)
67
+ for step, speaker_batch in enumerate(loader, init_step):
68
+ profiler.tick("Blocking, waiting for batch (threaded)")
69
+
70
+ # Forward pass
71
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
72
+ sync(device)
73
+ profiler.tick("Data to %s" % device)
74
+ embeds = model(inputs)
75
+ sync(device)
76
+ profiler.tick("Forward pass")
77
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
78
+ loss, eer = model.loss(embeds_loss)
79
+ sync(loss_device)
80
+ profiler.tick("Loss")
81
+
82
+ # Backward pass
83
+ model.zero_grad()
84
+ loss.backward()
85
+ profiler.tick("Backward pass")
86
+ model.do_gradient_ops()
87
+ optimizer.step()
88
+ profiler.tick("Parameter update")
89
+
90
+ # Update visualizations
91
+ # learning_rate = optimizer.param_groups[0]["lr"]
92
+ vis.update(loss.item(), eer, step)
93
+
94
+ # Draw projections and save them to the backup folder
95
+ if umap_every != 0 and step % umap_every == 0:
96
+ print("Drawing and saving projections (step %d)" % step)
97
+ backup_dir.mkdir(exist_ok=True)
98
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
99
+ embeds = embeds.detach().cpu().numpy()
100
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
101
+ vis.save()
102
+
103
+ # Overwrite the latest version of the model
104
+ if save_every != 0 and step % save_every == 0:
105
+ print("Saving the model (step %d)" % step)
106
+ torch.save({
107
+ "step": step + 1,
108
+ "model_state": model.state_dict(),
109
+ "optimizer_state": optimizer.state_dict(),
110
+ }, state_fpath)
111
+
112
+ # Make a backup
113
+ if backup_every != 0 and step % backup_every == 0:
114
+ print("Making a backup (step %d)" % step)
115
+ backup_dir.mkdir(exist_ok=True)
116
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
117
+ torch.save({
118
+ "step": step + 1,
119
+ "model_state": model.state_dict(),
120
+ "optimizer_state": optimizer.state_dict(),
121
+ }, backup_fpath)
122
+
123
+ profiler.tick("Extras (visualizations, saving)")
encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from encoder import params_data
69
+ from encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
encoder_preprocess.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2, preprocess_aidatatang_200zh
2
+ from utils.argutils import print_args
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+ if __name__ == "__main__":
7
+ class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
8
+ pass
9
+
10
+ parser = argparse.ArgumentParser(
11
+ description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
12
+ "writes them to the disk. This will allow you to train the encoder. The "
13
+ "datasets required are at least one of LibriSpeech, VoxCeleb1, VoxCeleb2, aidatatang_200zh. ",
14
+ formatter_class=MyFormatter
15
+ )
16
+ parser.add_argument("datasets_root", type=Path, help=\
17
+ "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
18
+ parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
19
+ "Path to the output directory that will contain the mel spectrograms. If left out, "
20
+ "defaults to <datasets_root>/SV2TTS/encoder/")
21
+ parser.add_argument("-d", "--datasets", type=str,
22
+ default="librispeech_other,voxceleb1,aidatatang_200zh", help=\
23
+ "Comma-separated list of the name of the datasets you want to preprocess. Only the train "
24
+ "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
25
+ "voxceleb2.")
26
+ parser.add_argument("-s", "--skip_existing", action="store_true", help=\
27
+ "Whether to skip existing output files with the same name. Useful if this script was "
28
+ "interrupted.")
29
+ parser.add_argument("--no_trim", action="store_true", help=\
30
+ "Preprocess audio without trimming silences (not recommended).")
31
+ args = parser.parse_args()
32
+
33
+ # Verify webrtcvad is available
34
+ if not args.no_trim:
35
+ try:
36
+ import webrtcvad
37
+ except:
38
+ raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
39
+ "noise removal and is recommended. Please install and try again. If installation fails, "
40
+ "use --no_trim to disable this error message.")
41
+ del args.no_trim
42
+
43
+ # Process the arguments
44
+ args.datasets = args.datasets.split(",")
45
+ if not hasattr(args, "out_dir"):
46
+ args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
47
+ assert args.datasets_root.exists()
48
+ args.out_dir.mkdir(exist_ok=True, parents=True)
49
+
50
+ # Preprocess the datasets
51
+ print_args(args, parser)
52
+ preprocess_func = {
53
+ "librispeech_other": preprocess_librispeech,
54
+ "voxceleb1": preprocess_voxceleb1,
55
+ "voxceleb2": preprocess_voxceleb2,
56
+ "aidatatang_200zh": preprocess_aidatatang_200zh,
57
+ }
58
+ args = vars(args)
59
+ for dataset in args.pop("datasets"):
60
+ print("Preprocessing %s" % dataset)
61
+ preprocess_func[dataset](**args)
encoder_train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.argutils import print_args
2
+ from encoder.train import train
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser(
9
+ description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
10
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
11
+ )
12
+
13
+ parser.add_argument("run_id", type=str, help= \
14
+ "Name for this model instance. If a model state from the same run ID was previously "
15
+ "saved, the training will restart from there. Pass -f to overwrite saved states and "
16
+ "restart from scratch.")
17
+ parser.add_argument("clean_data_root", type=Path, help= \
18
+ "Path to the output directory of encoder_preprocess.py. If you left the default "
19
+ "output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
20
+ parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
21
+ "Path to the output directory that will contain the saved model weights, as well as "
22
+ "backups of those weights and plots generated during training.")
23
+ parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
24
+ "Number of steps between updates of the loss and the plots.")
25
+ parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
26
+ "Number of steps between updates of the umap projection. Set to 0 to never update the "
27
+ "projections.")
28
+ parser.add_argument("-s", "--save_every", type=int, default=500, help= \
29
+ "Number of steps between updates of the model on the disk. Set to 0 to never save the "
30
+ "model.")
31
+ parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
32
+ "Number of steps between backups of the model. Set to 0 to never make backups of the "
33
+ "model.")
34
+ parser.add_argument("-f", "--force_restart", action="store_true", help= \
35
+ "Do not load any saved model.")
36
+ parser.add_argument("--visdom_server", type=str, default="http://localhost")
37
+ parser.add_argument("--no_visdom", action="store_true", help= \
38
+ "Disable visdom.")
39
+ args = parser.parse_args()
40
+
41
+ # Process the arguments
42
+ args.models_dir.mkdir(exist_ok=True)
43
+
44
+ # Run the training
45
+ print_args(args, parser)
46
+ train(**vars(args))
47
+
gen_voice.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import model_embedding_size as speaker_embedding_size
2
+ from utils.argutils import print_args
3
+ from utils.modelutils import check_model_paths
4
+ from synthesizer.inference import Synthesizer
5
+ from encoder import inference as encoder
6
+ from vocoder.wavernn import inference as rnn_vocoder
7
+ from vocoder.hifigan import inference as gan_vocoder
8
+ from pathlib import Path
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import librosa
12
+ import argparse
13
+ import torch
14
+ import sys
15
+ import os
16
+ import re
17
+ import cn2an
18
+ import glob
19
+
20
+ from audioread.exceptions import NoBackendError
21
+ vocoder = gan_vocoder
22
+
23
+ def gen_one_wav(synthesizer, in_fpath, embed, texts, file_name, seq):
24
+ embeds = [embed] * len(texts)
25
+ # If you know what the attention layer alignments are, you can retrieve them here by
26
+ # passing return_alignments=True
27
+ specs = synthesizer.synthesize_spectrograms(texts, embeds, style_idx=-1, min_stop_token=4, steps=400)
28
+ #spec = specs[0]
29
+ breaks = [spec.shape[1] for spec in specs]
30
+ spec = np.concatenate(specs, axis=1)
31
+
32
+ # If seed is specified, reset torch seed and reload vocoder
33
+ # Synthesizing the waveform is fairly straightforward. Remember that the longer the
34
+ # spectrogram, the more time-efficient the vocoder.
35
+ generated_wav, output_sample_rate = vocoder.infer_waveform(spec)
36
+
37
+ # Add breaks
38
+ b_ends = np.cumsum(np.array(breaks) * synthesizer.hparams.hop_size)
39
+ b_starts = np.concatenate(([0], b_ends[:-1]))
40
+ wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)]
41
+ breaks = [np.zeros(int(0.15 * synthesizer.sample_rate))] * len(breaks)
42
+ generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
43
+
44
+ ## Post-generation
45
+ # There's a bug with sounddevice that makes the audio cut one second earlier, so we
46
+ # pad it.
47
+
48
+ # Trim excess silences to compensate for gaps in spectrograms (issue #53)
49
+ generated_wav = encoder.preprocess_wav(generated_wav)
50
+ generated_wav = generated_wav / np.abs(generated_wav).max() * 0.97
51
+
52
+ # Save it on the disk
53
+ model=os.path.basename(in_fpath)
54
+ filename = "%s_%d_%s.wav" %(file_name, seq, model)
55
+ sf.write(filename, generated_wav, synthesizer.sample_rate)
56
+
57
+ print("\nSaved output as %s\n\n" % filename)
58
+
59
+
60
+ def generate_wav(enc_model_fpath, syn_model_fpath, voc_model_fpath, in_fpath, input_txt, file_name):
61
+ if torch.cuda.is_available():
62
+ device_id = torch.cuda.current_device()
63
+ gpu_properties = torch.cuda.get_device_properties(device_id)
64
+ ## Print some environment information (for debugging purposes)
65
+ print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
66
+ "%.1fGb total memory.\n" %
67
+ (torch.cuda.device_count(),
68
+ device_id,
69
+ gpu_properties.name,
70
+ gpu_properties.major,
71
+ gpu_properties.minor,
72
+ gpu_properties.total_memory / 1e9))
73
+ else:
74
+ print("Using CPU for inference.\n")
75
+
76
+ print("Preparing the encoder, the synthesizer and the vocoder...")
77
+ encoder.load_model(enc_model_fpath)
78
+ synthesizer = Synthesizer(syn_model_fpath)
79
+ vocoder.load_model(voc_model_fpath)
80
+
81
+ encoder_wav = synthesizer.load_preprocess_wav(in_fpath)
82
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
83
+
84
+ texts = input_txt.split("\n")
85
+ seq=0
86
+ each_num=1500
87
+
88
+ punctuation = '!,。、,' # punctuate and split/clean text
89
+ processed_texts = []
90
+ cur_num = 0
91
+ for text in texts:
92
+ for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
93
+ if processed_text:
94
+ processed_texts.append(processed_text.strip())
95
+ cur_num += len(processed_text.strip())
96
+ if cur_num > each_num:
97
+ seq = seq +1
98
+ gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
99
+ processed_texts = []
100
+ cur_num = 0
101
+
102
+ if len(processed_texts)>0:
103
+ seq = seq +1
104
+ gen_one_wav(synthesizer, in_fpath, embed, processed_texts, file_name, seq)
105
+
106
+ if (len(sys.argv)>=3):
107
+ my_txt = ""
108
+ print("reading from :", sys.argv[1])
109
+ with open(sys.argv[1], "r") as f:
110
+ for line in f.readlines():
111
+ #line = line.strip('\n')
112
+ my_txt += line
113
+ txt_file_name = sys.argv[1]
114
+ wav_file_name = sys.argv[2]
115
+
116
+ output = cn2an.transform(my_txt, "an2cn")
117
+ print(output)
118
+ generate_wav(
119
+ Path("encoder/saved_models/pretrained.pt"),
120
+ Path("synthesizer/saved_models/mandarin.pt"),
121
+ Path("vocoder/saved_models/pretrained/g_hifigan.pt"), wav_file_name, output, txt_file_name
122
+ )
123
+
124
+ else:
125
+ print("please input the file name")
126
+ exit(1)
127
+
128
+
mkgui/__init__.py ADDED
File without changes
mkgui/app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ import os
3
+ from pathlib import Path
4
+ from enum import Enum
5
+ from encoder import inference as encoder
6
+ import librosa
7
+ from scipy.io.wavfile import write
8
+ import re
9
+ import numpy as np
10
+ from mkgui.base.components.types import FileContent
11
+ from vocoder.hifigan import inference as gan_vocoder
12
+ from synthesizer.inference import Synthesizer
13
+ from typing import Any, Tuple
14
+ import matplotlib.pyplot as plt
15
+
16
+ # Constants
17
+ AUDIO_SAMPLES_DIR = f"samples{os.sep}"
18
+ SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
19
+ ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
20
+ VOC_MODELS_DIRT = f"vocoder{os.sep}saved_models"
21
+ TEMP_SOURCE_AUDIO = f"wavs{os.sep}temp_source.wav"
22
+ TEMP_RESULT_AUDIO = f"wavs{os.sep}temp_result.wav"
23
+ if not os.path.isdir("wavs"):
24
+ os.makedirs("wavs")
25
+
26
+ # Load local sample audio as options TODO: load dataset
27
+ if os.path.isdir(AUDIO_SAMPLES_DIR):
28
+ audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
29
+ # Pre-Load models
30
+ if os.path.isdir(SYN_MODELS_DIRT):
31
+ synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
32
+ print("Loaded synthesizer models: " + str(len(synthesizers)))
33
+ else:
34
+ raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
35
+
36
+ if os.path.isdir(ENC_MODELS_DIRT):
37
+ encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
38
+ print("Loaded encoders models: " + str(len(encoders)))
39
+ else:
40
+ raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
41
+
42
+ if os.path.isdir(VOC_MODELS_DIRT):
43
+ vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
44
+ print("Loaded vocoders models: " + str(len(synthesizers)))
45
+ else:
46
+ raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
47
+
48
+
49
+
50
+ class Input(BaseModel):
51
+ message: str = Field(
52
+ ..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
53
+ )
54
+ local_audio_file: audio_input_selection = Field(
55
+ ..., alias="输入语音(本地wav)",
56
+ description="选择本地语音文件."
57
+ )
58
+ upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
59
+ description="拖拽或点击上传.", mime_type="audio/wav")
60
+ encoder: encoders = Field(
61
+ ..., alias="编码模型",
62
+ description="选择语音编码模型文件."
63
+ )
64
+ synthesizer: synthesizers = Field(
65
+ ..., alias="合成模型",
66
+ description="选择语音合成模型文件."
67
+ )
68
+ vocoder: vocoders = Field(
69
+ ..., alias="语音解码模型",
70
+ description="选择语音解码模型文件(目前只支持HifiGan类型)."
71
+ )
72
+
73
+ class AudioEntity(BaseModel):
74
+ content: bytes
75
+ mel: Any
76
+
77
+ class Output(BaseModel):
78
+ __root__: Tuple[AudioEntity, AudioEntity]
79
+
80
+ def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
81
+ """Custom output UI.
82
+ If this method is implmeneted, it will be used instead of the default Output UI renderer.
83
+ """
84
+ src, result = self.__root__
85
+
86
+ streamlit_app.subheader("Synthesized Audio")
87
+ streamlit_app.audio(result.content, format="audio/wav")
88
+
89
+ fig, ax = plt.subplots()
90
+ ax.imshow(src.mel, aspect="equal", interpolation="none")
91
+ ax.set_title("mel spectrogram(Source Audio)")
92
+ streamlit_app.pyplot(fig)
93
+ fig, ax = plt.subplots()
94
+ ax.imshow(result.mel, aspect="equal", interpolation="none")
95
+ ax.set_title("mel spectrogram(Result Audio)")
96
+ streamlit_app.pyplot(fig)
97
+
98
+
99
+ def synthesize(input: Input) -> Output:
100
+ """synthesize(合成)"""
101
+ # load models
102
+ encoder.load_model(Path(input.encoder.value))
103
+ current_synt = Synthesizer(Path(input.synthesizer.value))
104
+ gan_vocoder.load_model(Path(input.vocoder.value))
105
+
106
+ # load file
107
+ if input.upload_audio_file != None:
108
+ with open(TEMP_SOURCE_AUDIO, "w+b") as f:
109
+ f.write(input.upload_audio_file.as_bytes())
110
+ f.seek(0)
111
+ wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
112
+ else:
113
+ wav, sample_rate = librosa.load(input.local_audio_file.value)
114
+ write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
115
+
116
+ source_spec = Synthesizer.make_spectrogram(wav)
117
+
118
+ # preprocess
119
+ encoder_wav = encoder.preprocess_wav(wav, sample_rate)
120
+ embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
121
+
122
+ # Load input text
123
+ texts = filter(None, input.message.split("\n"))
124
+ punctuation = '!,。、,' # punctuate and split/clean text
125
+ processed_texts = []
126
+ for text in texts:
127
+ for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'):
128
+ if processed_text:
129
+ processed_texts.append(processed_text.strip())
130
+ texts = processed_texts
131
+
132
+ # synthesize and vocode
133
+ embeds = [embed] * len(texts)
134
+ specs = current_synt.synthesize_spectrograms(texts, embeds)
135
+ spec = np.concatenate(specs, axis=1)
136
+ sample_rate = Synthesizer.sample_rate
137
+ wav, sample_rate = gan_vocoder.infer_waveform(spec)
138
+
139
+ # write and output
140
+ write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
141
+ with open(TEMP_SOURCE_AUDIO, "rb") as f:
142
+ source_file = f.read()
143
+ with open(TEMP_RESULT_AUDIO, "rb") as f:
144
+ result_file = f.read()
145
+ return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
mkgui/app_vc.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.inference import Synthesizer
2
+ from pydantic import BaseModel, Field
3
+ from encoder import inference as speacker_encoder
4
+ import torch
5
+ import os
6
+ from pathlib import Path
7
+ from enum import Enum
8
+ import ppg_extractor as Extractor
9
+ import ppg2mel as Convertor
10
+ import librosa
11
+ from scipy.io.wavfile import write
12
+ import re
13
+ import numpy as np
14
+ from mkgui.base.components.types import FileContent
15
+ from vocoder.hifigan import inference as gan_vocoder
16
+ from typing import Any, Tuple
17
+ import matplotlib.pyplot as plt
18
+
19
+
20
+ # Constants
21
+ AUDIO_SAMPLES_DIR = f'sample{os.sep}'
22
+ EXT_MODELS_DIRT = f'ppg_extractor{os.sep}saved_models'
23
+ CONV_MODELS_DIRT = f'ppg2mel{os.sep}saved_models'
24
+ VOC_MODELS_DIRT = f'vocoder{os.sep}saved_models'
25
+ TEMP_SOURCE_AUDIO = f'wavs{os.sep}temp_source.wav'
26
+ TEMP_TARGET_AUDIO = f'wavs{os.sep}temp_target.wav'
27
+ TEMP_RESULT_AUDIO = f'wavs{os.sep}temp_result.wav'
28
+
29
+ # Load local sample audio as options TODO: load dataset
30
+ if os.path.isdir(AUDIO_SAMPLES_DIR):
31
+ audio_input_selection = Enum('samples', list((file.name, file) for file in Path(AUDIO_SAMPLES_DIR).glob("*.wav")))
32
+ # Pre-Load models
33
+ if os.path.isdir(EXT_MODELS_DIRT):
34
+ extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
35
+ print("Loaded extractor models: " + str(len(extractors)))
36
+ else:
37
+ raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
38
+
39
+ if os.path.isdir(CONV_MODELS_DIRT):
40
+ convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
41
+ print("Loaded convertor models: " + str(len(convertors)))
42
+ else:
43
+ raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
44
+
45
+ if os.path.isdir(VOC_MODELS_DIRT):
46
+ vocoders = Enum('vocoders', list((file.name, file) for file in Path(VOC_MODELS_DIRT).glob("**/*gan*.pt")))
47
+ print("Loaded vocoders models: " + str(len(vocoders)))
48
+ else:
49
+ raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
50
+
51
+ class Input(BaseModel):
52
+ local_audio_file: audio_input_selection = Field(
53
+ ..., alias="输入语音(本地wav)",
54
+ description="选择本地语音文件."
55
+ )
56
+ upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
57
+ description="拖拽或点击上传.", mime_type="audio/wav")
58
+ local_audio_file_target: audio_input_selection = Field(
59
+ ..., alias="目标语音(本地wav)",
60
+ description="选择本地语音文件."
61
+ )
62
+ upload_audio_file_target: FileContent = Field(default=None, alias="或上传目标语音",
63
+ description="拖拽或点击上传.", mime_type="audio/wav")
64
+ extractor: extractors = Field(
65
+ ..., alias="编码模型",
66
+ description="选择语音编码模型文件."
67
+ )
68
+ convertor: convertors = Field(
69
+ ..., alias="转换模型",
70
+ description="选择语音转换模型文件."
71
+ )
72
+ vocoder: vocoders = Field(
73
+ ..., alias="语音解码模型",
74
+ description="选择语音解码模型文件(目前只支持HifiGan类型)."
75
+ )
76
+
77
+ class AudioEntity(BaseModel):
78
+ content: bytes
79
+ mel: Any
80
+
81
+ class Output(BaseModel):
82
+ __root__: Tuple[AudioEntity, AudioEntity, AudioEntity]
83
+
84
+ def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
85
+ """Custom output UI.
86
+ If this method is implmeneted, it will be used instead of the default Output UI renderer.
87
+ """
88
+ src, target, result = self.__root__
89
+
90
+ streamlit_app.subheader("Synthesized Audio")
91
+ streamlit_app.audio(result.content, format="audio/wav")
92
+
93
+ fig, ax = plt.subplots()
94
+ ax.imshow(src.mel, aspect="equal", interpolation="none")
95
+ ax.set_title("mel spectrogram(Source Audio)")
96
+ streamlit_app.pyplot(fig)
97
+ fig, ax = plt.subplots()
98
+ ax.imshow(target.mel, aspect="equal", interpolation="none")
99
+ ax.set_title("mel spectrogram(Target Audio)")
100
+ streamlit_app.pyplot(fig)
101
+ fig, ax = plt.subplots()
102
+ ax.imshow(result.mel, aspect="equal", interpolation="none")
103
+ ax.set_title("mel spectrogram(Result Audio)")
104
+ streamlit_app.pyplot(fig)
105
+
106
+ def convert(input: Input) -> Output:
107
+ """convert(转换)"""
108
+ # load models
109
+ extractor = Extractor.load_model(Path(input.extractor.value))
110
+ convertor = Convertor.load_model(Path(input.convertor.value))
111
+ # current_synt = Synthesizer(Path(input.synthesizer.value))
112
+ gan_vocoder.load_model(Path(input.vocoder.value))
113
+
114
+ # load file
115
+ if input.upload_audio_file != None:
116
+ with open(TEMP_SOURCE_AUDIO, "w+b") as f:
117
+ f.write(input.upload_audio_file.as_bytes())
118
+ f.seek(0)
119
+ src_wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
120
+ else:
121
+ src_wav, sample_rate = librosa.load(input.local_audio_file.value)
122
+ write(TEMP_SOURCE_AUDIO, sample_rate, src_wav) #Make sure we get the correct wav
123
+
124
+ if input.upload_audio_file_target != None:
125
+ with open(TEMP_TARGET_AUDIO, "w+b") as f:
126
+ f.write(input.upload_audio_file_target.as_bytes())
127
+ f.seek(0)
128
+ ref_wav, _ = librosa.load(TEMP_TARGET_AUDIO)
129
+ else:
130
+ ref_wav, _ = librosa.load(input.local_audio_file_target.value)
131
+ write(TEMP_TARGET_AUDIO, sample_rate, ref_wav) #Make sure we get the correct wav
132
+
133
+ ppg = extractor.extract_from_wav(src_wav)
134
+ # Import necessary dependency of Voice Conversion
135
+ from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
136
+ ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
137
+ speacker_encoder.load_model(Path("encoder{os.sep}saved_models{os.sep}pretrained_bak_5805000.pt"))
138
+ embed = speacker_encoder.embed_utterance(ref_wav)
139
+ lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
140
+ min_len = min(ppg.shape[1], len(lf0_uv))
141
+ ppg = ppg[:, :min_len]
142
+ lf0_uv = lf0_uv[:min_len]
143
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144
+ _, mel_pred, att_ws = convertor.inference(
145
+ ppg,
146
+ logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
147
+ spembs=torch.from_numpy(embed).unsqueeze(0).to(device),
148
+ )
149
+ mel_pred= mel_pred.transpose(0, 1)
150
+ breaks = [mel_pred.shape[1]]
151
+ mel_pred= mel_pred.detach().cpu().numpy()
152
+
153
+ # synthesize and vocode
154
+ wav, sample_rate = gan_vocoder.infer_waveform(mel_pred)
155
+
156
+ # write and output
157
+ write(TEMP_RESULT_AUDIO, sample_rate, wav) #Make sure we get the correct wav
158
+ with open(TEMP_SOURCE_AUDIO, "rb") as f:
159
+ source_file = f.read()
160
+ with open(TEMP_TARGET_AUDIO, "rb") as f:
161
+ target_file = f.read()
162
+ with open(TEMP_RESULT_AUDIO, "rb") as f:
163
+ result_file = f.read()
164
+
165
+
166
+ return Output(__root__=(AudioEntity(content=source_file, mel=Synthesizer.make_spectrogram(src_wav)), AudioEntity(content=target_file, mel=Synthesizer.make_spectrogram(ref_wav)), AudioEntity(content=result_file, mel=Synthesizer.make_spectrogram(wav))))
mkgui/base/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ from .core import Opyrator
mkgui/base/api/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .fastapi_app import create_api
mkgui/base/api/fastapi_utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Collection of utilities for FastAPI apps."""
2
+
3
+ import inspect
4
+ from typing import Any, Type
5
+
6
+ from fastapi import FastAPI, Form
7
+ from pydantic import BaseModel
8
+
9
+
10
+ def as_form(cls: Type[BaseModel]) -> Any:
11
+ """Adds an as_form class method to decorated models.
12
+
13
+ The as_form class method can be used with FastAPI endpoints
14
+ """
15
+ new_params = [
16
+ inspect.Parameter(
17
+ field.alias,
18
+ inspect.Parameter.POSITIONAL_ONLY,
19
+ default=(Form(field.default) if not field.required else Form(...)),
20
+ )
21
+ for field in cls.__fields__.values()
22
+ ]
23
+
24
+ async def _as_form(**data): # type: ignore
25
+ return cls(**data)
26
+
27
+ sig = inspect.signature(_as_form)
28
+ sig = sig.replace(parameters=new_params)
29
+ _as_form.__signature__ = sig # type: ignore
30
+ setattr(cls, "as_form", _as_form)
31
+ return cls
32
+
33
+
34
+ def patch_fastapi(app: FastAPI) -> None:
35
+ """Patch function to allow relative url resolution.
36
+
37
+ This patch is required to make fastapi fully functional with a relative url path.
38
+ This code snippet can be copy-pasted to any Fastapi application.
39
+ """
40
+ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
41
+ from starlette.requests import Request
42
+ from starlette.responses import HTMLResponse
43
+
44
+ async def redoc_ui_html(req: Request) -> HTMLResponse:
45
+ assert app.openapi_url is not None
46
+ redoc_ui = get_redoc_html(
47
+ openapi_url="./" + app.openapi_url.lstrip("/"),
48
+ title=app.title + " - Redoc UI",
49
+ )
50
+
51
+ return HTMLResponse(redoc_ui.body.decode("utf-8"))
52
+
53
+ async def swagger_ui_html(req: Request) -> HTMLResponse:
54
+ assert app.openapi_url is not None
55
+ swagger_ui = get_swagger_ui_html(
56
+ openapi_url="./" + app.openapi_url.lstrip("/"),
57
+ title=app.title + " - Swagger UI",
58
+ oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
59
+ )
60
+
61
+ # insert request interceptor to have all request run on relativ path
62
+ request_interceptor = (
63
+ "requestInterceptor: (e) => {"
64
+ "\n\t\t\tvar url = window.location.origin + window.location.pathname"
65
+ '\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
66
+ "\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
67
+ "\n\t\t\te.contextUrl = url"
68
+ "\n\t\t\te.url = url"
69
+ "\n\t\t\treturn e;}"
70
+ )
71
+
72
+ return HTMLResponse(
73
+ swagger_ui.body.decode("utf-8").replace(
74
+ "dom_id: '#swagger-ui',",
75
+ "dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
76
+ )
77
+ )
78
+
79
+ # remove old docs route and add our patched route
80
+ routes_new = []
81
+ for app_route in app.routes:
82
+ if app_route.path == "/docs": # type: ignore
83
+ continue
84
+
85
+ if app_route.path == "/redoc": # type: ignore
86
+ continue
87
+
88
+ routes_new.append(app_route)
89
+
90
+ app.router.routes = routes_new
91
+
92
+ assert app.docs_url is not None
93
+ app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
94
+ assert app.redoc_url is not None
95
+ app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
96
+
97
+ # Make graphql realtive
98
+ from starlette import graphql
99
+
100
+ graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
101
+ "({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
102
+ )
mkgui/base/components/__init__.py ADDED
File without changes
mkgui/base/components/outputs.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ScoredLabel(BaseModel):
7
+ label: str
8
+ score: float
9
+
10
+
11
+ class ClassificationOutput(BaseModel):
12
+ __root__: List[ScoredLabel]
13
+
14
+ def __iter__(self): # type: ignore
15
+ return iter(self.__root__)
16
+
17
+ def __getitem__(self, item): # type: ignore
18
+ return self.__root__[item]
19
+
20
+ def render_output_ui(self, streamlit) -> None: # type: ignore
21
+ import plotly.express as px
22
+
23
+ sorted_predictions = sorted(
24
+ [prediction.dict() for prediction in self.__root__],
25
+ key=lambda k: k["score"],
26
+ )
27
+
28
+ num_labels = len(sorted_predictions)
29
+ if len(sorted_predictions) > 10:
30
+ num_labels = streamlit.slider(
31
+ "Maximum labels to show: ",
32
+ min_value=1,
33
+ max_value=len(sorted_predictions),
34
+ value=len(sorted_predictions),
35
+ )
36
+ fig = px.bar(
37
+ sorted_predictions[len(sorted_predictions) - num_labels :],
38
+ x="score",
39
+ y="label",
40
+ orientation="h",
41
+ )
42
+ streamlit.plotly_chart(fig, use_container_width=True)
43
+ # fig.show()
mkgui/base/components/types.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from typing import Any, Dict, overload
3
+
4
+
5
+ class FileContent(str):
6
+ def as_bytes(self) -> bytes:
7
+ return base64.b64decode(self, validate=True)
8
+
9
+ def as_str(self) -> str:
10
+ return self.as_bytes().decode()
11
+
12
+ @classmethod
13
+ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
14
+ field_schema.update(format="byte")
15
+
16
+ @classmethod
17
+ def __get_validators__(cls) -> Any: # type: ignore
18
+ yield cls.validate
19
+
20
+ @classmethod
21
+ def validate(cls, value: Any) -> "FileContent":
22
+ if isinstance(value, FileContent):
23
+ return value
24
+ elif isinstance(value, str):
25
+ return FileContent(value)
26
+ elif isinstance(value, (bytes, bytearray, memoryview)):
27
+ return FileContent(base64.b64encode(value).decode())
28
+ else:
29
+ raise Exception("Wrong type")
30
+
31
+ # # 暂时无法使用,因为浏览器中没有考虑选择文件夹
32
+ # class DirectoryContent(FileContent):
33
+ # @classmethod
34
+ # def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
35
+ # field_schema.update(format="path")
36
+
37
+ # @classmethod
38
+ # def validate(cls, value: Any) -> "DirectoryContent":
39
+ # if isinstance(value, DirectoryContent):
40
+ # return value
41
+ # elif isinstance(value, str):
42
+ # return DirectoryContent(value)
43
+ # elif isinstance(value, (bytes, bytearray, memoryview)):
44
+ # return DirectoryContent(base64.b64encode(value).decode())
45
+ # else:
46
+ # raise Exception("Wrong type")
mkgui/base/core.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import inspect
3
+ import re
4
+ from typing import Any, Callable, Type, Union, get_type_hints
5
+
6
+ from pydantic import BaseModel, parse_raw_as
7
+ from pydantic.tools import parse_obj_as
8
+
9
+
10
+ def name_to_title(name: str) -> str:
11
+ """Converts a camelCase or snake_case name to title case."""
12
+ # If camelCase -> convert to snake case
13
+ name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
14
+ name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
15
+ # Convert to title case
16
+ return name.replace("_", " ").strip().title()
17
+
18
+
19
+ def is_compatible_type(type: Type) -> bool:
20
+ """Returns `True` if the type is opyrator-compatible."""
21
+ try:
22
+ if issubclass(type, BaseModel):
23
+ return True
24
+ except Exception:
25
+ pass
26
+
27
+ try:
28
+ # valid list type
29
+ if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
30
+ return True
31
+ except Exception:
32
+ pass
33
+
34
+ return False
35
+
36
+
37
+ def get_input_type(func: Callable) -> Type:
38
+ """Returns the input type of a given function (callable).
39
+
40
+ Args:
41
+ func: The function for which to get the input type.
42
+
43
+ Raises:
44
+ ValueError: If the function does not have a valid input type annotation.
45
+ """
46
+ type_hints = get_type_hints(func)
47
+
48
+ if "input" not in type_hints:
49
+ raise ValueError(
50
+ "The callable MUST have a parameter with the name `input` with typing annotation. "
51
+ "For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
52
+ )
53
+
54
+ input_type = type_hints["input"]
55
+
56
+ if not is_compatible_type(input_type):
57
+ raise ValueError(
58
+ "The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
59
+ )
60
+
61
+ # TODO: return warning if more than one input parameters
62
+
63
+ return input_type
64
+
65
+
66
+ def get_output_type(func: Callable) -> Type:
67
+ """Returns the output type of a given function (callable).
68
+
69
+ Args:
70
+ func: The function for which to get the output type.
71
+
72
+ Raises:
73
+ ValueError: If the function does not have a valid output type annotation.
74
+ """
75
+ type_hints = get_type_hints(func)
76
+ if "return" not in type_hints:
77
+ raise ValueError(
78
+ "The return type of the callable MUST be annotated with type hints."
79
+ "For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
80
+ )
81
+
82
+ output_type = type_hints["return"]
83
+
84
+ if not is_compatible_type(output_type):
85
+ raise ValueError(
86
+ "The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
87
+ )
88
+
89
+ return output_type
90
+
91
+
92
+ def get_callable(import_string: str) -> Callable:
93
+ """Import a callable from an string."""
94
+ callable_seperator = ":"
95
+ if callable_seperator not in import_string:
96
+ # Use dot as seperator
97
+ callable_seperator = "."
98
+
99
+ if callable_seperator not in import_string:
100
+ raise ValueError("The callable path MUST specify the function. ")
101
+
102
+ mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
103
+ mod = importlib.import_module(mod_name)
104
+ return getattr(mod, callable_name)
105
+
106
+
107
+ class Opyrator:
108
+ def __init__(self, func: Union[Callable, str]) -> None:
109
+ if isinstance(func, str):
110
+ # Try to load the function from a string notion
111
+ self.function = get_callable(func)
112
+ else:
113
+ self.function = func
114
+
115
+ self._action = "Execute"
116
+ self._input_type = None
117
+ self._output_type = None
118
+
119
+ if not callable(self.function):
120
+ raise ValueError("The provided function parameters is not a callable.")
121
+
122
+ if inspect.isclass(self.function):
123
+ raise ValueError(
124
+ "The provided callable is an uninitialized Class. This is not allowed."
125
+ )
126
+
127
+ if inspect.isfunction(self.function):
128
+ # The provided callable is a function
129
+ self._input_type = get_input_type(self.function)
130
+ self._output_type = get_output_type(self.function)
131
+
132
+ try:
133
+ # Get name
134
+ self._name = name_to_title(self.function.__name__)
135
+ except Exception:
136
+ pass
137
+
138
+ try:
139
+ # Get description from function
140
+ doc_string = inspect.getdoc(self.function)
141
+ if doc_string:
142
+ self._action = doc_string
143
+ except Exception:
144
+ pass
145
+ elif hasattr(self.function, "__call__"):
146
+ # The provided callable is a function
147
+ self._input_type = get_input_type(self.function.__call__) # type: ignore
148
+ self._output_type = get_output_type(self.function.__call__) # type: ignore
149
+
150
+ try:
151
+ # Get name
152
+ self._name = name_to_title(type(self.function).__name__)
153
+ except Exception:
154
+ pass
155
+
156
+ try:
157
+ # Get action from
158
+ doc_string = inspect.getdoc(self.function.__call__) # type: ignore
159
+ if doc_string:
160
+ self._action = doc_string
161
+
162
+ if (
163
+ not self._action
164
+ or self._action == "Call"
165
+ ):
166
+ # Get docstring from class instead of __call__ function
167
+ doc_string = inspect.getdoc(self.function)
168
+ if doc_string:
169
+ self._action = doc_string
170
+ except Exception:
171
+ pass
172
+ else:
173
+ raise ValueError("Unknown callable type.")
174
+
175
+ @property
176
+ def name(self) -> str:
177
+ return self._name
178
+
179
+ @property
180
+ def action(self) -> str:
181
+ return self._action
182
+
183
+ @property
184
+ def input_type(self) -> Any:
185
+ return self._input_type
186
+
187
+ @property
188
+ def output_type(self) -> Any:
189
+ return self._output_type
190
+
191
+ def __call__(self, input: Any, **kwargs: Any) -> Any:
192
+
193
+ input_obj = input
194
+
195
+ if isinstance(input, str):
196
+ # Allow json input
197
+ input_obj = parse_raw_as(self.input_type, input)
198
+
199
+ if isinstance(input, dict):
200
+ # Allow dict input
201
+ input_obj = parse_obj_as(self.input_type, input)
202
+
203
+ return self.function(input_obj, **kwargs)
mkgui/base/ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .streamlit_ui import render_streamlit_ui
mkgui/base/ui/schema_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+
4
+ def resolve_reference(reference: str, references: Dict) -> Dict:
5
+ return references[reference.split("/")[-1]]
6
+
7
+
8
+ def get_single_reference_item(property: Dict, references: Dict) -> Dict:
9
+ # Ref can either be directly in the properties or the first element of allOf
10
+ reference = property.get("$ref")
11
+ if reference is None:
12
+ reference = property["allOf"][0]["$ref"]
13
+ return resolve_reference(reference, references)
14
+
15
+
16
+ def is_single_string_property(property: Dict) -> bool:
17
+ return property.get("type") == "string"
18
+
19
+
20
+ def is_single_datetime_property(property: Dict) -> bool:
21
+ if property.get("type") != "string":
22
+ return False
23
+ return property.get("format") in ["date-time", "time", "date"]
24
+
25
+
26
+ def is_single_boolean_property(property: Dict) -> bool:
27
+ return property.get("type") == "boolean"
28
+
29
+
30
+ def is_single_number_property(property: Dict) -> bool:
31
+ return property.get("type") in ["integer", "number"]
32
+
33
+
34
+ def is_single_file_property(property: Dict) -> bool:
35
+ if property.get("type") != "string":
36
+ return False
37
+ # TODO: binary?
38
+ return property.get("format") == "byte"
39
+
40
+
41
+ def is_single_directory_property(property: Dict) -> bool:
42
+ if property.get("type") != "string":
43
+ return False
44
+ return property.get("format") == "path"
45
+
46
+ def is_multi_enum_property(property: Dict, references: Dict) -> bool:
47
+ if property.get("type") != "array":
48
+ return False
49
+
50
+ if property.get("uniqueItems") is not True:
51
+ # Only relevant if it is a set or other datastructures with unique items
52
+ return False
53
+
54
+ try:
55
+ _ = resolve_reference(property["items"]["$ref"], references)["enum"]
56
+ return True
57
+ except Exception:
58
+ return False
59
+
60
+
61
+ def is_single_enum_property(property: Dict, references: Dict) -> bool:
62
+ try:
63
+ _ = get_single_reference_item(property, references)["enum"]
64
+ return True
65
+ except Exception:
66
+ return False
67
+
68
+
69
+ def is_single_dict_property(property: Dict) -> bool:
70
+ if property.get("type") != "object":
71
+ return False
72
+ return "additionalProperties" in property
73
+
74
+
75
+ def is_single_reference(property: Dict) -> bool:
76
+ if property.get("type") is not None:
77
+ return False
78
+
79
+ return bool(property.get("$ref"))
80
+
81
+
82
+ def is_multi_file_property(property: Dict) -> bool:
83
+ if property.get("type") != "array":
84
+ return False
85
+
86
+ if property.get("items") is None:
87
+ return False
88
+
89
+ try:
90
+ # TODO: binary
91
+ return property["items"]["format"] == "byte"
92
+ except Exception:
93
+ return False
94
+
95
+
96
+ def is_single_object(property: Dict, references: Dict) -> bool:
97
+ try:
98
+ object_reference = get_single_reference_item(property, references)
99
+ if object_reference["type"] != "object":
100
+ return False
101
+ return "properties" in object_reference
102
+ except Exception:
103
+ return False
104
+
105
+
106
+ def is_property_list(property: Dict) -> bool:
107
+ if property.get("type") != "array":
108
+ return False
109
+
110
+ if property.get("items") is None:
111
+ return False
112
+
113
+ try:
114
+ return property["items"]["type"] in ["string", "number", "integer"]
115
+ except Exception:
116
+ return False
117
+
118
+
119
+ def is_object_list_property(property: Dict, references: Dict) -> bool:
120
+ if property.get("type") != "array":
121
+ return False
122
+
123
+ try:
124
+ object_reference = resolve_reference(property["items"]["$ref"], references)
125
+ if object_reference["type"] != "object":
126
+ return False
127
+ return "properties" in object_reference
128
+ except Exception:
129
+ return False
mkgui/base/ui/streamlit_ui.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import inspect
3
+ import mimetypes
4
+ import sys
5
+ from os import getcwd, unlink
6
+ from platform import system
7
+ from tempfile import NamedTemporaryFile
8
+ from typing import Any, Callable, Dict, List, Type
9
+ from PIL import Image
10
+
11
+ import pandas as pd
12
+ import streamlit as st
13
+ from fastapi.encoders import jsonable_encoder
14
+ from loguru import logger
15
+ from pydantic import BaseModel, ValidationError, parse_obj_as
16
+
17
+ from mkgui.base import Opyrator
18
+ from mkgui.base.core import name_to_title
19
+ from mkgui.base.ui import schema_utils
20
+ from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
21
+
22
+ STREAMLIT_RUNNER_SNIPPET = """
23
+ from mkgui.base.ui import render_streamlit_ui
24
+ from mkgui.base import Opyrator
25
+
26
+ import streamlit as st
27
+
28
+ # TODO: Make it configurable
29
+ # Page config can only be setup once
30
+ st.set_page_config(
31
+ page_title="MockingBird",
32
+ page_icon="🧊",
33
+ layout="wide")
34
+
35
+ render_streamlit_ui()
36
+ """
37
+
38
+ # with st.spinner("Loading MockingBird GUI. Please wait..."):
39
+ # opyrator = Opyrator("{opyrator_path}")
40
+
41
+
42
+ def launch_ui(port: int = 8501) -> None:
43
+ with NamedTemporaryFile(
44
+ suffix=".py", mode="w", encoding="utf-8", delete=False
45
+ ) as f:
46
+ f.write(STREAMLIT_RUNNER_SNIPPET)
47
+ f.seek(0)
48
+
49
+ import subprocess
50
+
51
+ python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
52
+ if system() == "Windows":
53
+ python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
54
+ subprocess.run(
55
+ f"""set STREAMLIT_GLOBAL_SHOW_WARNING_ON_DIRECT_EXECUTION=false""",
56
+ shell=True,
57
+ )
58
+
59
+ subprocess.run(
60
+ f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
61
+ shell=True,
62
+ )
63
+
64
+ f.close()
65
+ unlink(f.name)
66
+
67
+
68
+ def function_has_named_arg(func: Callable, parameter: str) -> bool:
69
+ try:
70
+ sig = inspect.signature(func)
71
+ for param in sig.parameters.values():
72
+ if param.name == "input":
73
+ return True
74
+ except Exception:
75
+ return False
76
+ return False
77
+
78
+
79
+ def has_output_ui_renderer(data_item: BaseModel) -> bool:
80
+ return hasattr(data_item, "render_output_ui")
81
+
82
+
83
+ def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
84
+ return hasattr(input_class, "render_input_ui")
85
+
86
+
87
+ def is_compatible_audio(mime_type: str) -> bool:
88
+ return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
89
+
90
+
91
+ def is_compatible_image(mime_type: str) -> bool:
92
+ return mime_type in ["image/png", "image/jpeg"]
93
+
94
+
95
+ def is_compatible_video(mime_type: str) -> bool:
96
+ return mime_type in ["video/mp4"]
97
+
98
+
99
+ class InputUI:
100
+ def __init__(self, session_state, input_class: Type[BaseModel]):
101
+ self._session_state = session_state
102
+ self._input_class = input_class
103
+
104
+ self._schema_properties = input_class.schema(by_alias=True).get(
105
+ "properties", {}
106
+ )
107
+ self._schema_references = input_class.schema(by_alias=True).get(
108
+ "definitions", {}
109
+ )
110
+
111
+ def render_ui(self, streamlit_app_root) -> None:
112
+ if has_input_ui_renderer(self._input_class):
113
+ # The input model has a rendering function
114
+ # The rendering also returns the current state of input data
115
+ self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
116
+ st, self._session_state.input_data
117
+ )
118
+ return
119
+
120
+ # print(self._schema_properties)
121
+ for property_key in self._schema_properties.keys():
122
+ property = self._schema_properties[property_key]
123
+
124
+ if not property.get("title"):
125
+ # Set property key as fallback title
126
+ property["title"] = name_to_title(property_key)
127
+
128
+ try:
129
+ if "input_data" in self._session_state:
130
+ self._store_value(
131
+ property_key,
132
+ self._render_property(streamlit_app_root, property_key, property),
133
+ )
134
+ except Exception as e:
135
+ print("Exception!", e)
136
+ pass
137
+
138
+ def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
139
+ streamlit_kwargs = {
140
+ "label": property.get("title"),
141
+ "key": key,
142
+ }
143
+
144
+ if property.get("description"):
145
+ streamlit_kwargs["help"] = property.get("description")
146
+ return streamlit_kwargs
147
+
148
+ def _store_value(self, key: str, value: Any) -> None:
149
+ data_element = self._session_state.input_data
150
+ key_elements = key.split(".")
151
+ for i, key_element in enumerate(key_elements):
152
+ if i == len(key_elements) - 1:
153
+ # add value to this element
154
+ data_element[key_element] = value
155
+ return
156
+ if key_element not in data_element:
157
+ data_element[key_element] = {}
158
+ data_element = data_element[key_element]
159
+
160
+ def _get_value(self, key: str) -> Any:
161
+ data_element = self._session_state.input_data
162
+ key_elements = key.split(".")
163
+ for i, key_element in enumerate(key_elements):
164
+ if i == len(key_elements) - 1:
165
+ # add value to this element
166
+ if key_element not in data_element:
167
+ return None
168
+ return data_element[key_element]
169
+ if key_element not in data_element:
170
+ data_element[key_element] = {}
171
+ data_element = data_element[key_element]
172
+ return None
173
+
174
+ def _render_single_datetime_input(
175
+ self, streamlit_app: st, key: str, property: Dict
176
+ ) -> Any:
177
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
178
+
179
+ if property.get("format") == "time":
180
+ if property.get("default"):
181
+ try:
182
+ streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
183
+ property.get("default")
184
+ )
185
+ except Exception:
186
+ pass
187
+ return streamlit_app.time_input(**streamlit_kwargs)
188
+ elif property.get("format") == "date":
189
+ if property.get("default"):
190
+ try:
191
+ streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
192
+ property.get("default")
193
+ )
194
+ except Exception:
195
+ pass
196
+ return streamlit_app.date_input(**streamlit_kwargs)
197
+ elif property.get("format") == "date-time":
198
+ if property.get("default"):
199
+ try:
200
+ streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
201
+ property.get("default")
202
+ )
203
+ except Exception:
204
+ pass
205
+ with streamlit_app.container():
206
+ streamlit_app.subheader(streamlit_kwargs.get("label"))
207
+ if streamlit_kwargs.get("description"):
208
+ streamlit_app.text(streamlit_kwargs.get("description"))
209
+ selected_date = None
210
+ selected_time = None
211
+ date_col, time_col = streamlit_app.columns(2)
212
+ with date_col:
213
+ date_kwargs = {"label": "Date", "key": key + "-date-input"}
214
+ if streamlit_kwargs.get("value"):
215
+ try:
216
+ date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
217
+ "value"
218
+ ).date()
219
+ except Exception:
220
+ pass
221
+ selected_date = streamlit_app.date_input(**date_kwargs)
222
+
223
+ with time_col:
224
+ time_kwargs = {"label": "Time", "key": key + "-time-input"}
225
+ if streamlit_kwargs.get("value"):
226
+ try:
227
+ time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
228
+ "value"
229
+ ).time()
230
+ except Exception:
231
+ pass
232
+ selected_time = streamlit_app.time_input(**time_kwargs)
233
+ return datetime.datetime.combine(selected_date, selected_time)
234
+ else:
235
+ streamlit_app.warning(
236
+ "Date format is not supported: " + str(property.get("format"))
237
+ )
238
+
239
+ def _render_single_file_input(
240
+ self, streamlit_app: st, key: str, property: Dict
241
+ ) -> Any:
242
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
243
+ file_extension = None
244
+ if "mime_type" in property:
245
+ file_extension = mimetypes.guess_extension(property["mime_type"])
246
+
247
+ uploaded_file = streamlit_app.file_uploader(
248
+ **streamlit_kwargs, accept_multiple_files=False, type=file_extension
249
+ )
250
+ if uploaded_file is None:
251
+ return None
252
+
253
+ bytes = uploaded_file.getvalue()
254
+ if property.get("mime_type"):
255
+ if is_compatible_audio(property["mime_type"]):
256
+ # Show audio
257
+ streamlit_app.audio(bytes, format=property.get("mime_type"))
258
+ if is_compatible_image(property["mime_type"]):
259
+ # Show image
260
+ streamlit_app.image(bytes)
261
+ if is_compatible_video(property["mime_type"]):
262
+ # Show video
263
+ streamlit_app.video(bytes, format=property.get("mime_type"))
264
+ return bytes
265
+
266
+ def _render_single_string_input(
267
+ self, streamlit_app: st, key: str, property: Dict
268
+ ) -> Any:
269
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
270
+
271
+ if property.get("default"):
272
+ streamlit_kwargs["value"] = property.get("default")
273
+ elif property.get("example"):
274
+ # TODO: also use example for other property types
275
+ # Use example as value if it is provided
276
+ streamlit_kwargs["value"] = property.get("example")
277
+
278
+ if property.get("maxLength") is not None:
279
+ streamlit_kwargs["max_chars"] = property.get("maxLength")
280
+
281
+ if (
282
+ property.get("format")
283
+ or (
284
+ property.get("maxLength") is not None
285
+ and int(property.get("maxLength")) < 140 # type: ignore
286
+ )
287
+ or property.get("writeOnly")
288
+ ):
289
+ # If any format is set, use single text input
290
+ # If max chars is set to less than 140, use single text input
291
+ # If write only -> password field
292
+ if property.get("writeOnly"):
293
+ streamlit_kwargs["type"] = "password"
294
+ return streamlit_app.text_input(**streamlit_kwargs)
295
+ else:
296
+ # Otherwise use multiline text area
297
+ return streamlit_app.text_area(**streamlit_kwargs)
298
+
299
+ def _render_multi_enum_input(
300
+ self, streamlit_app: st, key: str, property: Dict
301
+ ) -> Any:
302
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
303
+ reference_item = schema_utils.resolve_reference(
304
+ property["items"]["$ref"], self._schema_references
305
+ )
306
+ # TODO: how to select defaults
307
+ return streamlit_app.multiselect(
308
+ **streamlit_kwargs, options=reference_item["enum"]
309
+ )
310
+
311
+ def _render_single_enum_input(
312
+ self, streamlit_app: st, key: str, property: Dict
313
+ ) -> Any:
314
+
315
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
316
+ reference_item = schema_utils.get_single_reference_item(
317
+ property, self._schema_references
318
+ )
319
+
320
+ if property.get("default") is not None:
321
+ try:
322
+ streamlit_kwargs["index"] = reference_item["enum"].index(
323
+ property.get("default")
324
+ )
325
+ except Exception:
326
+ # Use default selection
327
+ pass
328
+
329
+ return streamlit_app.selectbox(
330
+ **streamlit_kwargs, options=reference_item["enum"]
331
+ )
332
+
333
+ def _render_single_dict_input(
334
+ self, streamlit_app: st, key: str, property: Dict
335
+ ) -> Any:
336
+
337
+ # Add title and subheader
338
+ streamlit_app.subheader(property.get("title"))
339
+ if property.get("description"):
340
+ streamlit_app.markdown(property.get("description"))
341
+
342
+ streamlit_app.markdown("---")
343
+
344
+ current_dict = self._get_value(key)
345
+ if not current_dict:
346
+ current_dict = {}
347
+
348
+ key_col, value_col = streamlit_app.columns(2)
349
+
350
+ with key_col:
351
+ updated_key = streamlit_app.text_input(
352
+ "Key", value="", key=key + "-new-key"
353
+ )
354
+
355
+ with value_col:
356
+ # TODO: also add boolean?
357
+ value_kwargs = {"label": "Value", "key": key + "-new-value"}
358
+ if property["additionalProperties"].get("type") == "integer":
359
+ value_kwargs["value"] = 0 # type: ignore
360
+ updated_value = streamlit_app.number_input(**value_kwargs)
361
+ elif property["additionalProperties"].get("type") == "number":
362
+ value_kwargs["value"] = 0.0 # type: ignore
363
+ value_kwargs["format"] = "%f"
364
+ updated_value = streamlit_app.number_input(**value_kwargs)
365
+ else:
366
+ value_kwargs["value"] = ""
367
+ updated_value = streamlit_app.text_input(**value_kwargs)
368
+
369
+ streamlit_app.markdown("---")
370
+
371
+ with streamlit_app.container():
372
+ clear_col, add_col = streamlit_app.columns([1, 2])
373
+
374
+ with clear_col:
375
+ if streamlit_app.button("Clear Items", key=key + "-clear-items"):
376
+ current_dict = {}
377
+
378
+ with add_col:
379
+ if (
380
+ streamlit_app.button("Add Item", key=key + "-add-item")
381
+ and updated_key
382
+ ):
383
+ current_dict[updated_key] = updated_value
384
+
385
+ streamlit_app.write(current_dict)
386
+
387
+ return current_dict
388
+
389
+ def _render_single_reference(
390
+ self, streamlit_app: st, key: str, property: Dict
391
+ ) -> Any:
392
+ reference_item = schema_utils.get_single_reference_item(
393
+ property, self._schema_references
394
+ )
395
+ return self._render_property(streamlit_app, key, reference_item)
396
+
397
+ def _render_multi_file_input(
398
+ self, streamlit_app: st, key: str, property: Dict
399
+ ) -> Any:
400
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
401
+
402
+ file_extension = None
403
+ if "mime_type" in property:
404
+ file_extension = mimetypes.guess_extension(property["mime_type"])
405
+
406
+ uploaded_files = streamlit_app.file_uploader(
407
+ **streamlit_kwargs, accept_multiple_files=True, type=file_extension
408
+ )
409
+ uploaded_files_bytes = []
410
+ if uploaded_files:
411
+ for uploaded_file in uploaded_files:
412
+ uploaded_files_bytes.append(uploaded_file.read())
413
+ return uploaded_files_bytes
414
+
415
+ def _render_single_boolean_input(
416
+ self, streamlit_app: st, key: str, property: Dict
417
+ ) -> Any:
418
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
419
+
420
+ if property.get("default"):
421
+ streamlit_kwargs["value"] = property.get("default")
422
+ return streamlit_app.checkbox(**streamlit_kwargs)
423
+
424
+ def _render_single_number_input(
425
+ self, streamlit_app: st, key: str, property: Dict
426
+ ) -> Any:
427
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
428
+
429
+ number_transform = int
430
+ if property.get("type") == "number":
431
+ number_transform = float # type: ignore
432
+ streamlit_kwargs["format"] = "%f"
433
+
434
+ if "multipleOf" in property:
435
+ # Set stepcount based on multiple of parameter
436
+ streamlit_kwargs["step"] = number_transform(property["multipleOf"])
437
+ elif number_transform == int:
438
+ # Set step size to 1 as default
439
+ streamlit_kwargs["step"] = 1
440
+ elif number_transform == float:
441
+ # Set step size to 0.01 as default
442
+ # TODO: adapt to default value
443
+ streamlit_kwargs["step"] = 0.01
444
+
445
+ if "minimum" in property:
446
+ streamlit_kwargs["min_value"] = number_transform(property["minimum"])
447
+ if "exclusiveMinimum" in property:
448
+ streamlit_kwargs["min_value"] = number_transform(
449
+ property["exclusiveMinimum"] + streamlit_kwargs["step"]
450
+ )
451
+ if "maximum" in property:
452
+ streamlit_kwargs["max_value"] = number_transform(property["maximum"])
453
+
454
+ if "exclusiveMaximum" in property:
455
+ streamlit_kwargs["max_value"] = number_transform(
456
+ property["exclusiveMaximum"] - streamlit_kwargs["step"]
457
+ )
458
+
459
+ if property.get("default") is not None:
460
+ streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
461
+ else:
462
+ if "min_value" in streamlit_kwargs:
463
+ streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
464
+ elif number_transform == int:
465
+ streamlit_kwargs["value"] = 0
466
+ else:
467
+ # Set default value to step
468
+ streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
469
+
470
+ if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
471
+ # TODO: Only if less than X steps
472
+ return streamlit_app.slider(**streamlit_kwargs)
473
+ else:
474
+ return streamlit_app.number_input(**streamlit_kwargs)
475
+
476
+ def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
477
+ properties = property["properties"]
478
+ object_inputs = {}
479
+ for property_key in properties:
480
+ property = properties[property_key]
481
+ if not property.get("title"):
482
+ # Set property key as fallback title
483
+ property["title"] = name_to_title(property_key)
484
+ # construct full key based on key parts -> required later to get the value
485
+ full_key = key + "." + property_key
486
+ object_inputs[property_key] = self._render_property(
487
+ streamlit_app, full_key, property
488
+ )
489
+ return object_inputs
490
+
491
+ def _render_single_object_input(
492
+ self, streamlit_app: st, key: str, property: Dict
493
+ ) -> Any:
494
+ # Add title and subheader
495
+ title = property.get("title")
496
+ streamlit_app.subheader(title)
497
+ if property.get("description"):
498
+ streamlit_app.markdown(property.get("description"))
499
+
500
+ object_reference = schema_utils.get_single_reference_item(
501
+ property, self._schema_references
502
+ )
503
+ return self._render_object_input(streamlit_app, key, object_reference)
504
+
505
+ def _render_property_list_input(
506
+ self, streamlit_app: st, key: str, property: Dict
507
+ ) -> Any:
508
+
509
+ # Add title and subheader
510
+ streamlit_app.subheader(property.get("title"))
511
+ if property.get("description"):
512
+ streamlit_app.markdown(property.get("description"))
513
+
514
+ streamlit_app.markdown("---")
515
+
516
+ current_list = self._get_value(key)
517
+ if not current_list:
518
+ current_list = []
519
+
520
+ value_kwargs = {"label": "Value", "key": key + "-new-value"}
521
+ if property["items"]["type"] == "integer":
522
+ value_kwargs["value"] = 0 # type: ignore
523
+ new_value = streamlit_app.number_input(**value_kwargs)
524
+ elif property["items"]["type"] == "number":
525
+ value_kwargs["value"] = 0.0 # type: ignore
526
+ value_kwargs["format"] = "%f"
527
+ new_value = streamlit_app.number_input(**value_kwargs)
528
+ else:
529
+ value_kwargs["value"] = ""
530
+ new_value = streamlit_app.text_input(**value_kwargs)
531
+
532
+ streamlit_app.markdown("---")
533
+
534
+ with streamlit_app.container():
535
+ clear_col, add_col = streamlit_app.columns([1, 2])
536
+
537
+ with clear_col:
538
+ if streamlit_app.button("Clear Items", key=key + "-clear-items"):
539
+ current_list = []
540
+
541
+ with add_col:
542
+ if (
543
+ streamlit_app.button("Add Item", key=key + "-add-item")
544
+ and new_value is not None
545
+ ):
546
+ current_list.append(new_value)
547
+
548
+ streamlit_app.write(current_list)
549
+
550
+ return current_list
551
+
552
+ def _render_object_list_input(
553
+ self, streamlit_app: st, key: str, property: Dict
554
+ ) -> Any:
555
+
556
+ # TODO: support max_items, and min_items properties
557
+
558
+ # Add title and subheader
559
+ streamlit_app.subheader(property.get("title"))
560
+ if property.get("description"):
561
+ streamlit_app.markdown(property.get("description"))
562
+
563
+ streamlit_app.markdown("---")
564
+
565
+ current_list = self._get_value(key)
566
+ if not current_list:
567
+ current_list = []
568
+
569
+ object_reference = schema_utils.resolve_reference(
570
+ property["items"]["$ref"], self._schema_references
571
+ )
572
+ input_data = self._render_object_input(streamlit_app, key, object_reference)
573
+
574
+ streamlit_app.markdown("---")
575
+
576
+ with streamlit_app.container():
577
+ clear_col, add_col = streamlit_app.columns([1, 2])
578
+
579
+ with clear_col:
580
+ if streamlit_app.button("Clear Items", key=key + "-clear-items"):
581
+ current_list = []
582
+
583
+ with add_col:
584
+ if (
585
+ streamlit_app.button("Add Item", key=key + "-add-item")
586
+ and input_data
587
+ ):
588
+ current_list.append(input_data)
589
+
590
+ streamlit_app.write(current_list)
591
+ return current_list
592
+
593
+ def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
594
+ if schema_utils.is_single_enum_property(property, self._schema_references):
595
+ return self._render_single_enum_input(streamlit_app, key, property)
596
+
597
+ if schema_utils.is_multi_enum_property(property, self._schema_references):
598
+ return self._render_multi_enum_input(streamlit_app, key, property)
599
+
600
+ if schema_utils.is_single_file_property(property):
601
+ return self._render_single_file_input(streamlit_app, key, property)
602
+
603
+ if schema_utils.is_multi_file_property(property):
604
+ return self._render_multi_file_input(streamlit_app, key, property)
605
+
606
+ if schema_utils.is_single_datetime_property(property):
607
+ return self._render_single_datetime_input(streamlit_app, key, property)
608
+
609
+ if schema_utils.is_single_boolean_property(property):
610
+ return self._render_single_boolean_input(streamlit_app, key, property)
611
+
612
+ if schema_utils.is_single_dict_property(property):
613
+ return self._render_single_dict_input(streamlit_app, key, property)
614
+
615
+ if schema_utils.is_single_number_property(property):
616
+ return self._render_single_number_input(streamlit_app, key, property)
617
+
618
+ if schema_utils.is_single_string_property(property):
619
+ return self._render_single_string_input(streamlit_app, key, property)
620
+
621
+ if schema_utils.is_single_object(property, self._schema_references):
622
+ return self._render_single_object_input(streamlit_app, key, property)
623
+
624
+ if schema_utils.is_object_list_property(property, self._schema_references):
625
+ return self._render_object_list_input(streamlit_app, key, property)
626
+
627
+ if schema_utils.is_property_list(property):
628
+ return self._render_property_list_input(streamlit_app, key, property)
629
+
630
+ if schema_utils.is_single_reference(property):
631
+ return self._render_single_reference(streamlit_app, key, property)
632
+
633
+ streamlit_app.warning(
634
+ "The type of the following property is currently not supported: "
635
+ + str(property.get("title"))
636
+ )
637
+ raise Exception("Unsupported property")
638
+
639
+
640
+ class OutputUI:
641
+ def __init__(self, output_data: Any, input_data: Any):
642
+ self._output_data = output_data
643
+ self._input_data = input_data
644
+
645
+ def render_ui(self, streamlit_app) -> None:
646
+ try:
647
+ if isinstance(self._output_data, BaseModel):
648
+ self._render_single_output(streamlit_app, self._output_data)
649
+ return
650
+ if type(self._output_data) == list:
651
+ self._render_list_output(streamlit_app, self._output_data)
652
+ return
653
+ except Exception as ex:
654
+ streamlit_app.exception(ex)
655
+ # Fallback to
656
+ streamlit_app.json(jsonable_encoder(self._output_data))
657
+
658
+ def _render_single_text_property(
659
+ self, streamlit: st, property_schema: Dict, value: Any
660
+ ) -> None:
661
+ # Add title and subheader
662
+ streamlit.subheader(property_schema.get("title"))
663
+ if property_schema.get("description"):
664
+ streamlit.markdown(property_schema.get("description"))
665
+ if value is None or value == "":
666
+ streamlit.info("No value returned!")
667
+ else:
668
+ streamlit.code(str(value), language="plain")
669
+
670
+ def _render_single_file_property(
671
+ self, streamlit: st, property_schema: Dict, value: Any
672
+ ) -> None:
673
+ # Add title and subheader
674
+ streamlit.subheader(property_schema.get("title"))
675
+ if property_schema.get("description"):
676
+ streamlit.markdown(property_schema.get("description"))
677
+ if value is None or value == "":
678
+ streamlit.info("No value returned!")
679
+ else:
680
+ # TODO: Detect if it is a FileContent instance
681
+ # TODO: detect if it is base64
682
+ file_extension = ""
683
+ if "mime_type" in property_schema:
684
+ mime_type = property_schema["mime_type"]
685
+ file_extension = mimetypes.guess_extension(mime_type) or ""
686
+
687
+ if is_compatible_audio(mime_type):
688
+ streamlit.audio(value.as_bytes(), format=mime_type)
689
+ return
690
+
691
+ if is_compatible_image(mime_type):
692
+ streamlit.image(value.as_bytes())
693
+ return
694
+
695
+ if is_compatible_video(mime_type):
696
+ streamlit.video(value.as_bytes(), format=mime_type)
697
+ return
698
+
699
+ filename = (
700
+ (property_schema["title"] + file_extension)
701
+ .lower()
702
+ .strip()
703
+ .replace(" ", "-")
704
+ )
705
+ streamlit.markdown(
706
+ f'<a href="data:application/octet-stream;base64,{value}" download="{filename}"><input type="button" value="Download File"></a>',
707
+ unsafe_allow_html=True,
708
+ )
709
+
710
+ def _render_single_complex_property(
711
+ self, streamlit: st, property_schema: Dict, value: Any
712
+ ) -> None:
713
+ # Add title and subheader
714
+ streamlit.subheader(property_schema.get("title"))
715
+ if property_schema.get("description"):
716
+ streamlit.markdown(property_schema.get("description"))
717
+
718
+ streamlit.json(jsonable_encoder(value))
719
+
720
+ def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
721
+ try:
722
+ if has_output_ui_renderer(output_data):
723
+ if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
724
+ # render method also requests the input data
725
+ output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
726
+ else:
727
+ output_data.render_output_ui(streamlit) # type: ignore
728
+ return
729
+ except Exception:
730
+ # Use default auto-generation methods if the custom rendering throws an exception
731
+ logger.exception(
732
+ "Failed to execute custom render_output_ui function. Using auto-generation instead"
733
+ )
734
+
735
+ model_schema = output_data.schema(by_alias=False)
736
+ model_properties = model_schema.get("properties")
737
+ definitions = model_schema.get("definitions")
738
+
739
+ if model_properties:
740
+ for property_key in output_data.__dict__:
741
+ property_schema = model_properties.get(property_key)
742
+ if not property_schema.get("title"):
743
+ # Set property key as fallback title
744
+ property_schema["title"] = property_key
745
+
746
+ output_property_value = output_data.__dict__[property_key]
747
+
748
+ if has_output_ui_renderer(output_property_value):
749
+ output_property_value.render_output_ui(streamlit) # type: ignore
750
+ continue
751
+
752
+ if isinstance(output_property_value, BaseModel):
753
+ # Render output recursivly
754
+ streamlit.subheader(property_schema.get("title"))
755
+ if property_schema.get("description"):
756
+ streamlit.markdown(property_schema.get("description"))
757
+ self._render_single_output(streamlit, output_property_value)
758
+ continue
759
+
760
+ if property_schema:
761
+ if schema_utils.is_single_file_property(property_schema):
762
+ self._render_single_file_property(
763
+ streamlit, property_schema, output_property_value
764
+ )
765
+ continue
766
+
767
+ if (
768
+ schema_utils.is_single_string_property(property_schema)
769
+ or schema_utils.is_single_number_property(property_schema)
770
+ or schema_utils.is_single_datetime_property(property_schema)
771
+ or schema_utils.is_single_boolean_property(property_schema)
772
+ ):
773
+ self._render_single_text_property(
774
+ streamlit, property_schema, output_property_value
775
+ )
776
+ continue
777
+ if definitions and schema_utils.is_single_enum_property(
778
+ property_schema, definitions
779
+ ):
780
+ self._render_single_text_property(
781
+ streamlit, property_schema, output_property_value.value
782
+ )
783
+ continue
784
+
785
+ # TODO: render dict as table
786
+
787
+ self._render_single_complex_property(
788
+ streamlit, property_schema, output_property_value
789
+ )
790
+ return
791
+
792
+ def _render_list_output(self, streamlit: st, output_data: List) -> None:
793
+ try:
794
+ data_items: List = []
795
+ for data_item in output_data:
796
+ if has_output_ui_renderer(data_item):
797
+ # Render using the render function
798
+ data_item.render_output_ui(streamlit) # type: ignore
799
+ continue
800
+ data_items.append(data_item.dict())
801
+ # Try to show as dataframe
802
+ streamlit.table(pd.DataFrame(data_items))
803
+ except Exception:
804
+ # Fallback to
805
+ streamlit.json(jsonable_encoder(output_data))
806
+
807
+
808
+ def getOpyrator(mode: str) -> Opyrator:
809
+ if mode == None or mode.startswith('VC'):
810
+ from mkgui.app_vc import convert
811
+ return Opyrator(convert)
812
+ if mode == None or mode.startswith('预处理'):
813
+ from mkgui.preprocess import preprocess
814
+ return Opyrator(preprocess)
815
+ if mode == None or mode.startswith('模型训练'):
816
+ from mkgui.train import train
817
+ return Opyrator(train)
818
+ if mode == None or mode.startswith('模型训练(VC)'):
819
+ from mkgui.train_vc import train_vc
820
+ return Opyrator(train_vc)
821
+ from mkgui.app import synthesize
822
+ return Opyrator(synthesize)
823
+
824
+
825
+ def render_streamlit_ui() -> None:
826
+ # init
827
+ session_state = st.session_state
828
+ session_state.input_data = {}
829
+ # Add custom css settings
830
+ st.markdown(f"<style>{CUSTOM_STREAMLIT_CSS}</style>", unsafe_allow_html=True)
831
+
832
+ with st.spinner("Loading MockingBird GUI. Please wait..."):
833
+ session_state.mode = st.sidebar.selectbox(
834
+ '模式选择',
835
+ ( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)")
836
+ )
837
+ if "mode" in session_state:
838
+ mode = session_state.mode
839
+ else:
840
+ mode = ""
841
+ opyrator = getOpyrator(mode)
842
+ title = opyrator.name + mode
843
+
844
+ col1, col2, _ = st.columns(3)
845
+ col2.title(title)
846
+ col2.markdown("欢迎使用MockingBird Web 2")
847
+
848
+ image = Image.open('.\\mkgui\\static\\mb.png')
849
+ col1.image(image)
850
+
851
+ st.markdown("---")
852
+ left, right = st.columns([0.4, 0.6])
853
+
854
+ with left:
855
+ st.header("Control 控制")
856
+ InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
857
+ execute_selected = st.button(opyrator.action)
858
+ if execute_selected:
859
+ with st.spinner("Executing operation. Please wait..."):
860
+ try:
861
+ input_data_obj = parse_obj_as(
862
+ opyrator.input_type, session_state.input_data
863
+ )
864
+ session_state.output_data = opyrator(input=input_data_obj)
865
+ session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
866
+ except ValidationError as ex:
867
+ st.error(ex)
868
+ else:
869
+ # st.success("Operation executed successfully.")
870
+ pass
871
+
872
+ with right:
873
+ st.header("Result 结果")
874
+ if 'output_data' in session_state:
875
+ OutputUI(
876
+ session_state.output_data, session_state.latest_operation_input
877
+ ).render_ui(st)
878
+ if st.button("Clear"):
879
+ # Clear all state
880
+ for key in st.session_state.keys():
881
+ del st.session_state[key]
882
+ session_state.input_data = {}
883
+ st.experimental_rerun()
884
+ else:
885
+ # placeholder
886
+ st.caption("请使用左侧控制板进行输入并运行获得结果")
887
+
888
+
mkgui/base/ui/streamlit_utils.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUSTOM_STREAMLIT_CSS = """
2
+ div[data-testid="stBlock"] button {
3
+ width: 100% !important;
4
+ margin-bottom: 20px !important;
5
+ border-color: #bfbfbf !important;
6
+ }
7
+ section[data-testid="stSidebar"] div {
8
+ max-width: 10rem;
9
+ }
10
+ pre code {
11
+ white-space: pre-wrap;
12
+ }
13
+ """
mkgui/preprocess.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ import os
3
+ from pathlib import Path
4
+ from enum import Enum
5
+ from typing import Any, Tuple
6
+
7
+
8
+ # Constants
9
+ EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
10
+ ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
11
+
12
+
13
+ if os.path.isdir(EXT_MODELS_DIRT):
14
+ extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
15
+ print("Loaded extractor models: " + str(len(extractors)))
16
+ else:
17
+ raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
18
+
19
+ if os.path.isdir(ENC_MODELS_DIRT):
20
+ encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
21
+ print("Loaded encoders models: " + str(len(encoders)))
22
+ else:
23
+ raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
24
+
25
+ class Model(str, Enum):
26
+ VC_PPG2MEL = "ppg2mel"
27
+
28
+ class Dataset(str, Enum):
29
+ AIDATATANG_200ZH = "aidatatang_200zh"
30
+ AIDATATANG_200ZH_S = "aidatatang_200zh_s"
31
+
32
+ class Input(BaseModel):
33
+ # def render_input_ui(st, input) -> Dict:
34
+ # input["selected_dataset"] = st.selectbox(
35
+ # '选择数据集',
36
+ # ("aidatatang_200zh", "aidatatang_200zh_s")
37
+ # )
38
+ # return input
39
+ model: Model = Field(
40
+ Model.VC_PPG2MEL, title="目标模型",
41
+ )
42
+ dataset: Dataset = Field(
43
+ Dataset.AIDATATANG_200ZH, title="数据集选择",
44
+ )
45
+ datasets_root: str = Field(
46
+ ..., alias="数据集根目录", description="输入数据集根目录(相对/绝对)",
47
+ format=True,
48
+ example="..\\trainning_data\\"
49
+ )
50
+ output_root: str = Field(
51
+ ..., alias="输出根目录", description="输出结果根目录(相对/绝对)",
52
+ format=True,
53
+ example="..\\trainning_data\\"
54
+ )
55
+ n_processes: int = Field(
56
+ 2, alias="处理线程数", description="根据CPU线程数来设置",
57
+ le=32, ge=1
58
+ )
59
+ extractor: extractors = Field(
60
+ ..., alias="特征提取模型",
61
+ description="选择PPG特征提取模型文件."
62
+ )
63
+ encoder: encoders = Field(
64
+ ..., alias="语音编码模型",
65
+ description="选择语音编码模型文件."
66
+ )
67
+
68
+ class AudioEntity(BaseModel):
69
+ content: bytes
70
+ mel: Any
71
+
72
+ class Output(BaseModel):
73
+ __root__: Tuple[str, int]
74
+
75
+ def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
76
+ """Custom output UI.
77
+ If this method is implmeneted, it will be used instead of the default Output UI renderer.
78
+ """
79
+ sr, count = self.__root__
80
+ streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
81
+
82
+ def preprocess(input: Input) -> Output:
83
+ """Preprocess(预处理)"""
84
+ finished = 0
85
+ if input.model == Model.VC_PPG2MEL:
86
+ from ppg2mel.preprocess import preprocess_dataset
87
+ finished = preprocess_dataset(
88
+ datasets_root=Path(input.datasets_root),
89
+ dataset=input.dataset,
90
+ out_dir=Path(input.output_root),
91
+ n_processes=input.n_processes,
92
+ ppg_encoder_model_fpath=Path(input.extractor.value),
93
+ speaker_encoder_model=Path(input.encoder.value)
94
+ )
95
+ # TODO: pass useful return code
96
+ return Output(__root__=(input.dataset, finished))
mkgui/static/mb.png ADDED
mkgui/train.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ import os
3
+ from pathlib import Path
4
+ from enum import Enum
5
+ from typing import Any
6
+ from synthesizer.hparams import hparams
7
+ from synthesizer.train import train as synt_train
8
+
9
+ # Constants
10
+ SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models"
11
+ ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
12
+
13
+
14
+ # EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
15
+ # CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
16
+ # ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
17
+
18
+ # Pre-Load models
19
+ if os.path.isdir(SYN_MODELS_DIRT):
20
+ synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt")))
21
+ print("Loaded synthesizer models: " + str(len(synthesizers)))
22
+ else:
23
+ raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.")
24
+
25
+ if os.path.isdir(ENC_MODELS_DIRT):
26
+ encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
27
+ print("Loaded encoders models: " + str(len(encoders)))
28
+ else:
29
+ raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
30
+
31
+ class Model(str, Enum):
32
+ DEFAULT = "default"
33
+
34
+ class Input(BaseModel):
35
+ model: Model = Field(
36
+ Model.DEFAULT, title="模型类型",
37
+ )
38
+ # datasets_root: str = Field(
39
+ # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
40
+ # format=True,
41
+ # example="..\\trainning_data\\"
42
+ # )
43
+ input_root: str = Field(
44
+ ..., alias="输入目录", description="预处理数据根目录",
45
+ format=True,
46
+ example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer"
47
+ )
48
+ run_id: str = Field(
49
+ "", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练",
50
+ )
51
+ synthesizer: synthesizers = Field(
52
+ ..., alias="已有合成模型",
53
+ description="选择语音合成模型文件."
54
+ )
55
+ gpu: bool = Field(
56
+ True, alias="GPU训练", description="选择“是”,则使用GPU训练",
57
+ )
58
+ verbose: bool = Field(
59
+ True, alias="打印详情", description="选择“是”,输出更多详情",
60
+ )
61
+ encoder: encoders = Field(
62
+ ..., alias="语音编码模型",
63
+ description="选择语音编码模型文件."
64
+ )
65
+ save_every: int = Field(
66
+ 1000, alias="更新间隔", description="每隔n步则更新一次模型",
67
+ )
68
+ backup_every: int = Field(
69
+ 10000, alias="保存间隔", description="每隔n步则保存一次模型",
70
+ )
71
+ log_every: int = Field(
72
+ 500, alias="打印间隔", description="每隔n步则打印一次训练统计",
73
+ )
74
+
75
+ class AudioEntity(BaseModel):
76
+ content: bytes
77
+ mel: Any
78
+
79
+ class Output(BaseModel):
80
+ __root__: int
81
+
82
+ def render_output_ui(self, streamlit_app) -> None: # type: ignore
83
+ """Custom output UI.
84
+ If this method is implmeneted, it will be used instead of the default Output UI renderer.
85
+ """
86
+ streamlit_app.subheader(f"Training started with code: {self.__root__}")
87
+
88
+ def train(input: Input) -> Output:
89
+ """Train(训练)"""
90
+
91
+ print(">>> Start training ...")
92
+ force_restart = len(input.run_id) > 0
93
+ if not force_restart:
94
+ input.run_id = Path(input.synthesizer.value).name.split('.')[0]
95
+
96
+ synt_train(
97
+ input.run_id,
98
+ input.input_root,
99
+ f"synthesizer{os.sep}saved_models",
100
+ input.save_every,
101
+ input.backup_every,
102
+ input.log_every,
103
+ force_restart,
104
+ hparams
105
+ )
106
+ return Output(__root__=0)
mkgui/train_vc.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ import os
3
+ from pathlib import Path
4
+ from enum import Enum
5
+ from typing import Any, Tuple
6
+ import numpy as np
7
+ from utils.load_yaml import HpsYaml
8
+ from utils.util import AttrDict
9
+ import torch
10
+
11
+ # Constants
12
+ EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models"
13
+ CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models"
14
+ ENC_MODELS_DIRT = f"encoder{os.sep}saved_models"
15
+
16
+
17
+ if os.path.isdir(EXT_MODELS_DIRT):
18
+ extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt")))
19
+ print("Loaded extractor models: " + str(len(extractors)))
20
+ else:
21
+ raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.")
22
+
23
+ if os.path.isdir(CONV_MODELS_DIRT):
24
+ convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth")))
25
+ print("Loaded convertor models: " + str(len(convertors)))
26
+ else:
27
+ raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.")
28
+
29
+ if os.path.isdir(ENC_MODELS_DIRT):
30
+ encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt")))
31
+ print("Loaded encoders models: " + str(len(encoders)))
32
+ else:
33
+ raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.")
34
+
35
+ class Model(str, Enum):
36
+ VC_PPG2MEL = "ppg2mel"
37
+
38
+ class Dataset(str, Enum):
39
+ AIDATATANG_200ZH = "aidatatang_200zh"
40
+ AIDATATANG_200ZH_S = "aidatatang_200zh_s"
41
+
42
+ class Input(BaseModel):
43
+ # def render_input_ui(st, input) -> Dict:
44
+ # input["selected_dataset"] = st.selectbox(
45
+ # '选择数据集',
46
+ # ("aidatatang_200zh", "aidatatang_200zh_s")
47
+ # )
48
+ # return input
49
+ model: Model = Field(
50
+ Model.VC_PPG2MEL, title="模型类型",
51
+ )
52
+ # datasets_root: str = Field(
53
+ # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型",
54
+ # format=True,
55
+ # example="..\\trainning_data\\"
56
+ # )
57
+ output_root: str = Field(
58
+ ..., alias="输出目录(可选)", description="建议不填,保持默认",
59
+ format=True,
60
+ example=""
61
+ )
62
+ continue_mode: bool = Field(
63
+ True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练",
64
+ )
65
+ gpu: bool = Field(
66
+ True, alias="GPU训练", description="选择“是”,则使用GPU训练",
67
+ )
68
+ verbose: bool = Field(
69
+ True, alias="打印详情", description="选择“是”,输出更多详情",
70
+ )
71
+ # TODO: Move to hiden fields by default
72
+ convertor: convertors = Field(
73
+ ..., alias="转换模型",
74
+ description="选择语音转换模型文件."
75
+ )
76
+ extractor: extractors = Field(
77
+ ..., alias="特征提取模型",
78
+ description="选择PPG特征提取模型文件."
79
+ )
80
+ encoder: encoders = Field(
81
+ ..., alias="语音编码模型",
82
+ description="选择语音编码模型文件."
83
+ )
84
+ njobs: int = Field(
85
+ 8, alias="进程数", description="适用于ppg2mel",
86
+ )
87
+ seed: int = Field(
88
+ default=0, alias="初始随机数", description="适用于ppg2mel",
89
+ )
90
+ model_name: str = Field(
91
+ ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效",
92
+ example="test"
93
+ )
94
+ model_config: str = Field(
95
+ ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效",
96
+ example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2"
97
+ )
98
+
99
+ class AudioEntity(BaseModel):
100
+ content: bytes
101
+ mel: Any
102
+
103
+ class Output(BaseModel):
104
+ __root__: Tuple[str, int]
105
+
106
+ def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
107
+ """Custom output UI.
108
+ If this method is implmeneted, it will be used instead of the default Output UI renderer.
109
+ """
110
+ sr, count = self.__root__
111
+ streamlit_app.subheader(f"Dataset {sr} done processed total of {count}")
112
+
113
+ def train_vc(input: Input) -> Output:
114
+ """Train VC(训练 VC)"""
115
+
116
+ print(">>> OneShot VC training ...")
117
+ params = AttrDict()
118
+ params.update({
119
+ "gpu": input.gpu,
120
+ "cpu": not input.gpu,
121
+ "njobs": input.njobs,
122
+ "seed": input.seed,
123
+ "verbose": input.verbose,
124
+ "load": input.convertor.value,
125
+ "warm_start": False,
126
+ })
127
+ if input.continue_mode:
128
+ # trace old model and config
129
+ p = Path(input.convertor.value)
130
+ params.name = p.parent.name
131
+ # search a config file
132
+ model_config_fpaths = list(p.parent.rglob("*.yaml"))
133
+ if len(model_config_fpaths) == 0:
134
+ raise "No model yaml config found for convertor"
135
+ config = HpsYaml(model_config_fpaths[0])
136
+ params.ckpdir = p.parent.parent
137
+ params.config = model_config_fpaths[0]
138
+ params.logdir = os.path.join(p.parent, "log")
139
+ else:
140
+ # Make the config dict dot visitable
141
+ config = HpsYaml(input.config)
142
+ np.random.seed(input.seed)
143
+ torch.manual_seed(input.seed)
144
+ if torch.cuda.is_available():
145
+ torch.cuda.manual_seed_all(input.seed)
146
+ mode = "train"
147
+ from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
148
+ solver = Solver(config, params, mode)
149
+ solver.load_data()
150
+ solver.set_model()
151
+ solver.exec()
152
+ print(">>> Oneshot VC train finished!")
153
+
154
+ # TODO: pass useful return code
155
+ return Output(__root__=(input.dataset, 0))
packages.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ libasound2-dev
2
+ portaudio19-dev
3
+ libportaudio2
4
+ libportaudiocpp0
5
+ ffmpeg
ppg2mel/__init__.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020 Songxiang Liu
4
+ # Apache 2.0
5
+
6
+ from typing import List
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ import numpy as np
12
+
13
+ from .utils.abs_model import AbsMelDecoder
14
+ from .rnn_decoder_mol import Decoder
15
+ from .utils.cnn_postnet import Postnet
16
+ from .utils.vc_utils import get_mask_from_lengths
17
+
18
+ from utils.load_yaml import HpsYaml
19
+
20
+ class MelDecoderMOLv2(AbsMelDecoder):
21
+ """Use an encoder to preprocess ppg."""
22
+ def __init__(
23
+ self,
24
+ num_speakers: int,
25
+ spk_embed_dim: int,
26
+ bottle_neck_feature_dim: int,
27
+ encoder_dim: int = 256,
28
+ encoder_downsample_rates: List = [2, 2],
29
+ attention_rnn_dim: int = 512,
30
+ decoder_rnn_dim: int = 512,
31
+ num_decoder_rnn_layer: int = 1,
32
+ concat_context_to_last: bool = True,
33
+ prenet_dims: List = [256, 128],
34
+ num_mixtures: int = 5,
35
+ frames_per_step: int = 2,
36
+ mask_padding: bool = True,
37
+ ):
38
+ super().__init__()
39
+
40
+ self.mask_padding = mask_padding
41
+ self.bottle_neck_feature_dim = bottle_neck_feature_dim
42
+ self.num_mels = 80
43
+ self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
44
+ self.frames_per_step = frames_per_step
45
+ self.use_spk_dvec = True
46
+
47
+ input_dim = bottle_neck_feature_dim
48
+
49
+ # Downsampling convolution
50
+ self.bnf_prenet = torch.nn.Sequential(
51
+ torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
52
+ torch.nn.LeakyReLU(0.1),
53
+
54
+ torch.nn.InstanceNorm1d(encoder_dim, affine=False),
55
+ torch.nn.Conv1d(
56
+ encoder_dim, encoder_dim,
57
+ kernel_size=2*encoder_downsample_rates[0],
58
+ stride=encoder_downsample_rates[0],
59
+ padding=encoder_downsample_rates[0]//2,
60
+ ),
61
+ torch.nn.LeakyReLU(0.1),
62
+
63
+ torch.nn.InstanceNorm1d(encoder_dim, affine=False),
64
+ torch.nn.Conv1d(
65
+ encoder_dim, encoder_dim,
66
+ kernel_size=2*encoder_downsample_rates[1],
67
+ stride=encoder_downsample_rates[1],
68
+ padding=encoder_downsample_rates[1]//2,
69
+ ),
70
+ torch.nn.LeakyReLU(0.1),
71
+
72
+ torch.nn.InstanceNorm1d(encoder_dim, affine=False),
73
+ )
74
+ decoder_enc_dim = encoder_dim
75
+ self.pitch_convs = torch.nn.Sequential(
76
+ torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
77
+ torch.nn.LeakyReLU(0.1),
78
+
79
+ torch.nn.InstanceNorm1d(encoder_dim, affine=False),
80
+ torch.nn.Conv1d(
81
+ encoder_dim, encoder_dim,
82
+ kernel_size=2*encoder_downsample_rates[0],
83
+ stride=encoder_downsample_rates[0],
84
+ padding=encoder_downsample_rates[0]//2,
85
+ ),
86
+ torch.nn.LeakyReLU(0.1),
87
+
88
+ torch.nn.InstanceNorm1d(encoder_dim, affine=False),
89
+ torch.nn.Conv1d(
90
+ encoder_dim, encoder_dim,
91
+ kernel_size=2*encoder_downsample_rates[1],
92
+ stride=encoder_downsample_rates[1],
93
+ padding=encoder_downsample_rates[1]//2,
94
+ ),
95
+ torch.nn.LeakyReLU(0.1),
96
+
97
+ torch.nn.InstanceNorm1d(encoder_dim, affine=False),
98
+ )
99
+
100
+ self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
101
+
102
+ # Decoder
103
+ self.decoder = Decoder(
104
+ enc_dim=decoder_enc_dim,
105
+ num_mels=self.num_mels,
106
+ frames_per_step=frames_per_step,
107
+ attention_rnn_dim=attention_rnn_dim,
108
+ decoder_rnn_dim=decoder_rnn_dim,
109
+ num_decoder_rnn_layer=num_decoder_rnn_layer,
110
+ prenet_dims=prenet_dims,
111
+ num_mixtures=num_mixtures,
112
+ use_stop_tokens=True,
113
+ concat_context_to_last=concat_context_to_last,
114
+ encoder_down_factor=self.encoder_down_factor,
115
+ )
116
+
117
+ # Mel-Spec Postnet: some residual CNN layers
118
+ self.postnet = Postnet()
119
+
120
+ def parse_output(self, outputs, output_lengths=None):
121
+ if self.mask_padding and output_lengths is not None:
122
+ mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
123
+ mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
124
+ outputs[0].data.masked_fill_(mask, 0.0)
125
+ outputs[1].data.masked_fill_(mask, 0.0)
126
+ return outputs
127
+
128
+ def forward(
129
+ self,
130
+ bottle_neck_features: torch.Tensor,
131
+ feature_lengths: torch.Tensor,
132
+ speech: torch.Tensor,
133
+ speech_lengths: torch.Tensor,
134
+ logf0_uv: torch.Tensor = None,
135
+ spembs: torch.Tensor = None,
136
+ output_att_ws: bool = False,
137
+ ):
138
+ decoder_inputs = self.bnf_prenet(
139
+ bottle_neck_features.transpose(1, 2)
140
+ ).transpose(1, 2)
141
+ logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
142
+ decoder_inputs = decoder_inputs + logf0_uv
143
+
144
+ assert spembs is not None
145
+ spk_embeds = F.normalize(
146
+ spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
147
+ decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
148
+ decoder_inputs = self.reduce_proj(decoder_inputs)
149
+
150
+ # (B, num_mels, T_dec)
151
+ T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
152
+ mel_outputs, predicted_stop, alignments = self.decoder(
153
+ decoder_inputs, speech, T_dec)
154
+ ## Post-processing
155
+ mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
156
+ mel_outputs_postnet = mel_outputs + mel_outputs_postnet
157
+ if output_att_ws:
158
+ return self.parse_output(
159
+ [mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
160
+ else:
161
+ return self.parse_output(
162
+ [mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
163
+
164
+ # return mel_outputs, mel_outputs_postnet
165
+
166
+ def inference(
167
+ self,
168
+ bottle_neck_features: torch.Tensor,
169
+ logf0_uv: torch.Tensor = None,
170
+ spembs: torch.Tensor = None,
171
+ ):
172
+ decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
173
+ logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
174
+ decoder_inputs = decoder_inputs + logf0_uv
175
+
176
+ assert spembs is not None
177
+ spk_embeds = F.normalize(
178
+ spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
179
+ bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
180
+ bottle_neck_features = self.reduce_proj(bottle_neck_features)
181
+
182
+ ## Decoder
183
+ if bottle_neck_features.size(0) > 1:
184
+ mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
185
+ else:
186
+ mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
187
+ ## Post-processing
188
+ mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
189
+ mel_outputs_postnet = mel_outputs + mel_outputs_postnet
190
+ # outputs = mel_outputs_postnet[0]
191
+
192
+ return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
193
+
194
+ def load_model(model_file, device=None):
195
+ # search a config file
196
+ model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
197
+ if len(model_config_fpaths) == 0:
198
+ raise "No model yaml config found for convertor"
199
+ if device is None:
200
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
201
+
202
+ model_config = HpsYaml(model_config_fpaths[0])
203
+ ppg2mel_model = MelDecoderMOLv2(
204
+ **model_config["model"]
205
+ ).to(device)
206
+ ckpt = torch.load(model_file, map_location=device)
207
+ ppg2mel_model.load_state_dict(ckpt["model"])
208
+ ppg2mel_model.eval()
209
+ return ppg2mel_model
ppg2mel/preprocess.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from pathlib import Path
7
+ import soundfile
8
+ import resampy
9
+
10
+ from ppg_extractor import load_model
11
+ import encoder.inference as Encoder
12
+ from encoder.audio import preprocess_wav
13
+ from encoder import audio
14
+ from utils.f0_utils import compute_f0
15
+
16
+ from torch.multiprocessing import Pool, cpu_count
17
+ from functools import partial
18
+
19
+ SAMPLE_RATE=16000
20
+
21
+ def _compute_bnf(
22
+ wav: any,
23
+ output_fpath: str,
24
+ device: torch.device,
25
+ ppg_model_local: any,
26
+ ):
27
+ """
28
+ Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
29
+ """
30
+ ppg_model_local.to(device)
31
+ wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
32
+ wav_length = torch.LongTensor([wav.shape[0]]).to(device)
33
+ with torch.no_grad():
34
+ bnf = ppg_model_local(wav_tensor, wav_length)
35
+ bnf_npy = bnf.squeeze(0).cpu().numpy()
36
+ np.save(output_fpath, bnf_npy, allow_pickle=False)
37
+ return bnf_npy, len(bnf_npy)
38
+
39
+ def _compute_f0_from_wav(wav, output_fpath):
40
+ """Compute merged f0 values."""
41
+ f0 = compute_f0(wav, SAMPLE_RATE)
42
+ np.save(output_fpath, f0, allow_pickle=False)
43
+ return f0, len(f0)
44
+
45
+ def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
46
+ Encoder.set_model(encoder_model_local)
47
+ # Compute where to split the utterance into partials and pad if necessary
48
+ wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
49
+ max_wave_length = wave_slices[-1].stop
50
+ if max_wave_length >= len(wav):
51
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
52
+
53
+ # Split the utterance into partials
54
+ frames = audio.wav_to_mel_spectrogram(wav)
55
+ frames_batch = np.array([frames[s] for s in mel_slices])
56
+ partial_embeds = Encoder.embed_frames_batch(frames_batch)
57
+
58
+ # Compute the utterance embedding from the partial embeddings
59
+ raw_embed = np.mean(partial_embeds, axis=0)
60
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
61
+
62
+ np.save(output_fpath, embed, allow_pickle=False)
63
+ return embed, len(embed)
64
+
65
+ def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
66
+ # wav = preprocess_wav(wav_path)
67
+ # try:
68
+ wav, sr = soundfile.read(wav_path)
69
+ if len(wav) < sr:
70
+ return None, sr, len(wav)
71
+ if sr != SAMPLE_RATE:
72
+ wav = resampy.resample(wav, sr, SAMPLE_RATE)
73
+ sr = SAMPLE_RATE
74
+ utt_id = os.path.basename(wav_path).rstrip(".wav")
75
+
76
+ _, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
77
+ _, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
78
+ _, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
79
+
80
+ def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
81
+ # Glob wav files
82
+ wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
83
+ print(f"Globbed {len(wav_file_list)} wav files.")
84
+
85
+ out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
86
+ out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
87
+ out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
88
+ ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
89
+ encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
90
+ if n_processes is None:
91
+ n_processes = cpu_count()
92
+
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+ func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
95
+ job = Pool(n_processes).imap(func, wav_file_list)
96
+ list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
97
+
98
+ # finish processing and mark
99
+ t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
100
+ d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
101
+ e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
102
+ for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
103
+ id = os.path.basename(file).split(".f0.npy")[0]
104
+ if id.endswith("01"):
105
+ d_fid_file.write(id + "\n")
106
+ elif id.endswith("09"):
107
+ e_fid_file.write(id + "\n")
108
+ else:
109
+ t_fid_file.write(id + "\n")
110
+ t_fid_file.close()
111
+ d_fid_file.close()
112
+ e_fid_file.close()
113
+ return len(wav_file_list)
ppg2mel/rnn_decoder_mol.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from .utils.mol_attention import MOLAttention
6
+ from .utils.basic_layers import Linear
7
+ from .utils.vc_utils import get_mask_from_lengths
8
+
9
+
10
+ class DecoderPrenet(nn.Module):
11
+ def __init__(self, in_dim, sizes):
12
+ super().__init__()
13
+ in_sizes = [in_dim] + sizes[:-1]
14
+ self.layers = nn.ModuleList(
15
+ [Linear(in_size, out_size, bias=False)
16
+ for (in_size, out_size) in zip(in_sizes, sizes)])
17
+
18
+ def forward(self, x):
19
+ for linear in self.layers:
20
+ x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
21
+ return x
22
+
23
+
24
+ class Decoder(nn.Module):
25
+ """Mixture of Logistic (MoL) attention-based RNN Decoder."""
26
+ def __init__(
27
+ self,
28
+ enc_dim,
29
+ num_mels,
30
+ frames_per_step,
31
+ attention_rnn_dim,
32
+ decoder_rnn_dim,
33
+ prenet_dims,
34
+ num_mixtures,
35
+ encoder_down_factor=1,
36
+ num_decoder_rnn_layer=1,
37
+ use_stop_tokens=False,
38
+ concat_context_to_last=False,
39
+ ):
40
+ super().__init__()
41
+ self.enc_dim = enc_dim
42
+ self.encoder_down_factor = encoder_down_factor
43
+ self.num_mels = num_mels
44
+ self.frames_per_step = frames_per_step
45
+ self.attention_rnn_dim = attention_rnn_dim
46
+ self.decoder_rnn_dim = decoder_rnn_dim
47
+ self.prenet_dims = prenet_dims
48
+ self.use_stop_tokens = use_stop_tokens
49
+ self.num_decoder_rnn_layer = num_decoder_rnn_layer
50
+ self.concat_context_to_last = concat_context_to_last
51
+
52
+ # Mel prenet
53
+ self.prenet = DecoderPrenet(num_mels, prenet_dims)
54
+ self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
55
+
56
+ # Attention RNN
57
+ self.attention_rnn = nn.LSTMCell(
58
+ prenet_dims[-1] + enc_dim,
59
+ attention_rnn_dim
60
+ )
61
+
62
+ # Attention
63
+ self.attention_layer = MOLAttention(
64
+ attention_rnn_dim,
65
+ r=frames_per_step/encoder_down_factor,
66
+ M=num_mixtures,
67
+ )
68
+
69
+ # Decoder RNN
70
+ self.decoder_rnn_layers = nn.ModuleList()
71
+ for i in range(num_decoder_rnn_layer):
72
+ if i == 0:
73
+ self.decoder_rnn_layers.append(
74
+ nn.LSTMCell(
75
+ enc_dim + attention_rnn_dim,
76
+ decoder_rnn_dim))
77
+ else:
78
+ self.decoder_rnn_layers.append(
79
+ nn.LSTMCell(
80
+ decoder_rnn_dim,
81
+ decoder_rnn_dim))
82
+ # self.decoder_rnn = nn.LSTMCell(
83
+ # 2 * enc_dim + attention_rnn_dim,
84
+ # decoder_rnn_dim
85
+ # )
86
+ if concat_context_to_last:
87
+ self.linear_projection = Linear(
88
+ enc_dim + decoder_rnn_dim,
89
+ num_mels * frames_per_step
90
+ )
91
+ else:
92
+ self.linear_projection = Linear(
93
+ decoder_rnn_dim,
94
+ num_mels * frames_per_step
95
+ )
96
+
97
+
98
+ # Stop-token layer
99
+ if self.use_stop_tokens:
100
+ if concat_context_to_last:
101
+ self.stop_layer = Linear(
102
+ enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
103
+ )
104
+ else:
105
+ self.stop_layer = Linear(
106
+ decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
107
+ )
108
+
109
+
110
+ def get_go_frame(self, memory):
111
+ B = memory.size(0)
112
+ go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
113
+ device=memory.device)
114
+ return go_frame
115
+
116
+ def initialize_decoder_states(self, memory, mask):
117
+ device = next(self.parameters()).device
118
+ B = memory.size(0)
119
+
120
+ # attention rnn states
121
+ self.attention_hidden = torch.zeros(
122
+ (B, self.attention_rnn_dim), device=device)
123
+ self.attention_cell = torch.zeros(
124
+ (B, self.attention_rnn_dim), device=device)
125
+
126
+ # decoder rnn states
127
+ self.decoder_hiddens = []
128
+ self.decoder_cells = []
129
+ for i in range(self.num_decoder_rnn_layer):
130
+ self.decoder_hiddens.append(
131
+ torch.zeros((B, self.decoder_rnn_dim),
132
+ device=device)
133
+ )
134
+ self.decoder_cells.append(
135
+ torch.zeros((B, self.decoder_rnn_dim),
136
+ device=device)
137
+ )
138
+ # self.decoder_hidden = torch.zeros(
139
+ # (B, self.decoder_rnn_dim), device=device)
140
+ # self.decoder_cell = torch.zeros(
141
+ # (B, self.decoder_rnn_dim), device=device)
142
+
143
+ self.attention_context = torch.zeros(
144
+ (B, self.enc_dim), device=device)
145
+
146
+ self.memory = memory
147
+ # self.processed_memory = self.attention_layer.memory_layer(memory)
148
+ self.mask = mask
149
+
150
+ def parse_decoder_inputs(self, decoder_inputs):
151
+ """Prepare decoder inputs, i.e. gt mel
152
+ Args:
153
+ decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
154
+ """
155
+ decoder_inputs = decoder_inputs.reshape(
156
+ decoder_inputs.size(0),
157
+ int(decoder_inputs.size(1)/self.frames_per_step), -1)
158
+ # (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
159
+ decoder_inputs = decoder_inputs.transpose(0, 1)
160
+ # (T_out//r, B, num_mels)
161
+ decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
162
+ return decoder_inputs
163
+
164
+ def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
165
+ """ Prepares decoder outputs for output
166
+ Args:
167
+ mel_outputs:
168
+ alignments:
169
+ """
170
+ # (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
171
+ alignments = torch.stack(alignments).transpose(0, 1)
172
+ # (T_out//r, B) -> (B, T_out//r)
173
+ if stop_outputs is not None:
174
+ if alignments.size(0) == 1:
175
+ stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
176
+ else:
177
+ stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
178
+ stop_outputs = stop_outputs.contiguous()
179
+ # (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
180
+ mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
181
+ # decouple frames per step
182
+ # (B, T_out, num_mels)
183
+ mel_outputs = mel_outputs.view(
184
+ mel_outputs.size(0), -1, self.num_mels)
185
+ return mel_outputs, alignments, stop_outputs
186
+
187
+ def attend(self, decoder_input):
188
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
189
+ self.attention_hidden, self.attention_cell = self.attention_rnn(
190
+ cell_input, (self.attention_hidden, self.attention_cell))
191
+ self.attention_context, attention_weights = self.attention_layer(
192
+ self.attention_hidden, self.memory, None, self.mask)
193
+
194
+ decoder_rnn_input = torch.cat(
195
+ (self.attention_hidden, self.attention_context), -1)
196
+
197
+ return decoder_rnn_input, self.attention_context, attention_weights
198
+
199
+ def decode(self, decoder_input):
200
+ for i in range(self.num_decoder_rnn_layer):
201
+ if i == 0:
202
+ self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
203
+ decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
204
+ else:
205
+ self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
206
+ self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
207
+ return self.decoder_hiddens[-1]
208
+
209
+ def forward(self, memory, mel_inputs, memory_lengths):
210
+ """ Decoder forward pass for training
211
+ Args:
212
+ memory: (B, T_enc, enc_dim) Encoder outputs
213
+ decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
214
+ memory_lengths: (B, ) Encoder output lengths for attention masking.
215
+ Returns:
216
+ mel_outputs: (B, T, num_mels) mel outputs from the decoder
217
+ alignments: (B, T//r, T_enc) attention weights.
218
+ """
219
+ # [1, B, num_mels]
220
+ go_frame = self.get_go_frame(memory).unsqueeze(0)
221
+ # [T//r, B, num_mels]
222
+ mel_inputs = self.parse_decoder_inputs(mel_inputs)
223
+ # [T//r + 1, B, num_mels]
224
+ mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
225
+ # [T//r + 1, B, prenet_dim]
226
+ decoder_inputs = self.prenet(mel_inputs)
227
+ # decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
228
+
229
+ self.initialize_decoder_states(
230
+ memory, mask=~get_mask_from_lengths(memory_lengths),
231
+ )
232
+
233
+ self.attention_layer.init_states(memory)
234
+ # self.attention_layer_pitch.init_states(memory_pitch)
235
+
236
+ mel_outputs, alignments = [], []
237
+ if self.use_stop_tokens:
238
+ stop_outputs = []
239
+ else:
240
+ stop_outputs = None
241
+ while len(mel_outputs) < decoder_inputs.size(0) - 1:
242
+ decoder_input = decoder_inputs[len(mel_outputs)]
243
+ # decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
244
+
245
+ decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
246
+
247
+ decoder_rnn_output = self.decode(decoder_rnn_input)
248
+ if self.concat_context_to_last:
249
+ decoder_rnn_output = torch.cat(
250
+ (decoder_rnn_output, context), dim=1)
251
+
252
+ mel_output = self.linear_projection(decoder_rnn_output)
253
+ if self.use_stop_tokens:
254
+ stop_output = self.stop_layer(decoder_rnn_output)
255
+ stop_outputs += [stop_output.squeeze()]
256
+ mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
257
+ alignments += [attention_weights]
258
+ # alignments_pitch += [attention_weights_pitch]
259
+
260
+ mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
261
+ mel_outputs, alignments, stop_outputs)
262
+ if stop_outputs is None:
263
+ return mel_outputs, alignments
264
+ else:
265
+ return mel_outputs, stop_outputs, alignments
266
+
267
+ def inference(self, memory, stop_threshold=0.5):
268
+ """ Decoder inference
269
+ Args:
270
+ memory: (1, T_enc, D_enc) Encoder outputs
271
+ Returns:
272
+ mel_outputs: mel outputs from the decoder
273
+ alignments: sequence of attention weights from the decoder
274
+ """
275
+ # [1, num_mels]
276
+ decoder_input = self.get_go_frame(memory)
277
+
278
+ self.initialize_decoder_states(memory, mask=None)
279
+
280
+ self.attention_layer.init_states(memory)
281
+
282
+ mel_outputs, alignments = [], []
283
+ # NOTE(sx): heuristic
284
+ max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
285
+ min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
286
+ while True:
287
+ decoder_input = self.prenet(decoder_input)
288
+
289
+ decoder_input_final, context, alignment = self.attend(decoder_input)
290
+
291
+ #mel_output, stop_output, alignment = self.decode(decoder_input)
292
+ decoder_rnn_output = self.decode(decoder_input_final)
293
+ if self.concat_context_to_last:
294
+ decoder_rnn_output = torch.cat(
295
+ (decoder_rnn_output, context), dim=1)
296
+
297
+ mel_output = self.linear_projection(decoder_rnn_output)
298
+ stop_output = self.stop_layer(decoder_rnn_output)
299
+
300
+ mel_outputs += [mel_output.squeeze(1)]
301
+ alignments += [alignment]
302
+
303
+ if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
304
+ break
305
+ if len(mel_outputs) >= max_decoder_step:
306
+ # print("Warning! Decoding steps reaches max decoder steps.")
307
+ break
308
+
309
+ decoder_input = mel_output[:,-self.num_mels:]
310
+
311
+
312
+ mel_outputs, alignments, _ = self.parse_decoder_outputs(
313
+ mel_outputs, alignments, None)
314
+
315
+ return mel_outputs, alignments
316
+
317
+ def inference_batched(self, memory, stop_threshold=0.5):
318
+ """ Decoder inference
319
+ Args:
320
+ memory: (B, T_enc, D_enc) Encoder outputs
321
+ Returns:
322
+ mel_outputs: mel outputs from the decoder
323
+ alignments: sequence of attention weights from the decoder
324
+ """
325
+ # [1, num_mels]
326
+ decoder_input = self.get_go_frame(memory)
327
+
328
+ self.initialize_decoder_states(memory, mask=None)
329
+
330
+ self.attention_layer.init_states(memory)
331
+
332
+ mel_outputs, alignments = [], []
333
+ stop_outputs = []
334
+ # NOTE(sx): heuristic
335
+ max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
336
+ min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
337
+ while True:
338
+ decoder_input = self.prenet(decoder_input)
339
+
340
+ decoder_input_final, context, alignment = self.attend(decoder_input)
341
+
342
+ #mel_output, stop_output, alignment = self.decode(decoder_input)
343
+ decoder_rnn_output = self.decode(decoder_input_final)
344
+ if self.concat_context_to_last:
345
+ decoder_rnn_output = torch.cat(
346
+ (decoder_rnn_output, context), dim=1)
347
+
348
+ mel_output = self.linear_projection(decoder_rnn_output)
349
+ # (B, 1)
350
+ stop_output = self.stop_layer(decoder_rnn_output)
351
+ stop_outputs += [stop_output.squeeze()]
352
+ # stop_outputs.append(stop_output)
353
+
354
+ mel_outputs += [mel_output.squeeze(1)]
355
+ alignments += [alignment]
356
+ # print(stop_output.shape)
357
+ if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
358
+ and len(mel_outputs) >= min_decoder_step:
359
+ break
360
+ if len(mel_outputs) >= max_decoder_step:
361
+ # print("Warning! Decoding steps reaches max decoder steps.")
362
+ break
363
+
364
+ decoder_input = mel_output[:,-self.num_mels:]
365
+
366
+
367
+ mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
368
+ mel_outputs, alignments, stop_outputs)
369
+ mel_outputs_stacked = []
370
+ for mel, stop_logit in zip(mel_outputs, stop_outputs):
371
+ idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
372
+ mel_outputs_stacked.append(mel[:idx,:])
373
+ mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
374
+ return mel_outputs, alignments