Delete main
Browse files- main/app/app.py +0 -0
- main/app/tensorboard.py +0 -30
- main/configs/config.json +0 -26
- main/configs/config.py +0 -70
- main/configs/v1/32000.json +0 -46
- main/configs/v1/40000.json +0 -46
- main/configs/v1/44100.json +0 -46
- main/configs/v1/48000.json +0 -46
- main/configs/v2/32000.json +0 -42
- main/configs/v2/40000.json +0 -42
- main/configs/v2/44100.json +0 -42
- main/configs/v2/48000.json +0 -42
- main/inference/audio_effects.py +0 -170
- main/inference/convert.py +0 -650
- main/inference/create_dataset.py +0 -240
- main/inference/create_index.py +0 -100
- main/inference/extract.py +0 -450
- main/inference/preprocess.py +0 -290
- main/inference/separator_music.py +0 -290
- main/inference/train.py +0 -1000
- main/library/algorithm/commons.py +0 -50
- main/library/algorithm/modules.py +0 -70
- main/library/algorithm/mrf_hifigan.py +0 -160
- main/library/algorithm/refinegan.py +0 -180
- main/library/algorithm/residuals.py +0 -140
- main/library/algorithm/separator.py +0 -330
- main/library/algorithm/synthesizers.py +0 -450
- main/library/architectures/demucs_separator.py +0 -160
- main/library/architectures/mdx_separator.py +0 -320
- main/library/predictors/CREPE.py +0 -210
- main/library/predictors/FCPE.py +0 -670
- main/library/predictors/RMVPE.py +0 -260
- main/library/predictors/WORLD.py +0 -90
- main/library/utils.py +0 -100
- main/library/uvr5_separator/common_separator.py +0 -250
- main/library/uvr5_separator/demucs/apply.py +0 -250
- main/library/uvr5_separator/demucs/demucs.py +0 -370
- main/library/uvr5_separator/demucs/hdemucs.py +0 -760
- main/library/uvr5_separator/demucs/htdemucs.py +0 -600
- main/library/uvr5_separator/demucs/states.py +0 -55
- main/library/uvr5_separator/demucs/utils.py +0 -8
- main/library/uvr5_separator/spec_utils.py +0 -900
- main/tools/edge_tts.py +0 -180
- main/tools/gdown.py +0 -110
- main/tools/google_tts.py +0 -30
- main/tools/huggingface.py +0 -24
- main/tools/mediafire.py +0 -30
- main/tools/meganz.py +0 -160
- main/tools/noisereduce.py +0 -200
- main/tools/pixeldrain.py +0 -16
main/app/app.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
main/app/tensorboard.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import json
|
4 |
-
import logging
|
5 |
-
import webbrowser
|
6 |
-
|
7 |
-
from tensorboard import program
|
8 |
-
|
9 |
-
sys.path.append(os.getcwd())
|
10 |
-
|
11 |
-
from main.configs.config import Config
|
12 |
-
translations = Config().translations
|
13 |
-
|
14 |
-
with open(os.path.join("main", "configs", "config.json"), "r") as f:
|
15 |
-
configs = json.load(f)
|
16 |
-
|
17 |
-
def launch_tensorboard():
|
18 |
-
for l in ["root", "tensorboard"]:
|
19 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
20 |
-
|
21 |
-
tb = program.TensorBoard()
|
22 |
-
tb.configure(argv=[None, "--logdir", "assets/logs", f"--port={configs["tensorboard_port"]}"])
|
23 |
-
url = tb.launch()
|
24 |
-
|
25 |
-
print(f"{translations['tensorboard_url']}: {url}")
|
26 |
-
if "--open" in sys.argv: webbrowser.open(url)
|
27 |
-
|
28 |
-
return f"{translations['tensorboard_url']}: {url}"
|
29 |
-
|
30 |
-
if __name__ == "__main__": launch_tensorboard()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.json
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"language": "vi-VN",
|
3 |
-
"support_language": ["en-US", "vi-VN"],
|
4 |
-
"theme": "NoCrypt/miku",
|
5 |
-
"themes": ["NoCrypt/miku", "gstaff/xkcd", "JohnSmith9982/small_and_pretty", "ParityError/Interstellar", "earneleh/paris", "shivi/calm_seafoam", "Hev832/Applio", "YTheme/Minecraft", "gstaff/sketch", "SebastianBravo/simci_css", "allenai/gradio-theme", "Nymbo/Nymbo_Theme_5", "lone17/kotaemon", "Zarkel/IBM_Carbon_Theme", "SherlockRamos/Feliz", "freddyaboulton/dracula_revamped", "freddyaboulton/bad-theme-space", "gradio/dracula_revamped", "abidlabs/dracula_revamped", "gradio/dracula_test", "gradio/seafoam", "gradio/glass", "gradio/monochrome", "gradio/soft", "gradio/default", "gradio/base", "abidlabs/pakistan", "dawood/microsoft_windows", "ysharma/steampunk", "ysharma/huggingface", "abidlabs/Lime", "freddyaboulton/this-theme-does-not-exist-2", "aliabid94/new-theme", "aliabid94/test2", "aliabid94/test3", "aliabid94/test4", "abidlabs/banana", "freddyaboulton/test-blue", "gstaff/whiteboard", "ysharma/llamas", "abidlabs/font-test", "YenLai/Superhuman", "bethecloud/storj_theme", "sudeepshouche/minimalist", "knotdgaf/gradiotest", "ParityError/Anime", "Ajaxon6255/Emerald_Isle", "ParityError/LimeFace", "finlaymacklon/smooth_slate", "finlaymacklon/boxy_violet", "derekzen/stardust", "EveryPizza/Cartoony-Gradio-Theme", "Ifeanyi/Cyanister", "Tshackelton/IBMPlex-DenseReadable", "snehilsanyal/scikit-learn", "Himhimhim/xkcd", "nota-ai/theme", "rawrsor1/Everforest", "rottenlittlecreature/Moon_Goblin", "abidlabs/test-yellow", "abidlabs/test-yellow3", "idspicQstitho/dracula_revamped", "kfahn/AnimalPose", "HaleyCH/HaleyCH_Theme", "simulKitke/dracula_test", "braintacles/CrimsonNight", "wentaohe/whiteboardv2", "reilnuud/polite", "remilia/Ghostly", "Franklisi/darkmode", "coding-alt/soft", "xiaobaiyuan/theme_land", "step-3-profit/Midnight-Deep", "xiaobaiyuan/theme_demo", "Taithrah/Minimal", "Insuz/SimpleIndigo", "zkunn/Alipay_Gradio_theme", "Insuz/Mocha", "xiaobaiyuan/theme_brief", "Ama434/434-base-Barlow", "Ama434/def_barlow", "Ama434/neutral-barlow", "dawood/dracula_test", "nuttea/Softblue", "BlueDancer/Alien_Diffusion", "naughtondale/monochrome", "Dagfinn1962/standard", "default"],
|
6 |
-
"mdx_model": ["Main_340", "Main_390", "Main_406", "Main_427", "Main_438", "Inst_full_292", "Inst_HQ_1", "Inst_HQ_2", "Inst_HQ_3", "Inst_HQ_4", "Inst_HQ_5", "Kim_Vocal_1", "Kim_Vocal_2", "Kim_Inst", "Inst_187_beta", "Inst_82_beta", "Inst_90_beta", "Voc_FT", "Crowd_HQ", "Inst_1", "Inst_2", "Inst_3", "MDXNET_1_9703", "MDXNET_2_9682", "MDXNET_3_9662", "Inst_Main", "MDXNET_Main", "MDXNET_9482"],
|
7 |
-
"demucs_model": ["HT-Normal", "HT-Tuned", "HD_MMI", "HT_6S"],
|
8 |
-
"edge_tts": ["af-ZA-AdriNeural", "af-ZA-WillemNeural", "sq-AL-AnilaNeural", "sq-AL-IlirNeural", "am-ET-AmehaNeural", "am-ET-MekdesNeural", "ar-DZ-AminaNeural", "ar-DZ-IsmaelNeural", "ar-BH-AliNeural", "ar-BH-LailaNeural", "ar-EG-SalmaNeural", "ar-EG-ShakirNeural", "ar-IQ-BasselNeural", "ar-IQ-RanaNeural", "ar-JO-SanaNeural", "ar-JO-TaimNeural", "ar-KW-FahedNeural", "ar-KW-NouraNeural", "ar-LB-LaylaNeural", "ar-LB-RamiNeural", "ar-LY-ImanNeural", "ar-LY-OmarNeural", "ar-MA-JamalNeural", "ar-MA-MounaNeural", "ar-OM-AbdullahNeural", "ar-OM-AyshaNeural", "ar-QA-AmalNeural", "ar-QA-MoazNeural", "ar-SA-HamedNeural", "ar-SA-ZariyahNeural", "ar-SY-AmanyNeural", "ar-SY-LaithNeural", "ar-TN-HediNeural", "ar-TN-ReemNeural", "ar-AE-FatimaNeural", "ar-AE-HamdanNeural", "ar-YE-MaryamNeural", "ar-YE-SalehNeural", "az-AZ-BabekNeural", "az-AZ-BanuNeural", "bn-BD-NabanitaNeural", "bn-BD-PradeepNeural", "bn-IN-BashkarNeural", "bn-IN-TanishaaNeural", "bs-BA-GoranNeural", "bs-BA-VesnaNeural", "bg-BG-BorislavNeural", "bg-BG-KalinaNeural", "my-MM-NilarNeural", "my-MM-ThihaNeural", "ca-ES-EnricNeural", "ca-ES-JoanaNeural", "zh-HK-HiuGaaiNeural", "zh-HK-HiuMaanNeural", "zh-HK-WanLungNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural", "zh-CN-YunjianNeural", "zh-CN-YunxiNeural", "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-TW-HsiaoChenNeural", "zh-TW-YunJheNeural", "zh-TW-HsiaoYuNeural", "zh-CN-shaanxi-XiaoniNeural", "hr-HR-GabrijelaNeural", "hr-HR-SreckoNeural", "cs-CZ-AntoninNeural", "cs-CZ-VlastaNeural", "da-DK-ChristelNeural", "da-DK-JeppeNeural", "nl-BE-ArnaudNeural", "nl-BE-DenaNeural", "nl-NL-ColetteNeural", "nl-NL-FennaNeural", "nl-NL-MaartenNeural", "en-AU-NatashaNeural", "en-AU-WilliamNeural", "en-CA-ClaraNeural", "en-CA-LiamNeural", "en-HK-SamNeural", "en-HK-YanNeural", "en-IN-NeerjaExpressiveNeural", "en-IN-NeerjaNeural", "en-IN-PrabhatNeural", "en-IE-ConnorNeural", "en-IE-EmilyNeural", "en-KE-AsiliaNeural", "en-KE-ChilembaNeural", "en-NZ-MitchellNeural", "en-NZ-MollyNeural", "en-NG-AbeoNeural", "en-NG-EzinneNeural", "en-PH-JamesNeural", "en-PH-RosaNeural", "en-SG-LunaNeural", "en-SG-WayneNeural", "en-ZA-LeahNeural", "en-ZA-LukeNeural", "en-TZ-ElimuNeural", "en-TZ-ImaniNeural", "en-GB-LibbyNeural", "en-GB-MaisieNeural", "en-GB-RyanNeural", "en-GB-SoniaNeural", "en-GB-ThomasNeural", "en-US-AvaMultilingualNeural", "en-US-AndrewMultilingualNeural", "en-US-EmmaMultilingualNeural", "en-US-BrianMultilingualNeural", "en-US-AvaNeural", "en-US-AndrewNeural", "en-US-EmmaNeural", "en-US-BrianNeural", "en-US-AnaNeural", "en-US-AriaNeural", "en-US-ChristopherNeural", "en-US-EricNeural", "en-US-GuyNeural", "en-US-JennyNeural", "en-US-MichelleNeural", "en-US-RogerNeural", "en-US-SteffanNeural", "et-EE-AnuNeural", "et-EE-KertNeural", "fil-PH-AngeloNeural", "fil-PH-BlessicaNeural", "fi-FI-HarriNeural", "fi-FI-NooraNeural", "fr-BE-CharlineNeural", "fr-BE-GerardNeural", "fr-CA-ThierryNeural", "fr-CA-AntoineNeural", "fr-CA-JeanNeural", "fr-CA-SylvieNeural", "fr-FR-VivienneMultilingualNeural", "fr-FR-RemyMultilingualNeural", "fr-FR-DeniseNeural", "fr-FR-EloiseNeural", "fr-FR-HenriNeural", "fr-CH-ArianeNeural", "fr-CH-FabriceNeural", "gl-ES-RoiNeural", "gl-ES-SabelaNeural", "ka-GE-EkaNeural", "ka-GE-GiorgiNeural", "de-AT-IngridNeural", "de-AT-JonasNeural", "de-DE-SeraphinaMultilingualNeural", "de-DE-FlorianMultilingualNeural", "de-DE-AmalaNeural", "de-DE-ConradNeural", "de-DE-KatjaNeural", "de-DE-KillianNeural", "de-CH-JanNeural", "de-CH-LeniNeural", "el-GR-AthinaNeural", "el-GR-NestorasNeural", "gu-IN-DhwaniNeural", "gu-IN-NiranjanNeural", "he-IL-AvriNeural", "he-IL-HilaNeural", "hi-IN-MadhurNeural", "hi-IN-SwaraNeural", "hu-HU-NoemiNeural", "hu-HU-TamasNeural", "is-IS-GudrunNeural", "is-IS-GunnarNeural", "id-ID-ArdiNeural", "id-ID-GadisNeural", "ga-IE-ColmNeural", "ga-IE-OrlaNeural", "it-IT-GiuseppeNeural", "it-IT-DiegoNeural", "it-IT-ElsaNeural", "it-IT-IsabellaNeural", "ja-JP-KeitaNeural", "ja-JP-NanamiNeural", "jv-ID-DimasNeural", "jv-ID-SitiNeural", "kn-IN-GaganNeural", "kn-IN-SapnaNeural", "kk-KZ-AigulNeural", "kk-KZ-DauletNeural", "km-KH-PisethNeural", "km-KH-SreymomNeural", "ko-KR-HyunsuNeural", "ko-KR-InJoonNeural", "ko-KR-SunHiNeural", "lo-LA-ChanthavongNeural", "lo-LA-KeomanyNeural", "lv-LV-EveritaNeural", "lv-LV-NilsNeural", "lt-LT-LeonasNeural", "lt-LT-OnaNeural", "mk-MK-AleksandarNeural", "mk-MK-MarijaNeural", "ms-MY-OsmanNeural", "ms-MY-YasminNeural", "ml-IN-MidhunNeural", "ml-IN-SobhanaNeural", "mt-MT-GraceNeural", "mt-MT-JosephNeural", "mr-IN-AarohiNeural", "mr-IN-ManoharNeural", "mn-MN-BataaNeural", "mn-MN-YesuiNeural", "ne-NP-HemkalaNeural", "ne-NP-SagarNeural", "nb-NO-FinnNeural", "nb-NO-PernilleNeural", "ps-AF-GulNawazNeural", "ps-AF-LatifaNeural", "fa-IR-DilaraNeural", "fa-IR-FaridNeural", "pl-PL-MarekNeural", "pl-PL-ZofiaNeural", "pt-BR-ThalitaNeural", "pt-BR-AntonioNeural", "pt-BR-FranciscaNeural", "pt-PT-DuarteNeural", "pt-PT-RaquelNeural", "ro-RO-AlinaNeural", "ro-RO-EmilNeural", "ru-RU-DmitryNeural", "ru-RU-SvetlanaNeural", "sr-RS-NicholasNeural", "sr-RS-SophieNeural", "si-LK-SameeraNeural", "si-LK-ThiliniNeural", "sk-SK-LukasNeural", "sk-SK-ViktoriaNeural", "sl-SI-PetraNeural", "sl-SI-RokNeural", "so-SO-MuuseNeural", "so-SO-UbaxNeural", "es-AR-ElenaNeural", "es-AR-TomasNeural", "es-BO-MarceloNeural", "es-BO-SofiaNeural", "es-CL-CatalinaNeural", "es-CL-LorenzoNeural", "es-ES-XimenaNeural", "es-CO-GonzaloNeural", "es-CO-SalomeNeural", "es-CR-JuanNeural", "es-CR-MariaNeural", "es-CU-BelkysNeural", "es-CU-ManuelNeural", "es-DO-EmilioNeural", "es-DO-RamonaNeural", "es-EC-AndreaNeural", "es-EC-LuisNeural", "es-SV-LorenaNeural", "es-SV-RodrigoNeural", "es-GQ-JavierNeural", "es-GQ-TeresaNeural", "es-GT-AndresNeural", "es-GT-MartaNeural", "es-HN-CarlosNeural", "es-HN-KarlaNeural", "es-MX-DaliaNeural", "es-MX-JorgeNeural", "es-NI-FedericoNeural", "es-NI-YolandaNeural", "es-PA-MargaritaNeural", "es-PA-RobertoNeural", "es-PY-MarioNeural", "es-PY-TaniaNeural", "es-PE-AlexNeural", "es-PE-CamilaNeural", "es-PR-KarinaNeural", "es-PR-VictorNeural", "es-ES-AlvaroNeural", "es-ES-ElviraNeural", "es-US-AlonsoNeural", "es-US-PalomaNeural", "es-UY-MateoNeural", "es-UY-ValentinaNeural", "es-VE-PaolaNeural", "es-VE-SebastianNeural", "su-ID-JajangNeural", "su-ID-TutiNeural", "sw-KE-RafikiNeural", "sw-KE-ZuriNeural", "sw-TZ-DaudiNeural", "sw-TZ-RehemaNeural", "sv-SE-MattiasNeural", "sv-SE-SofieNeural", "ta-IN-PallaviNeural", "ta-IN-ValluvarNeural", "ta-MY-KaniNeural", "ta-MY-SuryaNeural", "ta-SG-AnbuNeural", "ta-SG-VenbaNeural", "ta-LK-KumarNeural", "ta-LK-SaranyaNeural", "te-IN-MohanNeural", "te-IN-ShrutiNeural", "th-TH-NiwatNeural", "th-TH-PremwadeeNeural", "tr-TR-AhmetNeural", "tr-TR-EmelNeural", "uk-UA-OstapNeural", "uk-UA-PolinaNeural", "ur-IN-GulNeural", "ur-IN-SalmanNeural", "ur-PK-AsadNeural", "ur-PK-UzmaNeural", "uz-UZ-MadinaNeural", "uz-UZ-SardorNeural", "vi-VN-HoaiMyNeural", "vi-VN-NamMinhNeural", "cy-GB-AledNeural", "cy-GB-NiaNeural", "zu-ZA-ThandoNeural", "zu-ZA-ThembaNeural"],
|
9 |
-
"google_tts_voice": ["af", "am", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fi", "fr", "fr-CA", "gl", "gu", "ha", "hi", "hr", "hu", "id", "is", "it", "iw", "ja", "jw", "km", "kn", "ko", "la", "lt", "lv", "ml", "mr", "ms", "my", "ne", "nl", "no", "pa", "pl", "pt", "pt-PT", "ro", "ru", "si", "sk", "sq", "sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "vi", "yue", "zh-CN", "zh-TW", "zh"],
|
10 |
-
"separator_tab": true,
|
11 |
-
"convert_tab": true,
|
12 |
-
"tts_tab": true,
|
13 |
-
"effects_tab": true,
|
14 |
-
"create_dataset_tab": true,
|
15 |
-
"training_tab": true,
|
16 |
-
"fushion_tab": true,
|
17 |
-
"read_tab": true,
|
18 |
-
"downloads_tab": true,
|
19 |
-
"settings_tab": true,
|
20 |
-
"report_bug_tab": true,
|
21 |
-
"app_port": 7860,
|
22 |
-
"tensorboard_port": 6870,
|
23 |
-
"num_of_restart": 5,
|
24 |
-
"server_name": "0.0.0.0",
|
25 |
-
"app_show_error": true
|
26 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import torch
|
4 |
-
|
5 |
-
version_config_paths = [os.path.join(version, size) for version in ["v1", "v2"] for size in ["32000.json", "40000.json", "44100.json", "48000.json"]]
|
6 |
-
|
7 |
-
def singleton(cls):
|
8 |
-
instances = {}
|
9 |
-
def get_instance(*args, **kwargs):
|
10 |
-
if cls not in instances: instances[cls] = cls(*args, **kwargs)
|
11 |
-
return instances[cls]
|
12 |
-
return get_instance
|
13 |
-
|
14 |
-
@singleton
|
15 |
-
class Config:
|
16 |
-
def __init__(self):
|
17 |
-
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
18 |
-
self.gpu_name = (torch.cuda.get_device_name(int(self.device.split(":")[-1])) if self.device.startswith("cuda") else None)
|
19 |
-
self.translations = self.multi_language()
|
20 |
-
self.json_config = self.load_config_json()
|
21 |
-
self.gpu_mem = None
|
22 |
-
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
23 |
-
|
24 |
-
def multi_language(self):
|
25 |
-
try:
|
26 |
-
with open(os.path.join("main", "configs", "config.json"), "r") as f:
|
27 |
-
configs = json.load(f)
|
28 |
-
|
29 |
-
lang = configs.get("language", "vi-VN")
|
30 |
-
if len([l for l in os.listdir(os.path.join("assets", "languages")) if l.endswith(".json")]) < 1: raise FileNotFoundError("Không tìm thấy bất cứ gói ngôn ngữ nào(No package languages found)")
|
31 |
-
|
32 |
-
if not lang: lang = "vi-VN"
|
33 |
-
if lang not in configs["support_language"]: raise ValueError("Ngôn ngữ không được hỗ trợ(Language not supported)")
|
34 |
-
|
35 |
-
lang_path = os.path.join("assets", "languages", f"{lang}.json")
|
36 |
-
if not os.path.exists(lang_path): lang_path = os.path.join("assets", "languages", "vi-VN.json")
|
37 |
-
|
38 |
-
with open(lang_path, encoding="utf-8") as f:
|
39 |
-
translations = json.load(f)
|
40 |
-
except json.JSONDecodeError:
|
41 |
-
print(self.translations["empty_json"].format(file=lang))
|
42 |
-
pass
|
43 |
-
return translations
|
44 |
-
|
45 |
-
def load_config_json(self):
|
46 |
-
configs = {}
|
47 |
-
for config_file in version_config_paths:
|
48 |
-
try:
|
49 |
-
with open(os.path.join("main", "configs", config_file), "r") as f:
|
50 |
-
configs[config_file] = json.load(f)
|
51 |
-
except json.JSONDecodeError:
|
52 |
-
print(self.translations["empty_json"].format(file=config_file))
|
53 |
-
pass
|
54 |
-
return configs
|
55 |
-
|
56 |
-
def device_config(self):
|
57 |
-
if self.device.startswith("cuda"): self.set_cuda_config()
|
58 |
-
elif self.has_mps(): self.device = "mps"
|
59 |
-
else: self.device = "cpu"
|
60 |
-
|
61 |
-
if self.gpu_mem is not None and self.gpu_mem <= 4: return 1, 5, 30, 32
|
62 |
-
return 1, 6, 38, 41
|
63 |
-
|
64 |
-
def set_cuda_config(self):
|
65 |
-
i_device = int(self.device.split(":")[-1])
|
66 |
-
self.gpu_name = torch.cuda.get_device_name(i_device)
|
67 |
-
self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
|
68 |
-
|
69 |
-
def has_mps(self):
|
70 |
-
return torch.backends.mps.is_available()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/32000.json
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [0.8, 0.99],
|
8 |
-
"eps": 1e-09,
|
9 |
-
"batch_size": 4,
|
10 |
-
"lr_decay": 0.999875,
|
11 |
-
"segment_size": 12800,
|
12 |
-
"init_lr_ratio": 1,
|
13 |
-
"warmup_epochs": 0,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 32000,
|
20 |
-
"filter_length": 1024,
|
21 |
-
"hop_length": 320,
|
22 |
-
"win_length": 1024,
|
23 |
-
"n_mel_channels": 80,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 256,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
39 |
-
"upsample_rates": [10, 4, 2, 2, 2],
|
40 |
-
"upsample_initial_channel": 512,
|
41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
42 |
-
"use_spectral_norm": false,
|
43 |
-
"gin_channels": 256,
|
44 |
-
"spk_embed_dim": 109
|
45 |
-
}
|
46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/40000.json
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [0.8, 0.99],
|
8 |
-
"eps": 1e-09,
|
9 |
-
"batch_size": 4,
|
10 |
-
"lr_decay": 0.999875,
|
11 |
-
"segment_size": 12800,
|
12 |
-
"init_lr_ratio": 1,
|
13 |
-
"warmup_epochs": 0,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 40000,
|
20 |
-
"filter_length": 2048,
|
21 |
-
"hop_length": 400,
|
22 |
-
"win_length": 2048,
|
23 |
-
"n_mel_channels": 125,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 256,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
39 |
-
"upsample_rates": [10, 10, 2, 2],
|
40 |
-
"upsample_initial_channel": 512,
|
41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4],
|
42 |
-
"use_spectral_norm": false,
|
43 |
-
"gin_channels": 256,
|
44 |
-
"spk_embed_dim": 109
|
45 |
-
}
|
46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/44100.json
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [0.8, 0.99],
|
8 |
-
"eps": 1e-09,
|
9 |
-
"batch_size": 4,
|
10 |
-
"lr_decay": 0.999875,
|
11 |
-
"segment_size": 15876,
|
12 |
-
"init_lr_ratio": 1,
|
13 |
-
"warmup_epochs": 0,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 44100,
|
20 |
-
"filter_length": 2048,
|
21 |
-
"hop_length": 441,
|
22 |
-
"win_length": 2048,
|
23 |
-
"n_mel_channels": 160,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 256,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [ 1, 3, 5], [1, 3, 5]],
|
39 |
-
"upsample_rates": [7, 7, 3, 3],
|
40 |
-
"upsample_initial_channel": 512,
|
41 |
-
"upsample_kernel_sizes": [14, 14, 6, 6],
|
42 |
-
"use_spectral_norm": false,
|
43 |
-
"gin_channels": 256,
|
44 |
-
"spk_embed_dim": 109
|
45 |
-
}
|
46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/48000.json
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [0.8, 0.99],
|
8 |
-
"eps": 1e-09,
|
9 |
-
"batch_size": 4,
|
10 |
-
"lr_decay": 0.999875,
|
11 |
-
"segment_size": 11520,
|
12 |
-
"init_lr_ratio": 1,
|
13 |
-
"warmup_epochs": 0,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 48000,
|
20 |
-
"filter_length": 2048,
|
21 |
-
"hop_length": 480,
|
22 |
-
"win_length": 2048,
|
23 |
-
"n_mel_channels": 128,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 256,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
38 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
39 |
-
"upsample_rates": [10, 6, 2, 2, 2],
|
40 |
-
"upsample_initial_channel": 512,
|
41 |
-
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
|
42 |
-
"use_spectral_norm": false,
|
43 |
-
"gin_channels": 256,
|
44 |
-
"spk_embed_dim": 109
|
45 |
-
}
|
46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/32000.json
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [0.8, 0.99],
|
7 |
-
"eps": 1e-09,
|
8 |
-
"lr_decay": 0.999875,
|
9 |
-
"segment_size": 12800,
|
10 |
-
"c_mel": 45,
|
11 |
-
"c_kl": 1.0
|
12 |
-
},
|
13 |
-
"data": {
|
14 |
-
"max_wav_value": 32768.0,
|
15 |
-
"sample_rate": 32000,
|
16 |
-
"filter_length": 1024,
|
17 |
-
"hop_length": 320,
|
18 |
-
"win_length": 1024,
|
19 |
-
"n_mel_channels": 80,
|
20 |
-
"mel_fmin": 0.0,
|
21 |
-
"mel_fmax": null
|
22 |
-
},
|
23 |
-
"model": {
|
24 |
-
"inter_channels": 192,
|
25 |
-
"hidden_channels": 192,
|
26 |
-
"filter_channels": 768,
|
27 |
-
"text_enc_hidden_dim": 768,
|
28 |
-
"n_heads": 2,
|
29 |
-
"n_layers": 6,
|
30 |
-
"kernel_size": 3,
|
31 |
-
"p_dropout": 0,
|
32 |
-
"resblock": "1",
|
33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
35 |
-
"upsample_rates": [10, 8, 2, 2],
|
36 |
-
"upsample_initial_channel": 512,
|
37 |
-
"upsample_kernel_sizes": [20, 16, 4, 4],
|
38 |
-
"use_spectral_norm": false,
|
39 |
-
"gin_channels": 256,
|
40 |
-
"spk_embed_dim": 109
|
41 |
-
}
|
42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/40000.json
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [0.8, 0.99],
|
7 |
-
"eps": 1e-09,
|
8 |
-
"lr_decay": 0.999875,
|
9 |
-
"segment_size": 12800,
|
10 |
-
"c_mel": 45,
|
11 |
-
"c_kl": 1.0
|
12 |
-
},
|
13 |
-
"data": {
|
14 |
-
"max_wav_value": 32768.0,
|
15 |
-
"sample_rate": 40000,
|
16 |
-
"filter_length": 2048,
|
17 |
-
"hop_length": 400,
|
18 |
-
"win_length": 2048,
|
19 |
-
"n_mel_channels": 125,
|
20 |
-
"mel_fmin": 0.0,
|
21 |
-
"mel_fmax": null
|
22 |
-
},
|
23 |
-
"model": {
|
24 |
-
"inter_channels": 192,
|
25 |
-
"hidden_channels": 192,
|
26 |
-
"filter_channels": 768,
|
27 |
-
"text_enc_hidden_dim": 768,
|
28 |
-
"n_heads": 2,
|
29 |
-
"n_layers": 6,
|
30 |
-
"kernel_size": 3,
|
31 |
-
"p_dropout": 0,
|
32 |
-
"resblock": "1",
|
33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
35 |
-
"upsample_rates": [10, 10, 2, 2],
|
36 |
-
"upsample_initial_channel": 512,
|
37 |
-
"upsample_kernel_sizes": [16, 16, 4, 4],
|
38 |
-
"use_spectral_norm": false,
|
39 |
-
"gin_channels": 256,
|
40 |
-
"spk_embed_dim": 109
|
41 |
-
}
|
42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/44100.json
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [0.8, 0.99],
|
7 |
-
"eps": 1e-09,
|
8 |
-
"lr_decay": 0.999875,
|
9 |
-
"segment_size": 15876,
|
10 |
-
"c_mel": 45,
|
11 |
-
"c_kl": 1.0
|
12 |
-
},
|
13 |
-
"data": {
|
14 |
-
"max_wav_value": 32768.0,
|
15 |
-
"sample_rate": 44100,
|
16 |
-
"filter_length": 2048,
|
17 |
-
"hop_length": 441,
|
18 |
-
"win_length": 2048,
|
19 |
-
"n_mel_channels": 160,
|
20 |
-
"mel_fmin": 0.0,
|
21 |
-
"mel_fmax": null
|
22 |
-
},
|
23 |
-
"model": {
|
24 |
-
"inter_channels": 192,
|
25 |
-
"hidden_channels": 192,
|
26 |
-
"filter_channels": 768,
|
27 |
-
"text_enc_hidden_dim": 768,
|
28 |
-
"n_heads": 2,
|
29 |
-
"n_layers": 6,
|
30 |
-
"kernel_size": 3,
|
31 |
-
"p_dropout": 0,
|
32 |
-
"resblock": "1",
|
33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
35 |
-
"upsample_rates": [7, 7, 3, 3],
|
36 |
-
"upsample_initial_channel": 512,
|
37 |
-
"upsample_kernel_sizes": [14, 14, 6, 6],
|
38 |
-
"use_spectral_norm": false,
|
39 |
-
"gin_channels": 256,
|
40 |
-
"spk_embed_dim": 109
|
41 |
-
}
|
42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/48000.json
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [0.8, 0.99],
|
7 |
-
"eps": 1e-09,
|
8 |
-
"lr_decay": 0.999875,
|
9 |
-
"segment_size": 17280,
|
10 |
-
"c_mel": 45,
|
11 |
-
"c_kl": 1.0
|
12 |
-
},
|
13 |
-
"data": {
|
14 |
-
"max_wav_value": 32768.0,
|
15 |
-
"sample_rate": 48000,
|
16 |
-
"filter_length": 2048,
|
17 |
-
"hop_length": 480,
|
18 |
-
"win_length": 2048,
|
19 |
-
"n_mel_channels": 128,
|
20 |
-
"mel_fmin": 0.0,
|
21 |
-
"mel_fmax": null
|
22 |
-
},
|
23 |
-
"model": {
|
24 |
-
"inter_channels": 192,
|
25 |
-
"hidden_channels": 192,
|
26 |
-
"filter_channels": 768,
|
27 |
-
"text_enc_hidden_dim": 768,
|
28 |
-
"n_heads": 2,
|
29 |
-
"n_layers": 6,
|
30 |
-
"kernel_size": 3,
|
31 |
-
"p_dropout": 0,
|
32 |
-
"resblock": "1",
|
33 |
-
"resblock_kernel_sizes": [3, 7, 11],
|
34 |
-
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
35 |
-
"upsample_rates": [12, 10, 2, 2],
|
36 |
-
"upsample_initial_channel": 512,
|
37 |
-
"upsample_kernel_sizes": [24, 20, 4, 4],
|
38 |
-
"use_spectral_norm": false,
|
39 |
-
"gin_channels": 256,
|
40 |
-
"spk_embed_dim": 109
|
41 |
-
}
|
42 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/audio_effects.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import librosa
|
4 |
-
import argparse
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import soundfile as sf
|
8 |
-
|
9 |
-
from distutils.util import strtobool
|
10 |
-
from scipy.signal import butter, filtfilt
|
11 |
-
from pedalboard import Pedalboard, Chorus, Distortion, Reverb, PitchShift, Delay, Limiter, Gain, Bitcrush, Clipping, Compressor, Phaser, HighpassFilter
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.configs.config import Config
|
16 |
-
from main.library.utils import pydub_convert
|
17 |
-
|
18 |
-
translations = Config().translations
|
19 |
-
|
20 |
-
def parse_arguments():
|
21 |
-
parser = argparse.ArgumentParser()
|
22 |
-
parser.add_argument("--input_path", type=str, required=True)
|
23 |
-
parser.add_argument("--output_path", type=str, default="./audios/apply_effects.wav")
|
24 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
25 |
-
parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
|
26 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
27 |
-
parser.add_argument("--chorus", type=lambda x: bool(strtobool(x)), default=False)
|
28 |
-
parser.add_argument("--chorus_depth", type=float, default=0.5)
|
29 |
-
parser.add_argument("--chorus_rate", type=float, default=1.5)
|
30 |
-
parser.add_argument("--chorus_mix", type=float, default=0.5)
|
31 |
-
parser.add_argument("--chorus_delay", type=int, default=10)
|
32 |
-
parser.add_argument("--chorus_feedback", type=float, default=0)
|
33 |
-
parser.add_argument("--distortion", type=lambda x: bool(strtobool(x)), default=False)
|
34 |
-
parser.add_argument("--drive_db", type=int, default=20)
|
35 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
36 |
-
parser.add_argument("--reverb_room_size", type=float, default=0.5)
|
37 |
-
parser.add_argument("--reverb_damping", type=float, default=0.5)
|
38 |
-
parser.add_argument("--reverb_wet_level", type=float, default=0.33)
|
39 |
-
parser.add_argument("--reverb_dry_level", type=float, default=0.67)
|
40 |
-
parser.add_argument("--reverb_width", type=float, default=1)
|
41 |
-
parser.add_argument("--reverb_freeze_mode", type=lambda x: bool(strtobool(x)), default=False)
|
42 |
-
parser.add_argument("--pitchshift", type=lambda x: bool(strtobool(x)), default=False)
|
43 |
-
parser.add_argument("--pitch_shift", type=int, default=0)
|
44 |
-
parser.add_argument("--delay", type=lambda x: bool(strtobool(x)), default=False)
|
45 |
-
parser.add_argument("--delay_seconds", type=float, default=0.5)
|
46 |
-
parser.add_argument("--delay_feedback", type=float, default=0.5)
|
47 |
-
parser.add_argument("--delay_mix", type=float, default=0.5)
|
48 |
-
parser.add_argument("--compressor", type=lambda x: bool(strtobool(x)), default=False)
|
49 |
-
parser.add_argument("--compressor_threshold", type=int, default=-20)
|
50 |
-
parser.add_argument("--compressor_ratio", type=float, default=4)
|
51 |
-
parser.add_argument("--compressor_attack_ms", type=float, default=10)
|
52 |
-
parser.add_argument("--compressor_release_ms", type=int, default=200)
|
53 |
-
parser.add_argument("--limiter", type=lambda x: bool(strtobool(x)), default=False)
|
54 |
-
parser.add_argument("--limiter_threshold", type=int, default=0)
|
55 |
-
parser.add_argument("--limiter_release", type=int, default=100)
|
56 |
-
parser.add_argument("--gain", type=lambda x: bool(strtobool(x)), default=False)
|
57 |
-
parser.add_argument("--gain_db", type=int, default=0)
|
58 |
-
parser.add_argument("--bitcrush", type=lambda x: bool(strtobool(x)), default=False)
|
59 |
-
parser.add_argument("--bitcrush_bit_depth", type=int, default=16)
|
60 |
-
parser.add_argument("--clipping", type=lambda x: bool(strtobool(x)), default=False)
|
61 |
-
parser.add_argument("--clipping_threshold", type=int, default=-10)
|
62 |
-
parser.add_argument("--phaser", type=lambda x: bool(strtobool(x)), default=False)
|
63 |
-
parser.add_argument("--phaser_rate_hz", type=float, default=0.5)
|
64 |
-
parser.add_argument("--phaser_depth", type=float, default=0.5)
|
65 |
-
parser.add_argument("--phaser_centre_frequency_hz", type=int, default=1000)
|
66 |
-
parser.add_argument("--phaser_feedback", type=float, default=0)
|
67 |
-
parser.add_argument("--phaser_mix", type=float, default=0.5)
|
68 |
-
parser.add_argument("--treble_bass_boost", type=lambda x: bool(strtobool(x)), default=False)
|
69 |
-
parser.add_argument("--bass_boost_db", type=int, default=0)
|
70 |
-
parser.add_argument("--bass_boost_frequency", type=int, default=100)
|
71 |
-
parser.add_argument("--treble_boost_db", type=int, default=0)
|
72 |
-
parser.add_argument("--treble_boost_frequency", type=int, default=3000)
|
73 |
-
parser.add_argument("--fade_in_out", type=lambda x: bool(strtobool(x)), default=False)
|
74 |
-
parser.add_argument("--fade_in_duration", type=float, default=2000)
|
75 |
-
parser.add_argument("--fade_out_duration", type=float, default=2000)
|
76 |
-
parser.add_argument("--audio_combination", type=lambda x: bool(strtobool(x)), default=False)
|
77 |
-
parser.add_argument("--audio_combination_input", type=str)
|
78 |
-
|
79 |
-
return parser.parse_args()
|
80 |
-
|
81 |
-
def process_audio(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, pitchshift, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out, audio_combination, audio_combination_input):
|
82 |
-
def bass_boost(audio, gain_db, frequency, sample_rate):
|
83 |
-
if gain_db >= 1:
|
84 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='low')
|
85 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
86 |
-
else: return audio
|
87 |
-
|
88 |
-
def treble_boost(audio, gain_db, frequency, sample_rate):
|
89 |
-
if gain_db >=1:
|
90 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='high')
|
91 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
92 |
-
else: return audio
|
93 |
-
|
94 |
-
def fade_out_effect(audio, sr, duration=3.0):
|
95 |
-
length = int(duration * sr)
|
96 |
-
end = audio.shape[0]
|
97 |
-
|
98 |
-
if length > end: length = end
|
99 |
-
start = end - length
|
100 |
-
|
101 |
-
audio[start:end] = audio[start:end] * np.linspace(1.0, 0.0, length)
|
102 |
-
return audio
|
103 |
-
|
104 |
-
def fade_in_effect(audio, sr, duration=3.0):
|
105 |
-
length = int(duration * sr)
|
106 |
-
start = 0
|
107 |
-
|
108 |
-
if length > audio.shape[0]: length = audio.shape[0]
|
109 |
-
end = length
|
110 |
-
|
111 |
-
audio[start:end] = audio[start:end] * np.linspace(0.0, 1.0, length)
|
112 |
-
return audio
|
113 |
-
|
114 |
-
if not input_path or not os.path.exists(input_path):
|
115 |
-
print(translations["input_not_valid"])
|
116 |
-
sys.exit(1)
|
117 |
-
|
118 |
-
if not output_path:
|
119 |
-
print(translations["output_not_valid"])
|
120 |
-
sys.exit(1)
|
121 |
-
|
122 |
-
if os.path.exists(output_path): os.remove(output_path)
|
123 |
-
|
124 |
-
try:
|
125 |
-
audio, sample_rate = sf.read(input_path)
|
126 |
-
except Exception as e:
|
127 |
-
raise RuntimeError(translations["errors_loading_audio"].format(e=e))
|
128 |
-
|
129 |
-
try:
|
130 |
-
board = Pedalboard([HighpassFilter()])
|
131 |
-
|
132 |
-
if chorus: board.append(Chorus(depth=chorus_depth, rate_hz=chorus_rate, mix=chorus_mix, centre_delay_ms=chorus_delay, feedback=chorus_feedback))
|
133 |
-
if distortion: board.append(Distortion(drive_db=distortion_drive))
|
134 |
-
if reverb: board.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=reverb_dry_level, width=reverb_width, freeze_mode=1 if reverb_freeze_mode else 0))
|
135 |
-
if pitchshift: board.append(PitchShift(semitones=pitch_shift))
|
136 |
-
if delay: board.append(Delay(delay_seconds=delay_seconds, feedback=delay_feedback, mix=delay_mix))
|
137 |
-
if compressor: board.append(Compressor(threshold_db=compressor_threshold, ratio=compressor_ratio, attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
|
138 |
-
if limiter: board.append(Limiter(threshold_db=limiter_threshold, release_ms=limiter_release))
|
139 |
-
if gain: board.append(Gain(gain_db=gain_db))
|
140 |
-
if bitcrush: board.append(Bitcrush(bit_depth=bitcrush_bit_depth))
|
141 |
-
if clipping: board.append(Clipping(threshold_db=clipping_threshold))
|
142 |
-
if phaser: board.append(Phaser(rate_hz=phaser_rate_hz, depth=phaser_depth, centre_frequency_hz=phaser_centre_frequency_hz, feedback=phaser_feedback, mix=phaser_mix))
|
143 |
-
|
144 |
-
processed_audio = board(audio, sample_rate)
|
145 |
-
|
146 |
-
if treble_bass_boost:
|
147 |
-
processed_audio = bass_boost(processed_audio, bass_boost_db, bass_boost_frequency, sample_rate)
|
148 |
-
processed_audio = treble_boost(processed_audio, treble_boost_db, treble_boost_frequency, sample_rate)
|
149 |
-
|
150 |
-
if fade_in_out:
|
151 |
-
processed_audio = fade_in_effect(processed_audio, sample_rate, fade_in_duration)
|
152 |
-
processed_audio = fade_out_effect(processed_audio, sample_rate, fade_out_duration)
|
153 |
-
|
154 |
-
if resample_sr != sample_rate and resample_sr > 0 and resample:
|
155 |
-
target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - resample_sr))
|
156 |
-
processed_audio = librosa.resample(processed_audio, orig_sr=sample_rate, target_sr=target_sr, res_type="soxr_vhq")
|
157 |
-
sample_rate = target_sr
|
158 |
-
|
159 |
-
sf.write(output_path.replace("wav", export_format), processed_audio, sample_rate, format=export_format)
|
160 |
-
|
161 |
-
if audio_combination:
|
162 |
-
from pydub import AudioSegment
|
163 |
-
pydub_convert(AudioSegment.from_file(audio_combination_input)).overlay(pydub_convert(AudioSegment.from_file(output_path.replace("wav", export_format)))).export(output_path.replace("wav", export_format), format=export_format)
|
164 |
-
except Exception as e:
|
165 |
-
raise RuntimeError(translations["apply_error"].format(e=e))
|
166 |
-
return output_path
|
167 |
-
|
168 |
-
if __name__ == "__main__":
|
169 |
-
args = parse_arguments()
|
170 |
-
process_audio(input_path=args.input_path, output_path=args.output_path, resample=args.resample, resample_sr=args.resample_sr, chorus_depth=args.chorus_depth, chorus_rate=args.chorus_rate, chorus_mix=args.chorus_mix, chorus_delay=args.chorus_delay, chorus_feedback=args.chorus_feedback, distortion_drive=args.drive_db, reverb_room_size=args.reverb_room_size, reverb_damping=args.reverb_damping, reverb_wet_level=args.reverb_wet_level, reverb_dry_level=args.reverb_dry_level, reverb_width=args.reverb_width, reverb_freeze_mode=args.reverb_freeze_mode, pitch_shift=args.pitch_shift, delay_seconds=args.delay_seconds, delay_feedback=args.delay_feedback, delay_mix=args.delay_mix, compressor_threshold=args.compressor_threshold, compressor_ratio=args.compressor_ratio, compressor_attack_ms=args.compressor_attack_ms, compressor_release_ms=args.compressor_release_ms, limiter_threshold=args.limiter_threshold, limiter_release=args.limiter_release, gain_db=args.gain_db, bitcrush_bit_depth=args.bitcrush_bit_depth, clipping_threshold=args.clipping_threshold, phaser_rate_hz=args.phaser_rate_hz, phaser_depth=args.phaser_depth, phaser_centre_frequency_hz=args.phaser_centre_frequency_hz, phaser_feedback=args.phaser_feedback, phaser_mix=args.phaser_mix, bass_boost_db=args.bass_boost_db, bass_boost_frequency=args.bass_boost_frequency, treble_boost_db=args.treble_boost_db, treble_boost_frequency=args.treble_boost_frequency, fade_in_duration=args.fade_in_duration, fade_out_duration=args.fade_out_duration, export_format=args.export_format, chorus=args.chorus, distortion=args.distortion, reverb=args.reverb, pitchshift=args.pitchshift, delay=args.delay, compressor=args.compressor, limiter=args.limiter, gain=args.gain, bitcrush=args.bitcrush, clipping=args.clipping, phaser=args.phaser, treble_bass_boost=args.treble_bass_boost, fade_in_out=args.fade_in_out, audio_combination=args.audio_combination, audio_combination_input=args.audio_combination_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/convert.py
DELETED
@@ -1,650 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import os
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
import faiss
|
6 |
-
import torch
|
7 |
-
import shutil
|
8 |
-
import librosa
|
9 |
-
import logging
|
10 |
-
import argparse
|
11 |
-
import warnings
|
12 |
-
import parselmouth
|
13 |
-
import onnxruntime
|
14 |
-
import logging.handlers
|
15 |
-
|
16 |
-
import numpy as np
|
17 |
-
import soundfile as sf
|
18 |
-
import torch.nn.functional as F
|
19 |
-
|
20 |
-
from tqdm import tqdm
|
21 |
-
from scipy import signal
|
22 |
-
from distutils.util import strtobool
|
23 |
-
from fairseq import checkpoint_utils
|
24 |
-
|
25 |
-
warnings.filterwarnings("ignore")
|
26 |
-
sys.path.append(os.getcwd())
|
27 |
-
|
28 |
-
from main.configs.config import Config
|
29 |
-
from main.library.predictors.FCPE import FCPE
|
30 |
-
from main.library.predictors.RMVPE import RMVPE
|
31 |
-
from main.library.predictors.WORLD import PYWORLD
|
32 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
33 |
-
from main.library.predictors.CREPE import predict, mean, median
|
34 |
-
from main.library.utils import check_predictors, check_embedders, load_audio, process_audio, merge_audio
|
35 |
-
|
36 |
-
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
|
37 |
-
config = Config()
|
38 |
-
translations = config.translations
|
39 |
-
logger = logging.getLogger(__name__)
|
40 |
-
logger.propagate = False
|
41 |
-
|
42 |
-
for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3"]:
|
43 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
44 |
-
|
45 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
46 |
-
else:
|
47 |
-
console_handler = logging.StreamHandler()
|
48 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
49 |
-
console_handler.setFormatter(console_formatter)
|
50 |
-
console_handler.setLevel(logging.INFO)
|
51 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "convert.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
52 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
53 |
-
file_handler.setFormatter(file_formatter)
|
54 |
-
file_handler.setLevel(logging.DEBUG)
|
55 |
-
logger.addHandler(console_handler)
|
56 |
-
logger.addHandler(file_handler)
|
57 |
-
logger.setLevel(logging.DEBUG)
|
58 |
-
|
59 |
-
def parse_arguments():
|
60 |
-
parser = argparse.ArgumentParser()
|
61 |
-
parser.add_argument("--pitch", type=int, default=0)
|
62 |
-
parser.add_argument("--filter_radius", type=int, default=3)
|
63 |
-
parser.add_argument("--index_rate", type=float, default=0.5)
|
64 |
-
parser.add_argument("--volume_envelope", type=float, default=1)
|
65 |
-
parser.add_argument("--protect", type=float, default=0.33)
|
66 |
-
parser.add_argument("--hop_length", type=int, default=64)
|
67 |
-
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
68 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
69 |
-
parser.add_argument("--input_path", type=str, required=True)
|
70 |
-
parser.add_argument("--output_path", type=str, default="./audios/output.wav")
|
71 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
72 |
-
parser.add_argument("--pth_path", type=str, required=True)
|
73 |
-
parser.add_argument("--index_path", type=str)
|
74 |
-
parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
|
75 |
-
parser.add_argument("--f0_autotune_strength", type=float, default=1)
|
76 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
77 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
78 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
79 |
-
parser.add_argument("--split_audio", type=lambda x: bool(strtobool(x)), default=False)
|
80 |
-
parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
|
81 |
-
|
82 |
-
return parser.parse_args()
|
83 |
-
|
84 |
-
def main():
|
85 |
-
args = parse_arguments()
|
86 |
-
pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, checkpointing = args.pitch, args.filter_radius, args.index_rate, args.volume_envelope,args.protect, args.hop_length, args.f0_method, args.input_path, args.output_path, args.pth_path, args.index_path, args.f0_autotune, args.f0_autotune_strength, args.clean_audio, args.clean_strength, args.export_format, args.embedder_model, args.resample_sr, args.split_audio, args.checkpointing
|
87 |
-
|
88 |
-
log_data = {translations['pitch']: pitch, translations['filter_radius']: filter_radius, translations['index_strength']: index_rate, translations['volume_envelope']: volume_envelope, translations['protect']: protect, "Hop length": hop_length, translations['f0_method']: f0_method, translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_path']: pth_path, translations['indexpath']: index_path, translations['autotune']: f0_autotune, translations['clear_audio']: clean_audio, translations['export_format']: export_format, translations['hubert_model']: embedder_model, translations['split_audio']: split_audio, translations['memory_efficient_training']: checkpointing}
|
89 |
-
|
90 |
-
if clean_audio: log_data[translations['clean_strength']] = clean_strength
|
91 |
-
if resample_sr != 0: log_data[translations['sample_rate']] = resample_sr
|
92 |
-
if f0_autotune: log_data[translations['autotune_rate_info']] = f0_autotune_strength
|
93 |
-
|
94 |
-
for key, value in log_data.items():
|
95 |
-
logger.debug(f"{key}: {value}")
|
96 |
-
|
97 |
-
check_predictors(f0_method)
|
98 |
-
check_embedders(embedder_model)
|
99 |
-
|
100 |
-
run_convert_script(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, input_path=input_path, output_path=output_path, pth_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, split_audio=split_audio, checkpointing=checkpointing)
|
101 |
-
|
102 |
-
def run_batch_convert(params):
|
103 |
-
path, audio_temp, export_format, cut_files, pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, embedder_model, resample_sr, checkpointing = params["path"], params["audio_temp"], params["export_format"], params["cut_files"], params["pitch"], params["filter_radius"], params["index_rate"], params["volume_envelope"], params["protect"], params["hop_length"], params["f0_method"], params["pth_path"], params["index_path"], params["f0_autotune"], params["f0_autotune_strength"], params["clean_audio"], params["clean_strength"], params["embedder_model"], params["resample_sr"], params["checkpointing"]
|
104 |
-
|
105 |
-
segment_output_path = os.path.join(audio_temp, f"output_{cut_files.index(path)}.{export_format}")
|
106 |
-
if os.path.exists(segment_output_path): os.remove(segment_output_path)
|
107 |
-
|
108 |
-
VoiceConverter().convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=path, audio_output_path=segment_output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing)
|
109 |
-
os.remove(path)
|
110 |
-
|
111 |
-
if os.path.exists(segment_output_path): return segment_output_path
|
112 |
-
else:
|
113 |
-
logger.warning(f"{translations['not_found_convert_file']}: {segment_output_path}")
|
114 |
-
sys.exit(1)
|
115 |
-
|
116 |
-
def run_convert_script(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, checkpointing):
|
117 |
-
cvt = VoiceConverter()
|
118 |
-
start_time = time.time()
|
119 |
-
|
120 |
-
pid_path = os.path.join("assets", "convert_pid.txt")
|
121 |
-
with open(pid_path, "w") as pid_file:
|
122 |
-
pid_file.write(str(os.getpid()))
|
123 |
-
|
124 |
-
if not pth_path or not os.path.exists(pth_path) or os.path.isdir(pth_path) or not pth_path.endswith(".pth"):
|
125 |
-
logger.warning(translations["provide_file"].format(filename=translations["model"]))
|
126 |
-
sys.exit(1)
|
127 |
-
|
128 |
-
output_dir = os.path.dirname(output_path) or output_path
|
129 |
-
if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
|
130 |
-
|
131 |
-
processed_segments = []
|
132 |
-
audio_temp = os.path.join("audios_temp")
|
133 |
-
if not os.path.exists(audio_temp) and split_audio: os.makedirs(audio_temp, exist_ok=True)
|
134 |
-
|
135 |
-
if os.path.isdir(input_path):
|
136 |
-
try:
|
137 |
-
logger.info(translations["convert_batch"])
|
138 |
-
audio_files = [f for f in os.listdir(input_path) if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
|
139 |
-
|
140 |
-
if not audio_files:
|
141 |
-
logger.warning(translations["not_found_audio"])
|
142 |
-
sys.exit(1)
|
143 |
-
|
144 |
-
logger.info(translations["found_audio"].format(audio_files=len(audio_files)))
|
145 |
-
|
146 |
-
for audio in audio_files:
|
147 |
-
audio_path = os.path.join(input_path, audio)
|
148 |
-
output_audio = os.path.join(input_path, os.path.splitext(audio)[0] + f"_output.{export_format}")
|
149 |
-
|
150 |
-
if split_audio:
|
151 |
-
try:
|
152 |
-
cut_files, time_stamps = process_audio(logger, audio_path, audio_temp)
|
153 |
-
params_list = [{"path": path, "audio_temp": audio_temp, "export_format": export_format, "cut_files": cut_files, "pitch": pitch, "filter_radius": filter_radius, "index_rate": index_rate, "volume_envelope": volume_envelope, "protect": protect, "hop_length": hop_length, "f0_method": f0_method, "pth_path": pth_path, "index_path": index_path, "f0_autotune": f0_autotune, "f0_autotune_strength": f0_autotune_strength, "clean_audio": clean_audio, "clean_strength": clean_strength, "embedder_model": embedder_model, "resample_sr": resample_sr, "checkpointing": checkpointing} for path in cut_files]
|
154 |
-
|
155 |
-
with tqdm(total=len(params_list), desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
|
156 |
-
for params in params_list:
|
157 |
-
results = run_batch_convert(params)
|
158 |
-
processed_segments.append(results)
|
159 |
-
pbar.update(1)
|
160 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
161 |
-
|
162 |
-
merge_audio(processed_segments, time_stamps, audio_path, output_audio, export_format)
|
163 |
-
except Exception as e:
|
164 |
-
logger.error(translations["error_convert_batch"].format(e=e))
|
165 |
-
finally:
|
166 |
-
if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
|
167 |
-
else:
|
168 |
-
try:
|
169 |
-
logger.info(f"{translations['convert_audio']} '{audio_path}'...")
|
170 |
-
if os.path.exists(output_audio): os.remove(output_audio)
|
171 |
-
|
172 |
-
with tqdm(total=1, desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
|
173 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=audio_path, audio_output_path=output_audio, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing)
|
174 |
-
pbar.update(1)
|
175 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
176 |
-
except Exception as e:
|
177 |
-
logger.error(translations["error_convert"].format(e=e))
|
178 |
-
|
179 |
-
elapsed_time = time.time() - start_time
|
180 |
-
logger.info(translations["convert_batch_success"].format(elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('wav', export_format)))
|
181 |
-
except Exception as e:
|
182 |
-
logger.error(translations["error_convert_batch_2"].format(e=e))
|
183 |
-
else:
|
184 |
-
logger.info(f"{translations['convert_audio']} '{input_path}'...")
|
185 |
-
|
186 |
-
if not os.path.exists(input_path):
|
187 |
-
logger.warning(translations["not_found_audio"])
|
188 |
-
sys.exit(1)
|
189 |
-
|
190 |
-
if os.path.isdir(output_path): output_path = os.path.join(output_path, f"output.{export_format}")
|
191 |
-
if os.path.exists(output_path): os.remove(output_path)
|
192 |
-
|
193 |
-
if split_audio:
|
194 |
-
try:
|
195 |
-
cut_files, time_stamps = process_audio(logger, input_path, audio_temp)
|
196 |
-
params_list = [{"path": path, "audio_temp": audio_temp, "export_format": export_format, "cut_files": cut_files, "pitch": pitch, "filter_radius": filter_radius, "index_rate": index_rate, "volume_envelope": volume_envelope, "protect": protect, "hop_length": hop_length, "f0_method": f0_method, "pth_path": pth_path, "index_path": index_path, "f0_autotune": f0_autotune, "f0_autotune_strength": f0_autotune_strength, "clean_audio": clean_audio, "clean_strength": clean_strength, "embedder_model": embedder_model, "resample_sr": resample_sr, "checkpointing": checkpointing} for path in cut_files]
|
197 |
-
|
198 |
-
with tqdm(total=len(params_list), desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
|
199 |
-
for params in params_list:
|
200 |
-
results = run_batch_convert(params)
|
201 |
-
processed_segments.append(results)
|
202 |
-
pbar.update(1)
|
203 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
204 |
-
|
205 |
-
merge_audio(processed_segments, time_stamps, input_path, output_path.replace("wav", export_format), export_format)
|
206 |
-
except Exception as e:
|
207 |
-
logger.error(translations["error_convert_batch"].format(e=e))
|
208 |
-
finally:
|
209 |
-
if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
|
210 |
-
else:
|
211 |
-
try:
|
212 |
-
with tqdm(total=1, desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
|
213 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=input_path, audio_output_path=output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing)
|
214 |
-
pbar.update(1)
|
215 |
-
|
216 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
217 |
-
except Exception as e:
|
218 |
-
logger.error(translations["error_convert"].format(e=e))
|
219 |
-
|
220 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
221 |
-
elapsed_time = time.time() - start_time
|
222 |
-
logger.info(translations["convert_audio_success"].format(input_path=input_path, elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('wav', export_format)))
|
223 |
-
|
224 |
-
def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
|
225 |
-
rms2 = F.interpolate(torch.from_numpy(librosa.feature.rms(y=target_audio, frame_length=target_rate // 2 * 2, hop_length=target_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze()
|
226 |
-
return (target_audio * (torch.pow(F.interpolate(torch.from_numpy(librosa.feature.rms(y=source_audio, frame_length=source_rate // 2 * 2, hop_length=source_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze(), 1 - rate) * torch.pow(torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6), rate - 1)).numpy())
|
227 |
-
|
228 |
-
class Autotune:
|
229 |
-
def __init__(self, ref_freqs):
|
230 |
-
self.ref_freqs = ref_freqs
|
231 |
-
self.note_dict = self.ref_freqs
|
232 |
-
|
233 |
-
def autotune_f0(self, f0, f0_autotune_strength):
|
234 |
-
autotuned_f0 = np.zeros_like(f0)
|
235 |
-
|
236 |
-
for i, freq in enumerate(f0):
|
237 |
-
autotuned_f0[i] = freq + (min(self.note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
|
238 |
-
|
239 |
-
return autotuned_f0
|
240 |
-
|
241 |
-
class VC:
|
242 |
-
def __init__(self, tgt_sr, config):
|
243 |
-
self.x_pad = config.x_pad
|
244 |
-
self.x_query = config.x_query
|
245 |
-
self.x_center = config.x_center
|
246 |
-
self.x_max = config.x_max
|
247 |
-
self.sample_rate = 16000
|
248 |
-
self.window = 160
|
249 |
-
self.t_pad = self.sample_rate * self.x_pad
|
250 |
-
self.t_pad_tgt = tgt_sr * self.x_pad
|
251 |
-
self.t_pad2 = self.t_pad * 2
|
252 |
-
self.t_query = self.sample_rate * self.x_query
|
253 |
-
self.t_center = self.sample_rate * self.x_center
|
254 |
-
self.t_max = self.sample_rate * self.x_max
|
255 |
-
self.time_step = self.window / self.sample_rate * 1000
|
256 |
-
self.f0_min = 50
|
257 |
-
self.f0_max = 1100
|
258 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
259 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
260 |
-
self.device = config.device
|
261 |
-
self.ref_freqs = [49.00, 51.91, 55.00, 58.27, 61.74, 65.41, 69.30, 73.42, 77.78, 82.41, 87.31, 92.50, 98.00, 103.83, 110.00, 116.54, 123.47, 130.81, 138.59, 146.83, 155.56, 164.81, 174.61, 185.00, 196.00, 207.65, 220.00, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13, 329.63, 349.23, 369.99, 392.00, 415.30, 440.00, 466.16, 493.88, 523.25, 554.37, 587.33, 622.25, 659.25, 698.46, 739.99, 783.99, 830.61, 880.00, 932.33, 987.77, 1046.50]
|
262 |
-
self.autotune = Autotune(self.ref_freqs)
|
263 |
-
self.note_dict = self.autotune.note_dict
|
264 |
-
|
265 |
-
def get_providers(self):
|
266 |
-
ort_providers = onnxruntime.get_available_providers()
|
267 |
-
|
268 |
-
if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
|
269 |
-
elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
|
270 |
-
else: providers = ["CPUExecutionProvider"]
|
271 |
-
|
272 |
-
return providers
|
273 |
-
|
274 |
-
def get_f0_pm(self, x, p_len):
|
275 |
-
f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
|
276 |
-
pad_size = (p_len - len(f0) + 1) // 2
|
277 |
-
|
278 |
-
if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
279 |
-
return f0
|
280 |
-
|
281 |
-
def get_f0_mangio_crepe(self, x, p_len, hop_length, model="full", onnx=False):
|
282 |
-
providers = self.get_providers() if onnx else None
|
283 |
-
|
284 |
-
x = x.astype(np.float32)
|
285 |
-
x /= np.quantile(np.abs(x), 0.999)
|
286 |
-
|
287 |
-
audio = torch.unsqueeze(torch.from_numpy(x).to(self.device, copy=True), dim=0)
|
288 |
-
if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
|
289 |
-
|
290 |
-
p_len = p_len or x.shape[0] // hop_length
|
291 |
-
source = np.array(predict(audio.detach(), self.sample_rate, hop_length, self.f0_min, self.f0_max, model, batch_size=hop_length * 2, device=self.device, pad=True, providers=providers, onnx=onnx).squeeze(0).cpu().float().numpy())
|
292 |
-
source[source < 0.001] = np.nan
|
293 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
|
294 |
-
|
295 |
-
def get_f0_crepe(self, x, model="full", onnx=False):
|
296 |
-
providers = self.get_providers() if onnx else None
|
297 |
-
|
298 |
-
f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.sample_rate, self.window, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=providers, onnx=onnx)
|
299 |
-
f0, pd = mean(f0, 3), median(pd, 3)
|
300 |
-
f0[pd < 0.1] = 0
|
301 |
-
|
302 |
-
return f0[0].cpu().numpy()
|
303 |
-
|
304 |
-
def get_f0_fcpe(self, x, p_len, hop_length, onnx=False, legacy=False):
|
305 |
-
providers = self.get_providers() if onnx else None
|
306 |
-
|
307 |
-
model_fcpe = FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03, providers=providers, onnx=onnx) if legacy else FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=self.window, f0_min=0, f0_max=8000, dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.006, providers=providers, onnx=onnx)
|
308 |
-
f0 = model_fcpe.compute_f0(x, p_len=p_len)
|
309 |
-
|
310 |
-
del model_fcpe
|
311 |
-
return f0
|
312 |
-
|
313 |
-
def get_f0_rmvpe(self, x, legacy=False, onnx=False):
|
314 |
-
providers = self.get_providers() if onnx else None
|
315 |
-
|
316 |
-
rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), device=self.device, onnx=onnx, providers=providers)
|
317 |
-
f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
|
318 |
-
|
319 |
-
del rmvpe_model
|
320 |
-
return f0
|
321 |
-
|
322 |
-
def get_f0_pyworld(self, x, filter_radius, model="harvest"):
|
323 |
-
pw = PYWORLD()
|
324 |
-
|
325 |
-
if model == "harvest": f0, t = pw.harvest(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
326 |
-
elif model == "dio": f0, t = pw.dio(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
327 |
-
else: raise ValueError(translations["method_not_valid"])
|
328 |
-
|
329 |
-
f0 = pw.stonemask(x.astype(np.double), self.sample_rate, t, f0)
|
330 |
-
|
331 |
-
if filter_radius > 2 or model == "dio": f0 = signal.medfilt(f0, 3)
|
332 |
-
return f0
|
333 |
-
|
334 |
-
def get_f0_yin(self, x, hop_length, p_len):
|
335 |
-
source = np.array(librosa.yin(x.astype(np.double), sr=self.sample_rate, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length))
|
336 |
-
source[source < 0.001] = np.nan
|
337 |
-
|
338 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
|
339 |
-
|
340 |
-
def get_f0_pyin(self, x, hop_length, p_len):
|
341 |
-
f0, _, _ = librosa.pyin(x.astype(np.double), fmin=self.f0_min, fmax=self.f0_max, sr=self.sample_rate, hop_length=hop_length)
|
342 |
-
source = np.array(f0)
|
343 |
-
source[source < 0.001] = np.nan
|
344 |
-
|
345 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
|
346 |
-
|
347 |
-
def get_f0_hybrid(self, methods_str, x, p_len, hop_length, filter_radius):
|
348 |
-
methods_str = re.search("hybrid\[(.+)\]", methods_str)
|
349 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
350 |
-
|
351 |
-
f0_computation_stack, resampled_stack = [], []
|
352 |
-
logger.debug(translations["hybrid_methods"].format(methods=methods))
|
353 |
-
|
354 |
-
x = x.astype(np.float32)
|
355 |
-
x /= np.quantile(np.abs(x), 0.999)
|
356 |
-
|
357 |
-
for method in methods:
|
358 |
-
f0 = None
|
359 |
-
|
360 |
-
if method == "pm": f0 = self.get_f0_pm(x, p_len)
|
361 |
-
elif method == "dio": f0 = self.get_f0_pyworld(x, filter_radius, "dio")
|
362 |
-
elif method == "mangio-crepe-tiny": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny")
|
363 |
-
elif method == "mangio-crepe-tiny-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=True)
|
364 |
-
elif method == "mangio-crepe-small": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small")
|
365 |
-
elif method == "mangio-crepe-small-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=True)
|
366 |
-
elif method == "mangio-crepe-medium": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium")
|
367 |
-
elif method == "mangio-crepe-medium-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=True)
|
368 |
-
elif method == "mangio-crepe-large": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large")
|
369 |
-
elif method == "mangio-crepe-large-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=True)
|
370 |
-
elif method == "mangio-crepe-full": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full")
|
371 |
-
elif method == "mangio-crepe-full-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=True)
|
372 |
-
elif method == "crepe-tiny": f0 = self.get_f0_crepe(x, "tiny")
|
373 |
-
elif method == "crepe-tiny-onnx": f0 = self.get_f0_crepe(x, "tiny", onnx=True)
|
374 |
-
elif method == "crepe-small": f0 = self.get_f0_crepe(x, "small")
|
375 |
-
elif method == "crepe-small-onnx": f0 = self.get_f0_crepe(x, "small", onnx=True)
|
376 |
-
elif method == "crepe-medium": f0 = self.get_f0_crepe(x, "medium")
|
377 |
-
elif method == "crepe-medium-onnx": f0 = self.get_f0_crepe(x, "medium", onnx=True)
|
378 |
-
elif method == "crepe-large": f0 = self.get_f0_crepe(x, "large")
|
379 |
-
elif method == "crepe-large-onnx": f0 = self.get_f0_crepe(x, "large", onnx=True)
|
380 |
-
elif method == "crepe-full": f0 = self.get_f0_crepe(x, "full")
|
381 |
-
elif method == "crepe-full-onnx": f0 = self.get_f0_crepe(x, "full", onnx=True)
|
382 |
-
elif method == "fcpe": f0 = self.get_f0_fcpe(x, p_len, int(hop_length))
|
383 |
-
elif method == "fcpe-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True)
|
384 |
-
elif method == "fcpe-legacy": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True)
|
385 |
-
elif method == "fcpe-legacy-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True, legacy=True)
|
386 |
-
elif method == "rmvpe": f0 = self.get_f0_rmvpe(x)
|
387 |
-
elif method == "rmvpe-onnx": f0 = self.get_f0_rmvpe(x, onnx=True)
|
388 |
-
elif method == "rmvpe-legacy": f0 = self.get_f0_rmvpe(x, legacy=True)
|
389 |
-
elif method == "rmvpe-legacy-onnx": f0 = self.get_f0_rmvpe(x, legacy=True, onnx=True)
|
390 |
-
elif method == "harvest": f0 = self.get_f0_pyworld(x, filter_radius, "harvest")
|
391 |
-
elif method == "yin": f0 = self.get_f0_yin(x, int(hop_length), p_len)
|
392 |
-
elif method == "pyin": f0 = self.get_f0_pyin(x, int(hop_length), p_len)
|
393 |
-
else: raise ValueError(translations["method_not_valid"])
|
394 |
-
|
395 |
-
f0_computation_stack.append(f0)
|
396 |
-
|
397 |
-
for f0 in f0_computation_stack:
|
398 |
-
resampled_stack.append(np.interp(np.linspace(0, len(f0), p_len), np.arange(len(f0)), f0))
|
399 |
-
|
400 |
-
return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
|
401 |
-
|
402 |
-
def get_f0(self, x, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength):
|
403 |
-
if f0_method == "pm": f0 = self.get_f0_pm(x, p_len)
|
404 |
-
elif f0_method == "dio": f0 = self.get_f0_pyworld(x, filter_radius, "dio")
|
405 |
-
elif f0_method == "mangio-crepe-tiny": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny")
|
406 |
-
elif f0_method == "mangio-crepe-tiny-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=True)
|
407 |
-
elif f0_method == "mangio-crepe-small": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small")
|
408 |
-
elif f0_method == "mangio-crepe-small-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=True)
|
409 |
-
elif f0_method == "mangio-crepe-medium": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium")
|
410 |
-
elif f0_method == "mangio-crepe-medium-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=True)
|
411 |
-
elif f0_method == "mangio-crepe-large": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large")
|
412 |
-
elif f0_method == "mangio-crepe-large-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=True)
|
413 |
-
elif f0_method == "mangio-crepe-full": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full")
|
414 |
-
elif f0_method == "mangio-crepe-full-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=True)
|
415 |
-
elif f0_method == "crepe-tiny": f0 = self.get_f0_crepe(x, "tiny")
|
416 |
-
elif f0_method == "crepe-tiny-onnx": f0 = self.get_f0_crepe(x, "tiny", onnx=True)
|
417 |
-
elif f0_method == "crepe-small": f0 = self.get_f0_crepe(x, "small")
|
418 |
-
elif f0_method == "crepe-small-onnx": f0 = self.get_f0_crepe(x, "small", onnx=True)
|
419 |
-
elif f0_method == "crepe-medium": f0 = self.get_f0_crepe(x, "medium")
|
420 |
-
elif f0_method == "crepe-medium-onnx": f0 = self.get_f0_crepe(x, "medium", onnx=True)
|
421 |
-
elif f0_method == "crepe-large": f0 = self.get_f0_crepe(x, "large")
|
422 |
-
elif f0_method == "crepe-large-onnx": f0 = self.get_f0_crepe(x, "large", onnx=True)
|
423 |
-
elif f0_method == "crepe-full": f0 = self.get_f0_crepe(x, "full")
|
424 |
-
elif f0_method == "crepe-full-onnx": f0 = self.get_f0_crepe(x, "full", onnx=True)
|
425 |
-
elif f0_method == "fcpe": f0 = self.get_f0_fcpe(x, p_len, int(hop_length))
|
426 |
-
elif f0_method == "fcpe-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True)
|
427 |
-
elif f0_method == "fcpe-legacy": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True)
|
428 |
-
elif f0_method == "fcpe-legacy-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True, legacy=True)
|
429 |
-
elif f0_method == "rmvpe": f0 = self.get_f0_rmvpe(x)
|
430 |
-
elif f0_method == "rmvpe-onnx": f0 = self.get_f0_rmvpe(x, onnx=True)
|
431 |
-
elif f0_method == "rmvpe-legacy": f0 = self.get_f0_rmvpe(x, legacy=True)
|
432 |
-
elif f0_method == "rmvpe-legacy-onnx": f0 = self.get_f0_rmvpe(x, legacy=True, onnx=True)
|
433 |
-
elif f0_method == "harvest": f0 = self.get_f0_pyworld(x, filter_radius, "harvest")
|
434 |
-
elif f0_method == "yin": f0 = self.get_f0_yin(x, int(hop_length), p_len)
|
435 |
-
elif f0_method == "pyin": f0 = self.get_f0_pyin(x, int(hop_length), p_len)
|
436 |
-
elif "hybrid" in f0_method: f0 = self.get_f0_hybrid(f0_method, x, p_len, hop_length, filter_radius)
|
437 |
-
else: raise ValueError(translations["method_not_valid"])
|
438 |
-
|
439 |
-
if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
|
440 |
-
|
441 |
-
f0 *= pow(2, pitch / 12)
|
442 |
-
f0_mel = 1127 * np.log(1 + f0 / 700)
|
443 |
-
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (self.f0_mel_max - self.f0_mel_min) + 1
|
444 |
-
f0_mel[f0_mel <= 1] = 1
|
445 |
-
f0_mel[f0_mel > 255] = 255
|
446 |
-
|
447 |
-
return np.rint(f0_mel).astype(np.int32), f0.copy()
|
448 |
-
|
449 |
-
def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect):
|
450 |
-
pitch_guidance = pitch != None and pitchf != None
|
451 |
-
feats = torch.from_numpy(audio0).float()
|
452 |
-
|
453 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
454 |
-
assert feats.dim() == 1, feats.dim()
|
455 |
-
|
456 |
-
feats = feats.view(1, -1)
|
457 |
-
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
|
458 |
-
inputs = {"source": feats.to(self.device), "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12}
|
459 |
-
|
460 |
-
with torch.no_grad():
|
461 |
-
logits = model.extract_features(**inputs)
|
462 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
463 |
-
|
464 |
-
if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
|
465 |
-
|
466 |
-
if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
|
467 |
-
npy = feats[0].cpu().numpy()
|
468 |
-
score, ix = index.search(npy, k=8)
|
469 |
-
weight = np.square(1 / score)
|
470 |
-
weight /= weight.sum(axis=1, keepdims=True)
|
471 |
-
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
472 |
-
feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
|
473 |
-
|
474 |
-
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
475 |
-
if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
476 |
-
|
477 |
-
p_len = audio0.shape[0] // self.window
|
478 |
-
|
479 |
-
if feats.shape[1] < p_len:
|
480 |
-
p_len = feats.shape[1]
|
481 |
-
if pitch_guidance:
|
482 |
-
pitch = pitch[:, :p_len]
|
483 |
-
pitchf = pitchf[:, :p_len]
|
484 |
-
|
485 |
-
if protect < 0.5 and pitch_guidance:
|
486 |
-
pitchff = pitchf.clone()
|
487 |
-
pitchff[pitchf > 0] = 1
|
488 |
-
pitchff[pitchf < 1] = protect
|
489 |
-
pitchff = pitchff.unsqueeze(-1)
|
490 |
-
feats = feats * pitchff + feats0 * (1 - pitchff)
|
491 |
-
feats = feats.to(feats0.dtype)
|
492 |
-
|
493 |
-
p_len = torch.tensor([p_len], device=self.device).long()
|
494 |
-
audio1 = ((net_g.infer(feats, p_len, pitch if pitch_guidance else None, pitchf if pitch_guidance else None, sid)[0][0, 0]).data.cpu().float().numpy())
|
495 |
-
|
496 |
-
del feats, p_len, padding_mask
|
497 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
498 |
-
return audio1
|
499 |
-
|
500 |
-
def pipeline(self, model, net_g, sid, audio, pitch, f0_method, file_index, index_rate, pitch_guidance, filter_radius, tgt_sr, resample_sr, volume_envelope, version, protect, hop_length, f0_autotune, f0_autotune_strength):
|
501 |
-
if file_index != "" and os.path.exists(file_index) and index_rate != 0:
|
502 |
-
try:
|
503 |
-
index = faiss.read_index(file_index)
|
504 |
-
big_npy = index.reconstruct_n(0, index.ntotal)
|
505 |
-
except Exception as e:
|
506 |
-
logger.error(translations["read_faiss_index_error"].format(e=e))
|
507 |
-
index = big_npy = None
|
508 |
-
else: index = big_npy = None
|
509 |
-
|
510 |
-
audio = signal.filtfilt(bh, ah, audio)
|
511 |
-
audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
|
512 |
-
opt_ts, audio_opt = [], []
|
513 |
-
|
514 |
-
if audio_pad.shape[0] > self.t_max:
|
515 |
-
audio_sum = np.zeros_like(audio)
|
516 |
-
|
517 |
-
for i in range(self.window):
|
518 |
-
audio_sum += audio_pad[i : i - self.window]
|
519 |
-
|
520 |
-
for t in range(self.t_center, audio.shape[0], self.t_center):
|
521 |
-
opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
|
522 |
-
|
523 |
-
s = 0
|
524 |
-
t = None
|
525 |
-
|
526 |
-
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
527 |
-
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
528 |
-
p_len = audio_pad.shape[0] // self.window
|
529 |
-
|
530 |
-
if pitch_guidance:
|
531 |
-
pitch, pitchf = self.get_f0(audio_pad, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength)
|
532 |
-
pitch, pitchf = pitch[:p_len], pitchf[:p_len]
|
533 |
-
|
534 |
-
if self.device == "mps": pitchf = pitchf.astype(np.float32)
|
535 |
-
pitch, pitchf = torch.tensor(pitch, device=self.device).unsqueeze(0).long(), torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
536 |
-
|
537 |
-
for t in opt_ts:
|
538 |
-
t = t // self.window * self.window
|
539 |
-
audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], pitch[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, pitchf[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
540 |
-
s = t
|
541 |
-
|
542 |
-
audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], (pitch[:, t // self.window :] if t is not None else pitch) if pitch_guidance else None, (pitchf[:, t // self.window :] if t is not None else pitchf) if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
543 |
-
audio_opt = np.concatenate(audio_opt)
|
544 |
-
|
545 |
-
if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, tgt_sr, volume_envelope)
|
546 |
-
if resample_sr >= self.sample_rate and tgt_sr != resample_sr: audio_opt = librosa.resample(audio_opt, orig_sr=tgt_sr, target_sr=resample_sr, res_type="soxr_vhq")
|
547 |
-
|
548 |
-
audio_max = np.abs(audio_opt).max() / 0.99
|
549 |
-
if audio_max > 1: audio_opt /= audio_max
|
550 |
-
|
551 |
-
if pitch_guidance: del pitch, pitchf
|
552 |
-
del sid
|
553 |
-
|
554 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
555 |
-
return audio_opt
|
556 |
-
|
557 |
-
class VoiceConverter:
|
558 |
-
def __init__(self):
|
559 |
-
self.config = config
|
560 |
-
self.hubert_model = None
|
561 |
-
self.tgt_sr = None
|
562 |
-
self.net_g = None
|
563 |
-
self.vc = None
|
564 |
-
self.cpt = None
|
565 |
-
self.version = None
|
566 |
-
self.n_spk = None
|
567 |
-
self.use_f0 = None
|
568 |
-
self.loaded_model = None
|
569 |
-
self.vocoder = "Default"
|
570 |
-
self.checkpointing = False
|
571 |
-
|
572 |
-
def load_embedders(self, embedder_model):
|
573 |
-
try:
|
574 |
-
models, _, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join("assets", "models", "embedders", embedder_model + '.pt')], suffix="")
|
575 |
-
except Exception as e:
|
576 |
-
logger.error(translations["read_model_error"].format(e=e))
|
577 |
-
self.hubert_model = models[0].to(self.config.device).float().eval()
|
578 |
-
|
579 |
-
def convert_audio(self, audio_input_path, audio_output_path, model_path, index_path, embedder_model, pitch, f0_method, index_rate, volume_envelope, protect, hop_length, f0_autotune, f0_autotune_strength, filter_radius, clean_audio, clean_strength, export_format, resample_sr = 0, sid = 0, checkpointing = False):
|
580 |
-
try:
|
581 |
-
self.get_vc(model_path, sid)
|
582 |
-
audio = load_audio(audio_input_path)
|
583 |
-
self.checkpointing = checkpointing
|
584 |
-
|
585 |
-
audio_max = np.abs(audio).max() / 0.95
|
586 |
-
if audio_max > 1: audio /= audio_max
|
587 |
-
|
588 |
-
if not self.hubert_model:
|
589 |
-
if not os.path.exists(os.path.join("assets", "models", "embedders", embedder_model + '.pt')): raise FileNotFoundError(f"{translations['not_found'].format(name=translations['model'])}: {embedder_model}")
|
590 |
-
self.load_embedders(embedder_model)
|
591 |
-
|
592 |
-
if self.tgt_sr != resample_sr >= 16000: self.tgt_sr = resample_sr
|
593 |
-
target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - self.tgt_sr))
|
594 |
-
|
595 |
-
audio_output = self.vc.pipeline(model=self.hubert_model, net_g=self.net_g, sid=sid, audio=audio, pitch=pitch, f0_method=f0_method, file_index=(index_path.strip().strip('"').strip("\n").strip('"').strip().replace("trained", "added")), index_rate=index_rate, pitch_guidance=self.use_f0, filter_radius=filter_radius, tgt_sr=self.tgt_sr, resample_sr=target_sr, volume_envelope=volume_envelope, version=self.version, protect=protect, hop_length=hop_length, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength)
|
596 |
-
|
597 |
-
if clean_audio:
|
598 |
-
from main.tools.noisereduce import reduce_noise
|
599 |
-
audio_output = reduce_noise(y=audio_output, sr=target_sr, prop_decrease=clean_strength)
|
600 |
-
|
601 |
-
sf.write(audio_output_path, audio_output, target_sr, format=export_format)
|
602 |
-
except Exception as e:
|
603 |
-
logger.error(translations["error_convert"].format(e=e))
|
604 |
-
|
605 |
-
import traceback
|
606 |
-
logger.debug(traceback.format_exc())
|
607 |
-
|
608 |
-
def get_vc(self, weight_root, sid):
|
609 |
-
if sid == "" or sid == []:
|
610 |
-
self.cleanup()
|
611 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
612 |
-
|
613 |
-
if not self.loaded_model or self.loaded_model != weight_root:
|
614 |
-
self.load_model(weight_root)
|
615 |
-
if self.cpt is not None: self.setup()
|
616 |
-
self.loaded_model = weight_root
|
617 |
-
|
618 |
-
def cleanup(self):
|
619 |
-
if self.hubert_model is not None:
|
620 |
-
del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
|
621 |
-
self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
|
622 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
623 |
-
|
624 |
-
del self.net_g, self.cpt
|
625 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
626 |
-
self.cpt = None
|
627 |
-
|
628 |
-
def load_model(self, weight_root):
|
629 |
-
self.cpt = (torch.load(weight_root, map_location="cpu") if os.path.isfile(weight_root) else None)
|
630 |
-
|
631 |
-
def setup(self):
|
632 |
-
if self.cpt is not None:
|
633 |
-
self.tgt_sr = self.cpt["config"][-1]
|
634 |
-
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
|
635 |
-
|
636 |
-
self.use_f0 = self.cpt.get("f0", 1)
|
637 |
-
self.version = self.cpt.get("version", "v1")
|
638 |
-
self.vocoder = self.cpt.get("vocoder", "Default")
|
639 |
-
|
640 |
-
self.text_enc_hidden_dim = 768 if self.version == "v2" else 256
|
641 |
-
self.net_g = Synthesizer(*self.cpt["config"], use_f0=self.use_f0, text_enc_hidden_dim=self.text_enc_hidden_dim, vocoder=self.vocoder, checkpointing=self.checkpointing)
|
642 |
-
del self.net_g.enc_q
|
643 |
-
|
644 |
-
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
|
645 |
-
self.net_g.eval().to(self.config.device).float()
|
646 |
-
|
647 |
-
self.vc = VC(self.tgt_sr, self.config)
|
648 |
-
self.n_spk = self.cpt["config"][-3]
|
649 |
-
|
650 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_dataset.py
DELETED
@@ -1,240 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import yt_dlp
|
5 |
-
import shutil
|
6 |
-
import librosa
|
7 |
-
import logging
|
8 |
-
import argparse
|
9 |
-
import warnings
|
10 |
-
import logging.handlers
|
11 |
-
|
12 |
-
from soundfile import read, write
|
13 |
-
from distutils.util import strtobool
|
14 |
-
|
15 |
-
sys.path.append(os.getcwd())
|
16 |
-
|
17 |
-
from main.configs.config import Config
|
18 |
-
from main.library.utils import process_audio, merge_audio
|
19 |
-
|
20 |
-
translations = Config().translations
|
21 |
-
dataset_temp = os.path.join("dataset_temp")
|
22 |
-
logger = logging.getLogger(__name__)
|
23 |
-
|
24 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
25 |
-
else:
|
26 |
-
console_handler = logging.StreamHandler()
|
27 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
28 |
-
console_handler.setFormatter(console_formatter)
|
29 |
-
console_handler.setLevel(logging.INFO)
|
30 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "create_dataset.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
31 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
32 |
-
file_handler.setFormatter(file_formatter)
|
33 |
-
file_handler.setLevel(logging.DEBUG)
|
34 |
-
logger.addHandler(console_handler)
|
35 |
-
logger.addHandler(file_handler)
|
36 |
-
logger.setLevel(logging.DEBUG)
|
37 |
-
|
38 |
-
def parse_arguments():
|
39 |
-
parser = argparse.ArgumentParser()
|
40 |
-
parser.add_argument("--input_audio", type=str, required=True)
|
41 |
-
parser.add_argument("--output_dataset", type=str, default="./dataset")
|
42 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
43 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
44 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
45 |
-
parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
46 |
-
parser.add_argument("--kim_vocal_version", type=int, default=2)
|
47 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
48 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
49 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
50 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
51 |
-
parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False)
|
52 |
-
parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False)
|
53 |
-
parser.add_argument("--skip_start_audios", type=str, default="0")
|
54 |
-
parser.add_argument("--skip_end_audios", type=str, default="0")
|
55 |
-
|
56 |
-
return parser.parse_args()
|
57 |
-
|
58 |
-
def main():
|
59 |
-
pid_path = os.path.join("assets", "create_dataset_pid.txt")
|
60 |
-
with open(pid_path, "w") as pid_file:
|
61 |
-
pid_file.write(str(os.getpid()))
|
62 |
-
|
63 |
-
args = parse_arguments()
|
64 |
-
input_audio, output_dataset, sample_rate, clean_dataset, clean_strength, separator_reverb, kim_vocal_version, overlap, segments_size, hop_length, batch_size, denoise_mdx, skip, skip_start_audios, skip_end_audios = args.input_audio, args.output_dataset, args.sample_rate, args.clean_dataset, args.clean_strength, args.separator_reverb, args.kim_vocal_version, args.overlap, args.segments_size, args.mdx_hop_length, args.mdx_batch_size, args.denoise_mdx, args.skip, args.skip_start_audios, args.skip_end_audios
|
65 |
-
log_data = {translations['audio_path']: input_audio, translations['output_path']: output_dataset, translations['sr']: sample_rate, translations['clear_dataset']: clean_dataset, translations['dereveb_audio']: separator_reverb, translations['segments_size']: segments_size, translations['overlap']: overlap, "Hop length": hop_length, translations['batch_size']: batch_size, translations['denoise_mdx']: denoise_mdx, translations['skip']: skip}
|
66 |
-
|
67 |
-
if clean_dataset: log_data[translations['clean_strength']] = clean_strength
|
68 |
-
if skip:
|
69 |
-
log_data[translations['skip_start']] = skip_start_audios
|
70 |
-
log_data[translations['skip_end']] = skip_end_audios
|
71 |
-
|
72 |
-
for key, value in log_data.items():
|
73 |
-
logger.debug(f"{key}: {value}")
|
74 |
-
|
75 |
-
if kim_vocal_version not in [1, 2]: raise ValueError(translations["version_not_valid"])
|
76 |
-
start_time = time.time()
|
77 |
-
|
78 |
-
try:
|
79 |
-
paths = []
|
80 |
-
|
81 |
-
if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True)
|
82 |
-
urls = input_audio.replace(", ", ",").split(",")
|
83 |
-
|
84 |
-
for url in urls:
|
85 |
-
path = downloader(url, urls.index(url))
|
86 |
-
paths.append(path)
|
87 |
-
|
88 |
-
if skip:
|
89 |
-
skip_start_audios = skip_start_audios.replace(", ", ",").split(",")
|
90 |
-
skip_end_audios = skip_end_audios.replace(", ", ",").split(",")
|
91 |
-
|
92 |
-
if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths):
|
93 |
-
logger.warning(translations["skip<audio"])
|
94 |
-
sys.exit(1)
|
95 |
-
elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths):
|
96 |
-
logger.warning(translations["skip>audio"])
|
97 |
-
sys.exit(1)
|
98 |
-
else:
|
99 |
-
for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios):
|
100 |
-
skip_start(audio, skip_start_audio)
|
101 |
-
skip_end(audio, skip_end_audio)
|
102 |
-
|
103 |
-
separator_paths = []
|
104 |
-
|
105 |
-
for audio in paths:
|
106 |
-
vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size, sample_rate)
|
107 |
-
if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size, sample_rate)
|
108 |
-
separator_paths.append(vocals)
|
109 |
-
|
110 |
-
paths = separator_paths
|
111 |
-
processed_paths = []
|
112 |
-
|
113 |
-
for audio in paths:
|
114 |
-
cut_files, time_stamps = process_audio(logger, audio, os.path.dirname(audio))
|
115 |
-
processed_paths.append(merge_audio(cut_files, time_stamps, audio, os.path.splitext(audio)[0] + "_processed" + ".wav", "wav"))
|
116 |
-
|
117 |
-
paths = processed_paths
|
118 |
-
|
119 |
-
for audio_path in paths:
|
120 |
-
data, sample_rate = read(audio_path)
|
121 |
-
data = librosa.to_mono(data.T)
|
122 |
-
|
123 |
-
if clean_dataset:
|
124 |
-
from main.tools.noisereduce import reduce_noise
|
125 |
-
data = reduce_noise(y=data, prop_decrease=clean_strength)
|
126 |
-
|
127 |
-
write(audio_path, data, sample_rate)
|
128 |
-
except Exception as e:
|
129 |
-
logger.error(f"{translations['create_dataset_error']}: {e}")
|
130 |
-
|
131 |
-
import traceback
|
132 |
-
logger.error(traceback.format_exc())
|
133 |
-
finally:
|
134 |
-
for audio in paths:
|
135 |
-
shutil.move(audio, output_dataset)
|
136 |
-
|
137 |
-
if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True)
|
138 |
-
|
139 |
-
elapsed_time = time.time() - start_time
|
140 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
141 |
-
logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
142 |
-
|
143 |
-
def downloader(url, name):
|
144 |
-
with warnings.catch_warnings():
|
145 |
-
warnings.simplefilter("ignore")
|
146 |
-
|
147 |
-
ydl_opts = {"format": "bestaudio/best", "outtmpl": os.path.join(dataset_temp, f"{name}"), "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "192"}], "no_warnings": True, "noplaylist": True, "noplaylist": True, "verbose": False}
|
148 |
-
logger.info(f"{translations['starting_download']}: {url}...")
|
149 |
-
|
150 |
-
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
151 |
-
ydl.extract_info(url)
|
152 |
-
logger.info(f"{translations['download_success']}: {url}")
|
153 |
-
|
154 |
-
return os.path.join(dataset_temp, f"{name}" + ".wav")
|
155 |
-
|
156 |
-
def skip_start(input_file, seconds):
|
157 |
-
data, sr = read(input_file)
|
158 |
-
total_duration = len(data) / sr
|
159 |
-
|
160 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
161 |
-
elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
162 |
-
else:
|
163 |
-
logger.info(f"{translations['skip_start']}: {input_file}...")
|
164 |
-
write(input_file, data[int(seconds * sr):], sr)
|
165 |
-
|
166 |
-
logger.info(translations["skip_start_audio"].format(input_file=input_file))
|
167 |
-
|
168 |
-
def skip_end(input_file, seconds):
|
169 |
-
data, sr = read(input_file)
|
170 |
-
total_duration = len(data) / sr
|
171 |
-
|
172 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
173 |
-
elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
174 |
-
else:
|
175 |
-
logger.info(f"{translations['skip_end']}: {input_file}...")
|
176 |
-
write(input_file, data[:-int(seconds * sr)], sr)
|
177 |
-
|
178 |
-
logger.info(translations["skip_end_audio"].format(input_file=input_file))
|
179 |
-
|
180 |
-
def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size, sample_rate):
|
181 |
-
if not os.path.exists(input):
|
182 |
-
logger.warning(translations["input_not_valid"])
|
183 |
-
return None
|
184 |
-
|
185 |
-
if not os.path.exists(output):
|
186 |
-
logger.warning(translations["output_not_valid"])
|
187 |
-
return None
|
188 |
-
|
189 |
-
model = f"Kim_Vocal_{version}.onnx"
|
190 |
-
output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
191 |
-
|
192 |
-
for f in output_separator:
|
193 |
-
path = os.path.join(output, f)
|
194 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
195 |
-
|
196 |
-
if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
197 |
-
elif '_(Vocals)_' in f:
|
198 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
199 |
-
os.rename(path, rename_file)
|
200 |
-
|
201 |
-
return rename_file
|
202 |
-
|
203 |
-
def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size, sample_rate):
|
204 |
-
if not os.path.exists(input):
|
205 |
-
logger.warning(translations["input_not_valid"])
|
206 |
-
return None
|
207 |
-
|
208 |
-
if not os.path.exists(output):
|
209 |
-
logger.warning(translations["output_not_valid"])
|
210 |
-
return None
|
211 |
-
|
212 |
-
logger.info(f"{translations['dereverb']}: {input}...")
|
213 |
-
output_dereverb = separator_main(audio_file=input, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
214 |
-
|
215 |
-
for f in output_dereverb:
|
216 |
-
path = os.path.join(output, f)
|
217 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
218 |
-
|
219 |
-
if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
220 |
-
elif '_(No Reverb)_' in f:
|
221 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
222 |
-
os.rename(path, rename_file)
|
223 |
-
|
224 |
-
logger.info(f"{translations['dereverb_success']}: {rename_file}")
|
225 |
-
return rename_file
|
226 |
-
|
227 |
-
def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, sample_rate=44100):
|
228 |
-
from main.library.algorithm.separator import Separator
|
229 |
-
|
230 |
-
try:
|
231 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise})
|
232 |
-
separator.load_model(model_filename=model_filename)
|
233 |
-
return separator.separate(audio_file)
|
234 |
-
except:
|
235 |
-
logger.debug(translations["default_setting"])
|
236 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise})
|
237 |
-
separator.load_model(model_filename=model_filename)
|
238 |
-
return separator.separate(audio_file)
|
239 |
-
|
240 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_index.py
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import faiss
|
4 |
-
import logging
|
5 |
-
import argparse
|
6 |
-
import logging.handlers
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
|
10 |
-
from multiprocessing import cpu_count
|
11 |
-
from sklearn.cluster import MiniBatchKMeans
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.configs.config import Config
|
16 |
-
translations = Config().translations
|
17 |
-
|
18 |
-
|
19 |
-
def parse_arguments():
|
20 |
-
parser = argparse.ArgumentParser()
|
21 |
-
parser.add_argument("--model_name", type=str, required=True)
|
22 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
23 |
-
parser.add_argument("--index_algorithm", type=str, default="Auto")
|
24 |
-
|
25 |
-
return parser.parse_args()
|
26 |
-
|
27 |
-
def main():
|
28 |
-
args = parse_arguments()
|
29 |
-
|
30 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
31 |
-
version = args.rvc_version
|
32 |
-
index_algorithm = args.index_algorithm
|
33 |
-
logger = logging.getLogger(__name__)
|
34 |
-
|
35 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
36 |
-
else:
|
37 |
-
console_handler = logging.StreamHandler()
|
38 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
39 |
-
console_handler.setFormatter(console_formatter)
|
40 |
-
console_handler.setLevel(logging.INFO)
|
41 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "create_index.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
42 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
43 |
-
file_handler.setFormatter(file_formatter)
|
44 |
-
file_handler.setLevel(logging.DEBUG)
|
45 |
-
logger.addHandler(console_handler)
|
46 |
-
logger.addHandler(file_handler)
|
47 |
-
logger.setLevel(logging.DEBUG)
|
48 |
-
|
49 |
-
log_data = {translations['modelname']: args.model_name, translations['model_path']: exp_dir, translations['training_version']: version, translations['index_algorithm_info']: index_algorithm}
|
50 |
-
for key, value in log_data.items():
|
51 |
-
logger.debug(f"{key}: {value}")
|
52 |
-
|
53 |
-
try:
|
54 |
-
npys = []
|
55 |
-
|
56 |
-
feature_dir = os.path.join(exp_dir, f"{version}_extracted")
|
57 |
-
model_name = os.path.basename(exp_dir)
|
58 |
-
|
59 |
-
for name in sorted(os.listdir(feature_dir)):
|
60 |
-
npys.append(np.load(os.path.join(feature_dir, name)))
|
61 |
-
|
62 |
-
big_npy = np.concatenate(npys, axis=0)
|
63 |
-
big_npy_idx = np.arange(big_npy.shape[0])
|
64 |
-
|
65 |
-
np.random.shuffle(big_npy_idx)
|
66 |
-
big_npy = big_npy[big_npy_idx]
|
67 |
-
|
68 |
-
if big_npy.shape[0] > 2e5 and (index_algorithm == "Auto" or index_algorithm == "KMeans"): big_npy = (MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * cpu_count(), compute_labels=False, init="random").fit(big_npy).cluster_centers_)
|
69 |
-
np.save(os.path.join(exp_dir, "total_fea.npy"), big_npy)
|
70 |
-
|
71 |
-
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
|
72 |
-
index_trained = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
73 |
-
|
74 |
-
index_ivf_trained = faiss.extract_index_ivf(index_trained)
|
75 |
-
index_ivf_trained.nprobe = 1
|
76 |
-
|
77 |
-
index_trained.train(big_npy)
|
78 |
-
faiss.write_index(index_trained, os.path.join(exp_dir, f"trained_IVF{n_ivf}_Flat_nprobe_{index_ivf_trained.nprobe}_{model_name}_{version}.index"))
|
79 |
-
|
80 |
-
index_added = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
81 |
-
index_ivf_added = faiss.extract_index_ivf(index_added)
|
82 |
-
index_ivf_added.nprobe = 1
|
83 |
-
|
84 |
-
index_added.train(big_npy)
|
85 |
-
batch_size_add = 8192
|
86 |
-
|
87 |
-
for i in range(0, big_npy.shape[0], batch_size_add):
|
88 |
-
index_added.add(big_npy[i : i + batch_size_add])
|
89 |
-
|
90 |
-
index_filepath_added = os.path.join(exp_dir, f"added_IVF{n_ivf}_Flat_nprobe_{index_ivf_added.nprobe}_{model_name}_{version}.index")
|
91 |
-
faiss.write_index(index_added, index_filepath_added)
|
92 |
-
|
93 |
-
logger.info(f"{translations['save_index']} '{index_filepath_added}'")
|
94 |
-
except Exception as e:
|
95 |
-
logger.error(f"{translations['create_index_error']}: {e}")
|
96 |
-
|
97 |
-
import traceback
|
98 |
-
logger.debug(traceback.format_exc())
|
99 |
-
|
100 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/extract.py
DELETED
@@ -1,450 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
import tqdm
|
6 |
-
import torch
|
7 |
-
import shutil
|
8 |
-
import librosa
|
9 |
-
import logging
|
10 |
-
import argparse
|
11 |
-
import warnings
|
12 |
-
import parselmouth
|
13 |
-
import logging.handlers
|
14 |
-
|
15 |
-
import numpy as np
|
16 |
-
import soundfile as sf
|
17 |
-
import torch.nn.functional as F
|
18 |
-
|
19 |
-
from random import shuffle
|
20 |
-
from multiprocessing import Pool
|
21 |
-
from distutils.util import strtobool
|
22 |
-
from fairseq import checkpoint_utils
|
23 |
-
from functools import partial
|
24 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
25 |
-
|
26 |
-
sys.path.append(os.getcwd())
|
27 |
-
|
28 |
-
from main.configs.config import Config
|
29 |
-
from main.library.predictors.FCPE import FCPE
|
30 |
-
from main.library.predictors.RMVPE import RMVPE
|
31 |
-
from main.library.predictors.WORLD import PYWORLD
|
32 |
-
from main.library.predictors.CREPE import predict, mean, median
|
33 |
-
from main.library.utils import check_predictors, check_embedders, load_audio
|
34 |
-
|
35 |
-
logger = logging.getLogger(__name__)
|
36 |
-
translations = Config().translations
|
37 |
-
logger.propagate = False
|
38 |
-
|
39 |
-
warnings.filterwarnings("ignore")
|
40 |
-
for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3"]:
|
41 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
42 |
-
|
43 |
-
def parse_arguments():
|
44 |
-
parser = argparse.ArgumentParser()
|
45 |
-
parser.add_argument("--model_name", type=str, required=True)
|
46 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
47 |
-
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
48 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
49 |
-
parser.add_argument("--hop_length", type=int, default=128)
|
50 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
51 |
-
parser.add_argument("--gpu", type=str, default="-")
|
52 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
53 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
54 |
-
|
55 |
-
return parser.parse_args()
|
56 |
-
|
57 |
-
def generate_config(rvc_version, sample_rate, model_path):
|
58 |
-
config_save_path = os.path.join(model_path, "config.json")
|
59 |
-
if not os.path.exists(config_save_path): shutil.copy(os.path.join("main", "configs", rvc_version, f"{sample_rate}.json"), config_save_path)
|
60 |
-
|
61 |
-
def generate_filelist(pitch_guidance, model_path, rvc_version, sample_rate):
|
62 |
-
gt_wavs_dir, feature_dir = os.path.join(model_path, "sliced_audios"), os.path.join(model_path, f"{rvc_version}_extracted")
|
63 |
-
f0_dir, f0nsf_dir = None, None
|
64 |
-
|
65 |
-
if pitch_guidance: f0_dir, f0nsf_dir = os.path.join(model_path, "f0"), os.path.join(model_path, "f0_voiced")
|
66 |
-
|
67 |
-
gt_wavs_files, feature_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir)), set(name.split(".")[0] for name in os.listdir(feature_dir))
|
68 |
-
names = gt_wavs_files & feature_files & set(name.split(".")[0] for name in os.listdir(f0_dir)) & set(name.split(".")[0] for name in os.listdir(f0nsf_dir)) if pitch_guidance else gt_wavs_files & feature_files
|
69 |
-
|
70 |
-
options = []
|
71 |
-
mute_base_path = os.path.join("assets", "logs", "mute")
|
72 |
-
|
73 |
-
for name in names:
|
74 |
-
options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|{f0_dir}/{name}.wav.npy|{f0nsf_dir}/{name}.wav.npy|0" if pitch_guidance else f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|0")
|
75 |
-
|
76 |
-
mute_audio_path, mute_feature_path = os.path.join(mute_base_path, "sliced_audios", f"mute{sample_rate}.wav"), os.path.join(mute_base_path, f"{rvc_version}_extracted", "mute.npy")
|
77 |
-
|
78 |
-
for _ in range(2):
|
79 |
-
options.append(f"{mute_audio_path}|{mute_feature_path}|{os.path.join(mute_base_path, 'f0', 'mute.wav.npy')}|{os.path.join(mute_base_path, 'f0_voiced', 'mute.wav.npy')}|0" if pitch_guidance else f"{mute_audio_path}|{mute_feature_path}|0")
|
80 |
-
|
81 |
-
shuffle(options)
|
82 |
-
|
83 |
-
with open(os.path.join(model_path, "filelist.txt"), "w") as f:
|
84 |
-
f.write("\n".join(options))
|
85 |
-
|
86 |
-
def setup_paths(exp_dir, version = None):
|
87 |
-
wav_path = os.path.join(exp_dir, "sliced_audios_16k")
|
88 |
-
|
89 |
-
if version:
|
90 |
-
out_path = os.path.join(exp_dir, f"{version}_extracted")
|
91 |
-
os.makedirs(out_path, exist_ok=True)
|
92 |
-
|
93 |
-
return wav_path, out_path
|
94 |
-
else:
|
95 |
-
output_root1, output_root2 = os.path.join(exp_dir, "f0"), os.path.join(exp_dir, "f0_voiced")
|
96 |
-
os.makedirs(output_root1, exist_ok=True); os.makedirs(output_root2, exist_ok=True)
|
97 |
-
|
98 |
-
return wav_path, output_root1, output_root2
|
99 |
-
|
100 |
-
def read_wave(wav_path, normalize = False):
|
101 |
-
wav, sr = sf.read(wav_path)
|
102 |
-
assert sr == 16000, translations["sr_not_16000"]
|
103 |
-
|
104 |
-
feats = torch.from_numpy(wav).float()
|
105 |
-
|
106 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
107 |
-
feats = feats.view(1, -1)
|
108 |
-
|
109 |
-
if normalize: feats = F.layer_norm(feats, feats.shape)
|
110 |
-
return feats
|
111 |
-
|
112 |
-
def get_device(gpu_index):
|
113 |
-
if gpu_index == "cpu": return "cpu"
|
114 |
-
|
115 |
-
try:
|
116 |
-
index = int(gpu_index)
|
117 |
-
|
118 |
-
if index < torch.cuda.device_count(): return f"cuda:{index}"
|
119 |
-
else: logger.warning(translations["gpu_not_valid"])
|
120 |
-
except ValueError:
|
121 |
-
logger.warning(translations["gpu_not_valid"])
|
122 |
-
return "cpu"
|
123 |
-
|
124 |
-
class FeatureInput:
|
125 |
-
def __init__(self, sample_rate=16000, hop_size=160, device="cpu"):
|
126 |
-
self.fs = sample_rate
|
127 |
-
self.hop = hop_size
|
128 |
-
self.f0_bin = 256
|
129 |
-
self.f0_max = 1100.0
|
130 |
-
self.f0_min = 50.0
|
131 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
132 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
133 |
-
self.device = device
|
134 |
-
|
135 |
-
def get_providers(self):
|
136 |
-
import onnxruntime
|
137 |
-
|
138 |
-
ort_providers = onnxruntime.get_available_providers()
|
139 |
-
|
140 |
-
if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
|
141 |
-
elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
|
142 |
-
else: providers = ["CPUExecutionProvider"]
|
143 |
-
|
144 |
-
return providers
|
145 |
-
|
146 |
-
def compute_f0_hybrid(self, methods_str, np_arr, hop_length):
|
147 |
-
methods_str = re.search("hybrid\[(.+)\]", methods_str)
|
148 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
149 |
-
|
150 |
-
f0_computation_stack, resampled_stack = [], []
|
151 |
-
logger.debug(translations["hybrid_methods"].format(methods=methods))
|
152 |
-
|
153 |
-
for method in methods:
|
154 |
-
f0 = None
|
155 |
-
|
156 |
-
if method == "pm": f0 = self.get_pm(np_arr)
|
157 |
-
elif method == "dio": f0 = self.get_pyworld(np_arr, "dio")
|
158 |
-
elif method == "mangio-crepe-full": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "full")
|
159 |
-
elif method == "mangio-crepe-full-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=True)
|
160 |
-
elif method == "mangio-crepe-large": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "large")
|
161 |
-
elif method == "mangio-crepe-large-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=True)
|
162 |
-
elif method == "mangio-crepe-medium": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "medium")
|
163 |
-
elif method == "mangio-crepe-medium-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=True)
|
164 |
-
elif method == "mangio-crepe-small": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "small")
|
165 |
-
elif method == "mangio-crepe-small-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=True)
|
166 |
-
elif method == "mangio-crepe-tiny": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "tiny")
|
167 |
-
elif method == "mangio-crepe-tiny-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=True)
|
168 |
-
elif method == "crepe-full": f0 = self.get_crepe(np_arr, "full")
|
169 |
-
elif method == "crepe-full-onnx": f0 = self.get_crepe(np_arr, "full", onnx=True)
|
170 |
-
elif method == "crepe-large": f0 = self.get_crepe(np_arr, "large")
|
171 |
-
elif method == "crepe-large-onnx": f0 = self.get_crepe(np_arr, "large", onnx=True)
|
172 |
-
elif method == "crepe-medium": f0 = self.get_crepe(np_arr, "medium")
|
173 |
-
elif method == "crepe-medium-onnx": f0 = self.get_crepe(np_arr, "medium", onnx=True)
|
174 |
-
elif method == "crepe-small": f0 = self.get_crepe(np_arr, "small")
|
175 |
-
elif method == "crepe-small-onnx": f0 = self.get_crepe(np_arr, "small", onnx=True)
|
176 |
-
elif method == "crepe-tiny": f0 = self.get_crepe(np_arr, "tiny")
|
177 |
-
elif method == "crepe-tiny-onnx": f0 = self.get_crepe(np_arr, "tiny", onnx=True)
|
178 |
-
elif method == "fcpe": f0 = self.get_fcpe(np_arr, int(hop_length))
|
179 |
-
elif method == "fcpe-onnx": f0 = self.get_fcpe(np_arr, int(hop_length), onnx=True)
|
180 |
-
elif method == "fcpe-legacy": f0 = self.get_fcpe(np_arr, int(hop_length), legacy=True)
|
181 |
-
elif method == "fcpe-legacy-onnx": f0 = self.get_fcpe(np_arr, int(hop_length), onnx=True, legacy=True)
|
182 |
-
elif method == "rmvpe": f0 = self.get_rmvpe(np_arr)
|
183 |
-
elif method == "rmvpe-onnx": f0 = self.get_rmvpe(np_arr, onnx=True)
|
184 |
-
elif method == "rmvpe-legacy": f0 = self.get_rmvpe(np_arr, legacy=True)
|
185 |
-
elif method == "rmvpe-legacy-onnx": f0 = self.get_rmvpe(np_arr, legacy=True, onnx=True)
|
186 |
-
elif method == "harvest": f0 = self.get_pyworld(np_arr, "harvest")
|
187 |
-
elif method == "yin": f0 = self.get_yin(np_arr, int(hop_length))
|
188 |
-
elif method == "pyin": return self.get_pyin(np_arr, int(hop_length))
|
189 |
-
else: raise ValueError(translations["method_not_valid"])
|
190 |
-
|
191 |
-
f0_computation_stack.append(f0)
|
192 |
-
|
193 |
-
for f0 in f0_computation_stack:
|
194 |
-
resampled_stack.append(np.interp(np.linspace(0, len(f0), (np_arr.size // self.hop)), np.arange(len(f0)), f0))
|
195 |
-
|
196 |
-
return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
|
197 |
-
|
198 |
-
def compute_f0(self, np_arr, f0_method, hop_length):
|
199 |
-
if f0_method == "pm": return self.get_pm(np_arr)
|
200 |
-
elif f0_method == "dio": return self.get_pyworld(np_arr, "dio")
|
201 |
-
elif f0_method == "mangio-crepe-full": return self.get_mangio_crepe(np_arr, int(hop_length), "full")
|
202 |
-
elif f0_method == "mangio-crepe-full-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=True)
|
203 |
-
elif f0_method == "mangio-crepe-large": return self.get_mangio_crepe(np_arr, int(hop_length), "large")
|
204 |
-
elif f0_method == "mangio-crepe-large-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=True)
|
205 |
-
elif f0_method == "mangio-crepe-medium": return self.get_mangio_crepe(np_arr, int(hop_length), "medium")
|
206 |
-
elif f0_method == "mangio-crepe-medium-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=True)
|
207 |
-
elif f0_method == "mangio-crepe-small": return self.get_mangio_crepe(np_arr, int(hop_length), "small")
|
208 |
-
elif f0_method == "mangio-crepe-small-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=True)
|
209 |
-
elif f0_method == "mangio-crepe-tiny": return self.get_mangio_crepe(np_arr, int(hop_length), "tiny")
|
210 |
-
elif f0_method == "mangio-crepe-tiny-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=True)
|
211 |
-
elif f0_method == "crepe-full": return self.get_crepe(np_arr, "full")
|
212 |
-
elif f0_method == "crepe-full-onnx": return self.get_crepe(np_arr, "full", onnx=True)
|
213 |
-
elif f0_method == "crepe-large": return self.get_crepe(np_arr, "large")
|
214 |
-
elif f0_method == "crepe-large-onnx": return self.get_crepe(np_arr, "large", onnx=True)
|
215 |
-
elif f0_method == "crepe-medium": return self.get_crepe(np_arr, "medium")
|
216 |
-
elif f0_method == "crepe-medium-onnx": return self.get_crepe(np_arr, "medium", onnx=True)
|
217 |
-
elif f0_method == "crepe-small": return self.get_crepe(np_arr, "small")
|
218 |
-
elif f0_method == "crepe-small-onnx": return self.get_crepe(np_arr, "small", onnx=True)
|
219 |
-
elif f0_method == "crepe-tiny": return self.get_crepe(np_arr, "tiny")
|
220 |
-
elif f0_method == "crepe-tiny-onnx": return self.get_crepe(np_arr, "tiny", onnx=True)
|
221 |
-
elif f0_method == "fcpe": return self.get_fcpe(np_arr, int(hop_length))
|
222 |
-
elif f0_method == "fcpe-onnx": return self.get_fcpe(np_arr, int(hop_length), onnx=True)
|
223 |
-
elif f0_method == "fcpe-legacy": return self.get_fcpe(np_arr, int(hop_length), legacy=True)
|
224 |
-
elif f0_method == "fcpe-legacy-onnx": return self.get_fcpe(np_arr, int(hop_length), onnx=True, legacy=True)
|
225 |
-
elif f0_method == "rmvpe": return self.get_rmvpe(np_arr)
|
226 |
-
elif f0_method == "rmvpe-onnx": return self.get_rmvpe(np_arr, onnx=True)
|
227 |
-
elif f0_method == "rmvpe-legacy": return self.get_rmvpe(np_arr, legacy=True)
|
228 |
-
elif f0_method == "rmvpe-legacy-onnx": return self.get_rmvpe(np_arr, legacy=True, onnx=True)
|
229 |
-
elif f0_method == "harvest": return self.get_pyworld(np_arr, "harvest")
|
230 |
-
elif f0_method == "yin": return self.get_yin(np_arr, int(hop_length))
|
231 |
-
elif f0_method == "pyin": return self.get_pyin(np_arr, int(hop_length))
|
232 |
-
elif "hybrid" in f0_method: return self.compute_f0_hybrid(f0_method, np_arr, int(hop_length))
|
233 |
-
else: raise ValueError(translations["method_not_valid"])
|
234 |
-
|
235 |
-
def get_pm(self, x):
|
236 |
-
f0 = (parselmouth.Sound(x, self.fs).to_pitch_ac(time_step=(160 / 16000 * 1000) / 1000, voicing_threshold=0.6, pitch_floor=50, pitch_ceiling=1100).selected_array["frequency"])
|
237 |
-
pad_size = ((x.size // self.hop) - len(f0) + 1) // 2
|
238 |
-
|
239 |
-
if pad_size > 0 or (x.size // self.hop) - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, (x.size // self.hop) - len(f0) - pad_size]], mode="constant")
|
240 |
-
return f0
|
241 |
-
|
242 |
-
def get_mangio_crepe(self, x, hop_length, model="full", onnx=False):
|
243 |
-
providers = self.get_providers() if onnx else None
|
244 |
-
|
245 |
-
audio = torch.from_numpy(x.astype(np.float32)).to(self.device)
|
246 |
-
audio /= torch.quantile(torch.abs(audio), 0.999)
|
247 |
-
audio = audio.unsqueeze(0)
|
248 |
-
|
249 |
-
source = predict(audio, self.fs, hop_length, self.f0_min, self.f0_max, model=model, batch_size=hop_length * 2, device=self.device, pad=True, providers=providers, onnx=onnx).squeeze(0).cpu().float().numpy()
|
250 |
-
source[source < 0.001] = np.nan
|
251 |
-
|
252 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
|
253 |
-
|
254 |
-
def get_crepe(self, x, model="full", onnx=False):
|
255 |
-
providers = self.get_providers() if onnx else None
|
256 |
-
|
257 |
-
f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.fs, 160, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=providers, onnx=onnx)
|
258 |
-
f0, pd = mean(f0, 3), median(pd, 3)
|
259 |
-
f0[pd < 0.1] = 0
|
260 |
-
|
261 |
-
return f0[0].cpu().numpy()
|
262 |
-
|
263 |
-
def get_fcpe(self, x, hop_length, legacy=False, onnx=False):
|
264 |
-
providers = self.get_providers() if onnx else None
|
265 |
-
|
266 |
-
model_fcpe = FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.03, providers=providers, onnx=onnx) if legacy else FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=160, f0_min=0, f0_max=8000, dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.006, providers=providers, onnx=onnx)
|
267 |
-
f0 = model_fcpe.compute_f0(x, p_len=(x.size // self.hop))
|
268 |
-
|
269 |
-
del model_fcpe
|
270 |
-
return f0
|
271 |
-
|
272 |
-
def get_rmvpe(self, x, legacy=False, onnx=False):
|
273 |
-
providers = self.get_providers() if onnx else None
|
274 |
-
|
275 |
-
rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), device=self.device, onnx=onnx, providers=providers)
|
276 |
-
f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
|
277 |
-
|
278 |
-
del rmvpe_model
|
279 |
-
return f0
|
280 |
-
|
281 |
-
def get_pyworld(self, x, model="harvest"):
|
282 |
-
pw = PYWORLD()
|
283 |
-
|
284 |
-
if model == "harvest": f0, t = pw.harvest(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
285 |
-
elif model == "dio": f0, t = pw.dio(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
286 |
-
else: raise ValueError(translations["method_not_valid"])
|
287 |
-
|
288 |
-
return pw.stonemask(x.astype(np.double), self.fs, t, f0)
|
289 |
-
|
290 |
-
def get_yin(self, x, hop_length):
|
291 |
-
source = np.array(librosa.yin(x.astype(np.double), sr=self.fs, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length))
|
292 |
-
source[source < 0.001] = np.nan
|
293 |
-
|
294 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
|
295 |
-
|
296 |
-
def get_pyin(self, x, hop_length):
|
297 |
-
f0, _, _ = librosa.pyin(x.astype(np.double), fmin=self.f0_min, fmax=self.f0_max, sr=self.fs, hop_length=hop_length)
|
298 |
-
|
299 |
-
source = np.array(f0)
|
300 |
-
source[source < 0.001] = np.nan
|
301 |
-
|
302 |
-
return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
|
303 |
-
|
304 |
-
def coarse_f0(self, f0):
|
305 |
-
return np.rint(np.clip(((1127 * np.log(1 + f0 / 700)) - self.f0_mel_min) * (self.f0_bin - 2) / (self.f0_mel_max - self.f0_mel_min) + 1, 1, self.f0_bin - 1)).astype(int)
|
306 |
-
|
307 |
-
def process_file(self, file_info, f0_method, hop_length):
|
308 |
-
inp_path, opt_path1, opt_path2, np_arr = file_info
|
309 |
-
if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"): return
|
310 |
-
|
311 |
-
try:
|
312 |
-
feature_pit = self.compute_f0(np_arr, f0_method, hop_length)
|
313 |
-
np.save(opt_path2, feature_pit, allow_pickle=False)
|
314 |
-
np.save(opt_path1, self.coarse_f0(feature_pit), allow_pickle=False)
|
315 |
-
except Exception as e:
|
316 |
-
raise RuntimeError(f"{translations['extract_file_error']} {inp_path}: {e}")
|
317 |
-
|
318 |
-
def process_files(self, files, f0_method, hop_length, pbar):
|
319 |
-
for file_info in files:
|
320 |
-
self.process_file(file_info, f0_method, hop_length)
|
321 |
-
pbar.update()
|
322 |
-
|
323 |
-
def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus):
|
324 |
-
input_root, *output_roots = setup_paths(exp_dir)
|
325 |
-
output_root1, output_root2 = output_roots if len(output_roots) == 2 else (output_roots[0], None)
|
326 |
-
paths = [(os.path.join(input_root, name), os.path.join(output_root1, name) if output_root1 else None, os.path.join(output_root2, name) if output_root2 else None, load_audio(os.path.join(input_root, name))) for name in sorted(os.listdir(input_root)) if "spec" not in name]
|
327 |
-
logger.info(translations["extract_f0_method"].format(num_processes=num_processes, f0_method=f0_method))
|
328 |
-
|
329 |
-
start_time = time.time()
|
330 |
-
|
331 |
-
if gpus != "-":
|
332 |
-
gpus = gpus.split("-")
|
333 |
-
process_partials = []
|
334 |
-
|
335 |
-
pbar = tqdm.tqdm(total=len(paths), desc=translations["extract_f0"], ncols=100, unit="p")
|
336 |
-
|
337 |
-
for idx, gpu in enumerate(gpus):
|
338 |
-
feature_input = FeatureInput(device=get_device(gpu))
|
339 |
-
process_partials.append((feature_input, paths[idx::len(gpus)]))
|
340 |
-
|
341 |
-
with ThreadPoolExecutor() as executor:
|
342 |
-
for future in as_completed([executor.submit(FeatureInput.process_files, feature_input, part_paths, f0_method, hop_length, pbar) for feature_input, part_paths in process_partials]):
|
343 |
-
pbar.update(1)
|
344 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
345 |
-
future.result()
|
346 |
-
|
347 |
-
pbar.close()
|
348 |
-
else:
|
349 |
-
with tqdm.tqdm(total=len(paths), desc=translations["extract_f0"], ncols=100, unit="p") as pbar:
|
350 |
-
with Pool(processes=num_processes) as pool:
|
351 |
-
for _ in pool.imap_unordered(partial(FeatureInput(device="cpu").process_file, f0_method=f0_method, hop_length=hop_length), paths):
|
352 |
-
pbar.update(1)
|
353 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
354 |
-
|
355 |
-
elapsed_time = time.time() - start_time
|
356 |
-
logger.info(translations["extract_f0_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
357 |
-
|
358 |
-
def process_file_embedding(file, wav_path, out_path, model, device, version, saved_cfg):
|
359 |
-
out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
|
360 |
-
if os.path.exists(out_file_path): return
|
361 |
-
|
362 |
-
feats = read_wave(os.path.join(wav_path, file), normalize=saved_cfg.task.normalize).to(device).float()
|
363 |
-
inputs = {"source": feats, "padding_mask": torch.BoolTensor(feats.shape).fill_(False).to(device), "output_layer": 9 if version == "v1" else 12}
|
364 |
-
|
365 |
-
with torch.no_grad():
|
366 |
-
model = model.to(device).float().eval()
|
367 |
-
logits = model.extract_features(**inputs)
|
368 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
369 |
-
|
370 |
-
feats = feats.squeeze(0).float().cpu().numpy()
|
371 |
-
|
372 |
-
if not np.isnan(feats).any(): np.save(out_file_path, feats, allow_pickle=False)
|
373 |
-
else: logger.warning(f"{file} {translations['NaN']}")
|
374 |
-
|
375 |
-
def run_embedding_extraction(exp_dir, version, gpus, embedder_model):
|
376 |
-
wav_path, out_path = setup_paths(exp_dir, version)
|
377 |
-
logger.info(translations["start_extract_hubert"])
|
378 |
-
|
379 |
-
start_time = time.time()
|
380 |
-
|
381 |
-
try:
|
382 |
-
models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join("assets", "models", "embedders", embedder_model + '.pt')], suffix="")
|
383 |
-
except Exception as e:
|
384 |
-
raise ImportError(translations["read_model_error"].format(e=e))
|
385 |
-
|
386 |
-
devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]
|
387 |
-
paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
|
388 |
-
|
389 |
-
if not paths:
|
390 |
-
logger.warning(translations["not_found_audio_file"])
|
391 |
-
sys.exit(1)
|
392 |
-
|
393 |
-
pbar = tqdm.tqdm(total=len(paths) * len(devices), desc=translations["extract_hubert"], ncols=100, unit="p")
|
394 |
-
|
395 |
-
for task in [(file, wav_path, out_path, models[0], device, version, saved_cfg) for file in paths for device in devices]:
|
396 |
-
try:
|
397 |
-
process_file_embedding(*task)
|
398 |
-
except Exception as e:
|
399 |
-
raise RuntimeError(f"{translations['process_error']} {task[0]}: {e}")
|
400 |
-
|
401 |
-
pbar.update(1)
|
402 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
403 |
-
|
404 |
-
pbar.close()
|
405 |
-
elapsed_time = time.time() - start_time
|
406 |
-
logger.info(translations["extract_hubert_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
407 |
-
|
408 |
-
if __name__ == "__main__":
|
409 |
-
args = parse_arguments()
|
410 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
411 |
-
f0_method, hop_length, num_processes, gpus, version, pitch_guidance, sample_rate, embedder_model = args.f0_method, args.hop_length, args.cpu_cores, args.gpu, args.rvc_version, args.pitch_guidance, args.sample_rate, args.embedder_model
|
412 |
-
|
413 |
-
check_predictors(f0_method)
|
414 |
-
check_embedders(embedder_model)
|
415 |
-
|
416 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
417 |
-
else:
|
418 |
-
console_handler = logging.StreamHandler()
|
419 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
420 |
-
console_handler.setFormatter(console_formatter)
|
421 |
-
console_handler.setLevel(logging.INFO)
|
422 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "extract.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
423 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
424 |
-
file_handler.setFormatter(file_formatter)
|
425 |
-
file_handler.setLevel(logging.DEBUG)
|
426 |
-
logger.addHandler(console_handler)
|
427 |
-
logger.addHandler(file_handler)
|
428 |
-
logger.setLevel(logging.DEBUG)
|
429 |
-
|
430 |
-
log_data = {translations['modelname']: args.model_name, translations['export_process']: exp_dir, translations['f0_method']: f0_method, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, "Gpu": gpus, "Hop length": hop_length, translations['training_version']: version, translations['extract_f0']: pitch_guidance, translations['hubert_model']: embedder_model}
|
431 |
-
for key, value in log_data.items():
|
432 |
-
logger.debug(f"{key}: {value}")
|
433 |
-
|
434 |
-
pid_path = os.path.join(exp_dir, "extract_pid.txt")
|
435 |
-
with open(pid_path, "w") as pid_file:
|
436 |
-
pid_file.write(str(os.getpid()))
|
437 |
-
|
438 |
-
try:
|
439 |
-
run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus)
|
440 |
-
run_embedding_extraction(exp_dir, version, gpus, embedder_model)
|
441 |
-
generate_config(version, sample_rate, exp_dir)
|
442 |
-
generate_filelist(pitch_guidance, exp_dir, version, sample_rate)
|
443 |
-
except Exception as e:
|
444 |
-
logger.error(f"{translations['extract_error']}: {e}")
|
445 |
-
|
446 |
-
import traceback
|
447 |
-
logger.debug(traceback.format_exc())
|
448 |
-
|
449 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
450 |
-
logger.info(f"{translations['extract_success']} {args.model_name}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/preprocess.py
DELETED
@@ -1,290 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import logging
|
5 |
-
import librosa
|
6 |
-
import argparse
|
7 |
-
import logging.handlers
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import soundfile as sf
|
11 |
-
import multiprocessing as mp
|
12 |
-
|
13 |
-
from tqdm import tqdm
|
14 |
-
from scipy import signal
|
15 |
-
from scipy.io import wavfile
|
16 |
-
from distutils.util import strtobool
|
17 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
18 |
-
|
19 |
-
sys.path.append(os.getcwd())
|
20 |
-
|
21 |
-
from main.configs.config import Config
|
22 |
-
|
23 |
-
logger = logging.getLogger(__name__)
|
24 |
-
for l in ["numba.core.byteflow", "numba.core.ssa", "numba.core.interpreter"]:
|
25 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
26 |
-
|
27 |
-
OVERLAP, MAX_AMPLITUDE, ALPHA, HIGH_PASS_CUTOFF, SAMPLE_RATE_16K = 0.3, 0.9, 0.75, 48, 16000
|
28 |
-
translations = Config().translations
|
29 |
-
|
30 |
-
def parse_arguments():
|
31 |
-
parser = argparse.ArgumentParser()
|
32 |
-
parser.add_argument("--model_name", type=str, required=True)
|
33 |
-
parser.add_argument("--dataset_path", type=str, default="./dataset")
|
34 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
35 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
36 |
-
parser.add_argument("--cut_preprocess", type=lambda x: bool(strtobool(x)), default=True)
|
37 |
-
parser.add_argument("--process_effects", type=lambda x: bool(strtobool(x)), default=False)
|
38 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
39 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
40 |
-
|
41 |
-
return parser.parse_args()
|
42 |
-
|
43 |
-
def load_audio(file, sample_rate):
|
44 |
-
try:
|
45 |
-
audio, sr = sf.read(file.strip(" ").strip('"').strip("\n").strip('"').strip(" "))
|
46 |
-
|
47 |
-
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
48 |
-
if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq")
|
49 |
-
except Exception as e:
|
50 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
51 |
-
|
52 |
-
return audio.flatten()
|
53 |
-
|
54 |
-
class Slicer:
|
55 |
-
def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
|
56 |
-
if not min_length >= min_interval >= hop_size: raise ValueError(translations["min_length>=min_interval>=hop_size"])
|
57 |
-
if not max_sil_kept >= hop_size: raise ValueError(translations["max_sil_kept>=hop_size"])
|
58 |
-
|
59 |
-
min_interval = sr * min_interval / 1000
|
60 |
-
self.threshold = 10 ** (threshold / 20.0)
|
61 |
-
self.hop_size = round(sr * hop_size / 1000)
|
62 |
-
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
63 |
-
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
64 |
-
self.min_interval = round(min_interval / self.hop_size)
|
65 |
-
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
66 |
-
|
67 |
-
def _apply_slice(self, waveform, begin, end):
|
68 |
-
start_idx = begin * self.hop_size
|
69 |
-
|
70 |
-
if len(waveform.shape) > 1: return waveform[:, start_idx:min(waveform.shape[1], end * self.hop_size)]
|
71 |
-
else: return waveform[start_idx:min(waveform.shape[0], end * self.hop_size)]
|
72 |
-
|
73 |
-
def slice(self, waveform):
|
74 |
-
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
75 |
-
if samples.shape[0] <= self.min_length: return [waveform]
|
76 |
-
|
77 |
-
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
78 |
-
sil_tags = []
|
79 |
-
silence_start, clip_start = None, 0
|
80 |
-
|
81 |
-
for i, rms in enumerate(rms_list):
|
82 |
-
if rms < self.threshold:
|
83 |
-
if silence_start is None: silence_start = i
|
84 |
-
continue
|
85 |
-
|
86 |
-
if silence_start is None: continue
|
87 |
-
|
88 |
-
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
89 |
-
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
90 |
-
|
91 |
-
if not is_leading_silence and not need_slice_middle:
|
92 |
-
silence_start = None
|
93 |
-
continue
|
94 |
-
|
95 |
-
if i - silence_start <= self.max_sil_kept:
|
96 |
-
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
97 |
-
|
98 |
-
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
99 |
-
clip_start = pos
|
100 |
-
elif i - silence_start <= self.max_sil_kept * 2:
|
101 |
-
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
102 |
-
|
103 |
-
pos += i - self.max_sil_kept
|
104 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
105 |
-
|
106 |
-
if silence_start == 0:
|
107 |
-
sil_tags.append((0, pos_r))
|
108 |
-
clip_start = pos_r
|
109 |
-
else:
|
110 |
-
sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
|
111 |
-
clip_start = max(pos_r, pos)
|
112 |
-
else:
|
113 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
114 |
-
|
115 |
-
sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
|
116 |
-
clip_start = pos_r
|
117 |
-
|
118 |
-
silence_start = None
|
119 |
-
total_frames = rms_list.shape[0]
|
120 |
-
if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
|
121 |
-
|
122 |
-
if not sil_tags: return [waveform]
|
123 |
-
else:
|
124 |
-
chunks = []
|
125 |
-
if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
|
126 |
-
|
127 |
-
for i in range(len(sil_tags) - 1):
|
128 |
-
chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
|
129 |
-
|
130 |
-
if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
|
131 |
-
return chunks
|
132 |
-
|
133 |
-
def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
|
134 |
-
y = np.pad(y, (int(frame_length // 2), int(frame_length // 2)), mode=pad_mode)
|
135 |
-
axis = -1
|
136 |
-
|
137 |
-
x_shape_trimmed = list(y.shape)
|
138 |
-
x_shape_trimmed[axis] -= frame_length - 1
|
139 |
-
|
140 |
-
xw = np.lib.stride_tricks.as_strided(y, shape=tuple(x_shape_trimmed) + tuple([frame_length]), strides=y.strides + tuple([y.strides[axis]]))
|
141 |
-
xw = np.moveaxis(xw, -1, axis - 1 if axis < 0 else axis + 1)
|
142 |
-
|
143 |
-
slices = [slice(None)] * xw.ndim
|
144 |
-
slices[axis] = slice(0, None, hop_length)
|
145 |
-
|
146 |
-
return np.sqrt(np.mean(np.abs(xw[tuple(slices)]) ** 2, axis=-2, keepdims=True))
|
147 |
-
|
148 |
-
class PreProcess:
|
149 |
-
def __init__(self, sr, exp_dir, per):
|
150 |
-
self.slicer = Slicer(sr=sr, threshold=-42, min_length=1500, min_interval=400, hop_size=15, max_sil_kept=500)
|
151 |
-
self.sr = sr
|
152 |
-
self.b_high, self.a_high = signal.butter(N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr)
|
153 |
-
self.per = per
|
154 |
-
self.exp_dir = exp_dir
|
155 |
-
self.device = "cpu"
|
156 |
-
self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios")
|
157 |
-
self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k")
|
158 |
-
os.makedirs(self.gt_wavs_dir, exist_ok=True)
|
159 |
-
os.makedirs(self.wavs16k_dir, exist_ok=True)
|
160 |
-
|
161 |
-
def _normalize_audio(self, audio):
|
162 |
-
tmp_max = np.abs(audio).max()
|
163 |
-
if tmp_max > 2.5: return None
|
164 |
-
return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio
|
165 |
-
|
166 |
-
def process_audio_segment(self, normalized_audio, sid, idx0, idx1):
|
167 |
-
if normalized_audio is None:
|
168 |
-
logger.debug(f"{sid}-{idx0}-{idx1}-filtered")
|
169 |
-
return
|
170 |
-
|
171 |
-
wavfile.write(os.path.join(self.gt_wavs_dir, f"{sid}_{idx0}_{idx1}.wav"), self.sr, normalized_audio.astype(np.float32))
|
172 |
-
wavfile.write(os.path.join(self.wavs16k_dir, f"{sid}_{idx0}_{idx1}.wav"), SAMPLE_RATE_16K, librosa.resample(normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K, res_type="soxr_vhq").astype(np.float32))
|
173 |
-
|
174 |
-
def process_audio(self, path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
175 |
-
try:
|
176 |
-
audio = load_audio(path, self.sr)
|
177 |
-
|
178 |
-
if process_effects:
|
179 |
-
audio = signal.lfilter(self.b_high, self.a_high, audio)
|
180 |
-
audio = self._normalize_audio(audio)
|
181 |
-
|
182 |
-
if clean_dataset:
|
183 |
-
from main.tools.noisereduce import reduce_noise
|
184 |
-
audio = reduce_noise(y=audio, sr=self.sr, prop_decrease=clean_strength)
|
185 |
-
|
186 |
-
idx1 = 0
|
187 |
-
if cut_preprocess:
|
188 |
-
for audio_segment in self.slicer.slice(audio):
|
189 |
-
i = 0
|
190 |
-
|
191 |
-
while 1:
|
192 |
-
start = int(self.sr * (self.per - OVERLAP) * i)
|
193 |
-
i += 1
|
194 |
-
|
195 |
-
if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr:
|
196 |
-
self.process_audio_segment(audio_segment[start : start + int(self.per * self.sr)], sid, idx0, idx1)
|
197 |
-
idx1 += 1
|
198 |
-
else:
|
199 |
-
self.process_audio_segment(audio_segment[start:], sid, idx0, idx1)
|
200 |
-
idx1 += 1
|
201 |
-
break
|
202 |
-
else: self.process_audio_segment(audio, sid, idx0, idx1)
|
203 |
-
except Exception as e:
|
204 |
-
raise RuntimeError(f"{translations['process_audio_error']}: {e}")
|
205 |
-
|
206 |
-
def process_file(args):
|
207 |
-
pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength = (args)
|
208 |
-
file_path, idx0, sid = file
|
209 |
-
pp.process_audio(file_path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength)
|
210 |
-
|
211 |
-
def preprocess_training_set(input_root, sr, num_processes, exp_dir, per, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
212 |
-
start_time = time.time()
|
213 |
-
|
214 |
-
pp = PreProcess(sr, exp_dir, per)
|
215 |
-
logger.info(translations["start_preprocess"].format(num_processes=num_processes))
|
216 |
-
files = []
|
217 |
-
idx = 0
|
218 |
-
|
219 |
-
for root, _, filenames in os.walk(input_root):
|
220 |
-
try:
|
221 |
-
sid = 0 if root == input_root else int(os.path.basename(root))
|
222 |
-
|
223 |
-
for f in filenames:
|
224 |
-
if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3")):
|
225 |
-
files.append((os.path.join(root, f), idx, sid))
|
226 |
-
idx += 1
|
227 |
-
except ValueError:
|
228 |
-
raise ValueError(f"{translations['not_integer']} '{os.path.basename(root)}'.")
|
229 |
-
|
230 |
-
with tqdm(total=len(files), desc=translations["preprocess"], ncols=100, unit="f") as pbar:
|
231 |
-
with ProcessPoolExecutor(max_workers=num_processes) as executor:
|
232 |
-
futures = [executor.submit(process_file, (pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength)) for file in files]
|
233 |
-
for future in as_completed(futures):
|
234 |
-
try:
|
235 |
-
future.result()
|
236 |
-
except Exception as e:
|
237 |
-
raise RuntimeError(f"{translations['process_error']}: {e}")
|
238 |
-
pbar.update(1)
|
239 |
-
logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
|
240 |
-
|
241 |
-
elapsed_time = time.time() - start_time
|
242 |
-
logger.info(translations["preprocess_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
243 |
-
|
244 |
-
if __name__ == "__main__":
|
245 |
-
args = parse_arguments()
|
246 |
-
experiment_directory = os.path.join("assets", "logs", args.model_name)
|
247 |
-
num_processes = args.cpu_cores
|
248 |
-
num_processes = mp.cpu_count() if num_processes is None else int(num_processes)
|
249 |
-
dataset = args.dataset_path
|
250 |
-
sample_rate = args.sample_rate
|
251 |
-
cut_preprocess = args.cut_preprocess
|
252 |
-
preprocess_effects = args.process_effects
|
253 |
-
clean_dataset = args.clean_dataset
|
254 |
-
clean_strength = args.clean_strength
|
255 |
-
|
256 |
-
os.makedirs(experiment_directory, exist_ok=True)
|
257 |
-
|
258 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
259 |
-
else:
|
260 |
-
console_handler = logging.StreamHandler()
|
261 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
262 |
-
console_handler.setFormatter(console_formatter)
|
263 |
-
console_handler.setLevel(logging.INFO)
|
264 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_directory, "preprocess.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
265 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
266 |
-
file_handler.setFormatter(file_formatter)
|
267 |
-
file_handler.setLevel(logging.DEBUG)
|
268 |
-
logger.addHandler(console_handler)
|
269 |
-
logger.addHandler(file_handler)
|
270 |
-
logger.setLevel(logging.DEBUG)
|
271 |
-
|
272 |
-
log_data = {translations['modelname']: args.model_name, translations['export_process']: experiment_directory, translations['dataset_folder']: dataset, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, translations['split_audio']: cut_preprocess, translations['preprocess_effect']: preprocess_effects, translations['clear_audio']: clean_dataset}
|
273 |
-
if clean_dataset: log_data[translations['clean_strength']] = clean_strength
|
274 |
-
|
275 |
-
for key, value in log_data.items():
|
276 |
-
logger.debug(f"{key}: {value}")
|
277 |
-
|
278 |
-
pid_path = os.path.join(experiment_directory, "preprocess_pid.txt")
|
279 |
-
with open(pid_path, "w") as pid_file:
|
280 |
-
pid_file.write(str(os.getpid()))
|
281 |
-
|
282 |
-
try:
|
283 |
-
preprocess_training_set(dataset, sample_rate, num_processes, experiment_directory, 3.7, cut_preprocess, preprocess_effects, clean_dataset, clean_strength)
|
284 |
-
except Exception as e:
|
285 |
-
logger.error(f"{translations['process_audio_error']} {e}")
|
286 |
-
import traceback
|
287 |
-
logger.debug(traceback.format_exc())
|
288 |
-
|
289 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
290 |
-
logger.info(f"{translations['preprocess_model_success']} {args.model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/separator_music.py
DELETED
@@ -1,290 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import logging
|
5 |
-
import argparse
|
6 |
-
import logging.handlers
|
7 |
-
|
8 |
-
from pydub import AudioSegment
|
9 |
-
from distutils.util import strtobool
|
10 |
-
|
11 |
-
sys.path.append(os.getcwd())
|
12 |
-
|
13 |
-
from main.configs.config import Config
|
14 |
-
from main.library.utils import pydub_convert
|
15 |
-
from main.library.algorithm.separator import Separator
|
16 |
-
|
17 |
-
translations = Config().translations
|
18 |
-
logger = logging.getLogger(__name__)
|
19 |
-
|
20 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
21 |
-
else:
|
22 |
-
console_handler = logging.StreamHandler()
|
23 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
24 |
-
console_handler.setFormatter(console_formatter)
|
25 |
-
console_handler.setLevel(logging.INFO)
|
26 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "separator.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
27 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
28 |
-
file_handler.setFormatter(file_formatter)
|
29 |
-
file_handler.setLevel(logging.DEBUG)
|
30 |
-
logger.addHandler(console_handler)
|
31 |
-
logger.addHandler(file_handler)
|
32 |
-
logger.setLevel(logging.DEBUG)
|
33 |
-
|
34 |
-
demucs_models = {"HT-Tuned": "htdemucs_ft.yaml", "HT-Normal": "htdemucs.yaml", "HD_MMI": "hdemucs_mmi.yaml", "HT_6S": "htdemucs_6s.yaml"}
|
35 |
-
mdx_models = {"Main_340": "UVR-MDX-NET_Main_340.onnx", "Main_390": "UVR-MDX-NET_Main_390.onnx", "Main_406": "UVR-MDX-NET_Main_406.onnx", "Main_427": "UVR-MDX-NET_Main_427.onnx", "Main_438": "UVR-MDX-NET_Main_438.onnx", "Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx", "Inst_HQ_1": "UVR-MDX-NET_Inst_HQ_1.onnx", "Inst_HQ_2": "UVR-MDX-NET_Inst_HQ_2.onnx", "Inst_HQ_3": "UVR-MDX-NET_Inst_HQ_3.onnx", "Inst_HQ_4": "UVR-MDX-NET-Inst_HQ_4.onnx", "Inst_HQ_5": "UVR-MDX-NET-Inst_HQ_5.onnx", "Kim_Vocal_1": "Kim_Vocal_1.onnx", "Kim_Vocal_2": "Kim_Vocal_2.onnx", "Kim_Inst": "Kim_Inst.onnx", "Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx", "Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx", "Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx", "Voc_FT": "UVR-MDX-NET-Voc_FT.onnx", "Crowd_HQ": "UVR-MDX-NET_Crowd_HQ_1.onnx", "MDXNET_9482": "UVR_MDXNET_9482.onnx", "Inst_1": "UVR-MDX-NET-Inst_1.onnx", "Inst_2": "UVR-MDX-NET-Inst_2.onnx", "Inst_3": "UVR-MDX-NET-Inst_3.onnx", "MDXNET_1_9703": "UVR_MDXNET_1_9703.onnx", "MDXNET_2_9682": "UVR_MDXNET_2_9682.onnx", "MDXNET_3_9662": "UVR_MDXNET_3_9662.onnx", "Inst_Main": "UVR-MDX-NET-Inst_Main.onnx", "MDXNET_Main": "UVR_MDXNET_Main.onnx"}
|
36 |
-
kara_models = {"Version-1": "UVR_MDXNET_KARA.onnx", "Version-2": "UVR_MDXNET_KARA_2.onnx"}
|
37 |
-
|
38 |
-
def parse_arguments():
|
39 |
-
parser = argparse.ArgumentParser()
|
40 |
-
parser.add_argument("--input_path", type=str, required=True)
|
41 |
-
parser.add_argument("--output_path", type=str, default="./audios")
|
42 |
-
parser.add_argument("--format", type=str, default="wav")
|
43 |
-
parser.add_argument("--shifts", type=int, default=2)
|
44 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
45 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
46 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
47 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
48 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
49 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
50 |
-
parser.add_argument("--model_name", type=str, default="HT-Normal")
|
51 |
-
parser.add_argument("--kara_model", type=str, default="Version-1")
|
52 |
-
parser.add_argument("--backing", type=lambda x: bool(strtobool(x)), default=False)
|
53 |
-
parser.add_argument("--mdx_denoise", type=lambda x: bool(strtobool(x)), default=False)
|
54 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
55 |
-
parser.add_argument("--backing_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
56 |
-
parser.add_argument("--sample_rate", type=int, default=44100)
|
57 |
-
|
58 |
-
return parser.parse_args()
|
59 |
-
|
60 |
-
def main():
|
61 |
-
start_time = time.time()
|
62 |
-
pid_path = os.path.join("assets", "separate_pid.txt")
|
63 |
-
|
64 |
-
with open(pid_path, "w") as pid_file:
|
65 |
-
pid_file.write(str(os.getpid()))
|
66 |
-
|
67 |
-
try:
|
68 |
-
args = parse_arguments()
|
69 |
-
input_path, output_path, export_format, shifts, segments_size, overlap, hop_length, batch_size, clean_audio, clean_strength, model_name, kara_model, backing, mdx_denoise, reverb, backing_reverb, sample_rate = args.input_path, args.output_path, args.format, args.shifts, args.segments_size, args.overlap, args.mdx_hop_length, args.mdx_batch_size, args.clean_audio, args.clean_strength, args.model_name, args.kara_model, args.backing, args.mdx_denoise, args.reverb, args.backing_reverb, args.sample_rate
|
70 |
-
|
71 |
-
if backing_reverb and not reverb:
|
72 |
-
logger.warning(translations["turn_on_dereverb"])
|
73 |
-
sys.exit(1)
|
74 |
-
|
75 |
-
if backing_reverb and not backing:
|
76 |
-
logger.warning(translations["turn_on_separator_backing"])
|
77 |
-
sys.exit(1)
|
78 |
-
|
79 |
-
log_data = {translations['audio_path']: input_path, translations['output_path']: output_path, translations['export_format']: export_format, translations['shift']: shifts, translations['segments_size']: segments_size, translations['overlap']: overlap, translations['modelname']: model_name, translations['denoise_mdx']: mdx_denoise, "Hop length": hop_length, translations['batch_size']: batch_size, translations['sr']: sample_rate}
|
80 |
-
|
81 |
-
if clean_audio:
|
82 |
-
log_data[translations['clear_audio']] = clean_audio
|
83 |
-
log_data[translations['clean_strength']] = clean_strength
|
84 |
-
|
85 |
-
if backing:
|
86 |
-
log_data[translations['backing_model_ver']] = kara_model
|
87 |
-
log_data[translations['separator_backing']] = backing
|
88 |
-
|
89 |
-
if reverb:
|
90 |
-
log_data[translations['dereveb_audio']] = reverb
|
91 |
-
log_data[translations['dereveb_backing']] = backing_reverb
|
92 |
-
|
93 |
-
for key, value in log_data.items():
|
94 |
-
logger.debug(f"{key}: {value}")
|
95 |
-
|
96 |
-
if model_name in ["HT-Tuned", "HT-Normal", "HD_MMI", "HT_6S"]: vocals, instruments = separator_music_demucs(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate)
|
97 |
-
else: vocals, instruments = separator_music_mdx(input_path, output_path, export_format, segments_size, overlap, mdx_denoise, model_name, hop_length, batch_size, sample_rate)
|
98 |
-
|
99 |
-
if backing: main_vocals, backing_vocals = separator_backing(vocals, output_path, export_format, segments_size, overlap, mdx_denoise, kara_model, hop_length, batch_size, sample_rate)
|
100 |
-
if reverb: vocals_no_reverb, main_vocals_no_reverb, backing_vocals_no_reverb = separator_reverb(output_path, export_format, segments_size, overlap, mdx_denoise, reverb, backing_reverb, hop_length, batch_size, sample_rate)
|
101 |
-
|
102 |
-
original_output = os.path.join(output_path, f"Original_Vocals_No_Reverb.{export_format}") if reverb else os.path.join(output_path, f"Original_Vocals.{export_format}")
|
103 |
-
main_output = os.path.join(output_path, f"Main_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Main_Vocals.{export_format}")
|
104 |
-
backing_output = os.path.join(output_path, f"Backing_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Backing_Vocals.{export_format}")
|
105 |
-
|
106 |
-
if clean_audio:
|
107 |
-
import soundfile as sf
|
108 |
-
logger.info(f"{translations['clear_audio']}...")
|
109 |
-
|
110 |
-
vocal_data, vocal_sr = sf.read(vocals_no_reverb if reverb else vocals)
|
111 |
-
main_data, main_sr = sf.read(main_vocals_no_reverb if reverb and backing else main_vocals)
|
112 |
-
backing_data, backing_sr = sf.read(backing_vocals_no_reverb if reverb and backing_reverb else backing_vocals)
|
113 |
-
|
114 |
-
from main.tools.noisereduce import reduce_noise
|
115 |
-
sf.write(original_output, reduce_noise(y=vocal_data, prop_decrease=clean_strength), vocal_sr, format=export_format)
|
116 |
-
|
117 |
-
if backing:
|
118 |
-
sf.write(main_output, reduce_noise(y=main_data, sr=main_sr, prop_decrease=clean_strength), main_sr, format=export_format)
|
119 |
-
sf.write(backing_output, reduce_noise(y=backing_data, sr=backing_sr, prop_decrease=clean_strength), backing_sr, format=export_format)
|
120 |
-
|
121 |
-
logger.info(translations["clean_audio_success"])
|
122 |
-
return original_output, instruments, main_output, backing_output
|
123 |
-
except Exception as e:
|
124 |
-
logger.error(f"{translations['separator_error']}: {e}")
|
125 |
-
import traceback
|
126 |
-
logger.debug(traceback.format_exc())
|
127 |
-
|
128 |
-
if os.path.exists(pid_path): os.remove(pid_path)
|
129 |
-
|
130 |
-
elapsed_time = time.time() - start_time
|
131 |
-
logger.info(translations["separator_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
132 |
-
|
133 |
-
def separator_music_demucs(input, output, format, shifts, overlap, segments_size, demucs_model, sample_rate):
|
134 |
-
if not os.path.exists(input):
|
135 |
-
logger.warning(translations["input_not_valid"])
|
136 |
-
sys.exit(1)
|
137 |
-
|
138 |
-
if not os.path.exists(output):
|
139 |
-
logger.warning(translations["output_not_valid"])
|
140 |
-
sys.exit(1)
|
141 |
-
|
142 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
143 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
144 |
-
|
145 |
-
logger.info(f"{translations['separator_process_2']}...")
|
146 |
-
demucs_output = separator_main(audio_file=input, model_filename=demucs_models.get(demucs_model), output_format=format, output_dir=output, demucs_segment_size=(segments_size / 2), demucs_shifts=shifts, demucs_overlap=overlap, sample_rate=sample_rate)
|
147 |
-
|
148 |
-
for f in demucs_output:
|
149 |
-
path = os.path.join(output, f)
|
150 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
151 |
-
|
152 |
-
if '_(Drums)_' in f: drums = path
|
153 |
-
elif '_(Bass)_' in f: bass = path
|
154 |
-
elif '_(Other)_' in f: other = path
|
155 |
-
elif '_(Vocals)_' in f: os.rename(path, os.path.join(output, f"Original_Vocals.{format}"))
|
156 |
-
|
157 |
-
pydub_convert(AudioSegment.from_file(drums)).overlay(pydub_convert(AudioSegment.from_file(bass))).overlay(pydub_convert(AudioSegment.from_file(other))).export(os.path.join(output, f"Instruments.{format}"), format=format)
|
158 |
-
|
159 |
-
for f in [drums, bass, other]:
|
160 |
-
if os.path.exists(f): os.remove(f)
|
161 |
-
|
162 |
-
logger.info(translations["separator_success_2"])
|
163 |
-
return os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
|
164 |
-
|
165 |
-
def separator_backing(input, output, format, segments_size, overlap, denoise, kara_model, hop_length, batch_size, sample_rate):
|
166 |
-
if not os.path.exists(input):
|
167 |
-
logger.warning(translations["input_not_valid"])
|
168 |
-
sys.exit(1)
|
169 |
-
|
170 |
-
if not os.path.exists(output):
|
171 |
-
logger.warning(translations["output_not_valid"])
|
172 |
-
sys.exit(1)
|
173 |
-
|
174 |
-
for f in [f"Main_Vocals.{format}", f"Backing_Vocals.{format}"]:
|
175 |
-
if os.path.exists(os.path.join(output, f)): os.remove(os.path.join(output, f))
|
176 |
-
|
177 |
-
model_2 = kara_models.get(kara_model)
|
178 |
-
logger.info(f"{translations['separator_process_backing']}...")
|
179 |
-
|
180 |
-
backing_outputs = separator_main(audio_file=input, model_filename=model_2, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
181 |
-
main_output = os.path.join(output, f"Main_Vocals.{format}")
|
182 |
-
backing_output = os.path.join(output, f"Backing_Vocals.{format}")
|
183 |
-
|
184 |
-
for f in backing_outputs:
|
185 |
-
path = os.path.join(output, f)
|
186 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
187 |
-
if '_(Instrumental)_' in f: os.rename(path, backing_output)
|
188 |
-
elif '_(Vocals)_' in f: os.rename(path, main_output)
|
189 |
-
|
190 |
-
logger.info(translations["separator_process_backing_success"])
|
191 |
-
return main_output, backing_output
|
192 |
-
|
193 |
-
def separator_music_mdx(input, output, format, segments_size, overlap, denoise, mdx_model, hop_length, batch_size, sample_rate):
|
194 |
-
if not os.path.exists(input):
|
195 |
-
logger.warning(translations["input_not_valid"])
|
196 |
-
sys.exit(1)
|
197 |
-
|
198 |
-
if not os.path.exists(output):
|
199 |
-
logger.warning(translations["output_not_valid"])
|
200 |
-
sys.exit(1)
|
201 |
-
|
202 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
203 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
204 |
-
|
205 |
-
model_3 = mdx_models.get(mdx_model)
|
206 |
-
logger.info(f"{translations['separator_process_2']}...")
|
207 |
-
|
208 |
-
output_music = separator_main(audio_file=input, model_filename=model_3, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
209 |
-
original_output, instruments_output = os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
|
210 |
-
|
211 |
-
for f in output_music:
|
212 |
-
path = os.path.join(output, f)
|
213 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
214 |
-
if '_(Instrumental)_' in f: os.rename(path, instruments_output)
|
215 |
-
elif '_(Vocals)_' in f: os.rename(path, original_output)
|
216 |
-
|
217 |
-
logger.info(translations["separator_process_backing_success"])
|
218 |
-
return original_output, instruments_output
|
219 |
-
|
220 |
-
def separator_reverb(output, format, segments_size, overlap, denoise, original, backing_reverb, hop_length, batch_size, sample_rate):
|
221 |
-
if not os.path.exists(output):
|
222 |
-
logger.warning(translations["output_not_valid"])
|
223 |
-
sys.exit(1)
|
224 |
-
|
225 |
-
for i in [f"Original_Vocals_Reverb.{format}", f"Main_Vocals_Reverb.{format}", f"Original_Vocals_No_Reverb.{format}", f"Main_Vocals_No_Reverb.{format}"]:
|
226 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
227 |
-
|
228 |
-
dereveb_path = []
|
229 |
-
|
230 |
-
if original:
|
231 |
-
try:
|
232 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Original_Vocals' in f][0]))
|
233 |
-
except IndexError:
|
234 |
-
logger.warning(translations["not_found_original_vocal"])
|
235 |
-
sys.exit(1)
|
236 |
-
|
237 |
-
if backing_reverb:
|
238 |
-
try:
|
239 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Main_Vocals' in f][0]))
|
240 |
-
except IndexError:
|
241 |
-
logger.warning(translations["not_found_main_vocal"])
|
242 |
-
sys.exit(1)
|
243 |
-
|
244 |
-
if backing_reverb:
|
245 |
-
try:
|
246 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Backing_Vocals' in f][0]))
|
247 |
-
except IndexError:
|
248 |
-
logger.warning(translations["not_found_backing_vocal"])
|
249 |
-
sys.exit(1)
|
250 |
-
|
251 |
-
for path in dereveb_path:
|
252 |
-
if not os.path.exists(path):
|
253 |
-
logger.warning(translations["not_found"].format(name=path))
|
254 |
-
sys.exit(1)
|
255 |
-
|
256 |
-
if "Original_Vocals" in path:
|
257 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Original_Vocals_Reverb.{format}"), os.path.join(output, f"Original_Vocals_No_Reverb.{format}")
|
258 |
-
start_title, end_title = translations["process_original"], translations["process_original_success"]
|
259 |
-
elif "Main_Vocals" in path:
|
260 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Main_Vocals_Reverb.{format}"), os.path.join(output, f"Main_Vocals_No_Reverb.{format}")
|
261 |
-
start_title, end_title = translations["process_main"], translations["process_main_success"]
|
262 |
-
elif "Backing_Vocals" in path:
|
263 |
-
reverb_path, no_reverb_path = os.path.join(output, f"Backing_Vocals_Reverb.{format}"), os.path.join(output, f"Backing_Vocals_No_Reverb.{format}")
|
264 |
-
start_title, end_title = translations["process_backing"], translations["process_backing_success"]
|
265 |
-
|
266 |
-
logger.info(start_title)
|
267 |
-
output_dereveb = separator_main(audio_file=path, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
|
268 |
-
|
269 |
-
for f in output_dereveb:
|
270 |
-
path = os.path.join(output, f)
|
271 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
272 |
-
|
273 |
-
if '_(Reverb)_' in f: os.rename(path, reverb_path)
|
274 |
-
elif '_(No Reverb)_' in f: os.rename(path, no_reverb_path)
|
275 |
-
|
276 |
-
logger.info(end_title)
|
277 |
-
return (os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if original else None), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if backing_reverb else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else None)
|
278 |
-
|
279 |
-
def separator_main(audio_file=None, model_filename="UVR-MDX-NET_Main_340.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, demucs_segment_size=256, demucs_shifts=2, demucs_overlap=0.25, sample_rate=44100):
|
280 |
-
try:
|
281 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": demucs_segment_size, "shifts": demucs_shifts, "overlap": demucs_overlap, "segments_enabled": True})
|
282 |
-
separator.load_model(model_filename=model_filename)
|
283 |
-
return separator.separate(audio_file)
|
284 |
-
except:
|
285 |
-
logger.debug(translations["default_setting"])
|
286 |
-
separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": 128, "shifts": 2, "overlap": 0.25, "segments_enabled": True})
|
287 |
-
separator.load_model(model_filename=model_filename)
|
288 |
-
return separator.separate(audio_file)
|
289 |
-
|
290 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/train.py
DELETED
@@ -1,1000 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import glob
|
4 |
-
import json
|
5 |
-
import torch
|
6 |
-
import hashlib
|
7 |
-
import logging
|
8 |
-
import argparse
|
9 |
-
import datetime
|
10 |
-
import warnings
|
11 |
-
import logging.handlers
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
import soundfile as sf
|
15 |
-
import matplotlib.pyplot as plt
|
16 |
-
import torch.distributed as dist
|
17 |
-
import torch.utils.data as tdata
|
18 |
-
import torch.multiprocessing as mp
|
19 |
-
import torch.utils.checkpoint as checkpoint
|
20 |
-
|
21 |
-
from tqdm import tqdm
|
22 |
-
from collections import OrderedDict
|
23 |
-
from random import randint, shuffle
|
24 |
-
from torch.cuda.amp import GradScaler, autocast
|
25 |
-
from torch.utils.tensorboard import SummaryWriter
|
26 |
-
|
27 |
-
from time import time as ttime
|
28 |
-
from torch.nn import functional as F
|
29 |
-
from distutils.util import strtobool
|
30 |
-
from librosa.filters import mel as librosa_mel_fn
|
31 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
32 |
-
from torch.nn.utils.parametrizations import spectral_norm, weight_norm
|
33 |
-
|
34 |
-
sys.path.append(os.getcwd())
|
35 |
-
from main.configs.config import Config
|
36 |
-
from main.library.algorithm.residuals import LRELU_SLOPE
|
37 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
38 |
-
from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value
|
39 |
-
|
40 |
-
MATPLOTLIB_FLAG = False
|
41 |
-
translations = Config().translations
|
42 |
-
warnings.filterwarnings("ignore")
|
43 |
-
logging.getLogger("torch").setLevel(logging.ERROR)
|
44 |
-
|
45 |
-
class HParams:
|
46 |
-
def __init__(self, **kwargs):
|
47 |
-
for k, v in kwargs.items():
|
48 |
-
self[k] = HParams(**v) if isinstance(v, dict) else v
|
49 |
-
|
50 |
-
def keys(self):
|
51 |
-
return self.__dict__.keys()
|
52 |
-
|
53 |
-
def items(self):
|
54 |
-
return self.__dict__.items()
|
55 |
-
|
56 |
-
def values(self):
|
57 |
-
return self.__dict__.values()
|
58 |
-
|
59 |
-
def __len__(self):
|
60 |
-
return len(self.__dict__)
|
61 |
-
|
62 |
-
def __getitem__(self, key):
|
63 |
-
return self.__dict__[key]
|
64 |
-
|
65 |
-
def __setitem__(self, key, value):
|
66 |
-
self.__dict__[key] = value
|
67 |
-
|
68 |
-
def __contains__(self, key):
|
69 |
-
return key in self.__dict__
|
70 |
-
|
71 |
-
def __repr__(self):
|
72 |
-
return repr(self.__dict__)
|
73 |
-
|
74 |
-
def parse_arguments():
|
75 |
-
parser = argparse.ArgumentParser()
|
76 |
-
parser.add_argument("--model_name", type=str, required=True)
|
77 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
78 |
-
parser.add_argument("--save_every_epoch", type=int, required=True)
|
79 |
-
parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
|
80 |
-
parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
|
81 |
-
parser.add_argument("--total_epoch", type=int, default=300)
|
82 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
83 |
-
parser.add_argument("--batch_size", type=int, default=8)
|
84 |
-
parser.add_argument("--gpu", type=str, default="0")
|
85 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
86 |
-
parser.add_argument("--g_pretrained_path", type=str, default="")
|
87 |
-
parser.add_argument("--d_pretrained_path", type=str, default="")
|
88 |
-
parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
|
89 |
-
parser.add_argument("--overtraining_threshold", type=int, default=50)
|
90 |
-
parser.add_argument("--cleanup", type=lambda x: bool(strtobool(x)), default=False)
|
91 |
-
parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
|
92 |
-
parser.add_argument("--model_author", type=str)
|
93 |
-
parser.add_argument("--vocoder", type=str, default="Default")
|
94 |
-
parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
|
95 |
-
|
96 |
-
return parser.parse_args()
|
97 |
-
|
98 |
-
args = parse_arguments()
|
99 |
-
model_name, save_every_epoch, total_epoch, pretrainG, pretrainD, version, gpus, batch_size, sample_rate, pitch_guidance, save_only_latest, save_every_weights, cache_data_in_gpu, overtraining_detector, overtraining_threshold, cleanup, model_author, vocoder, checkpointing = args.model_name, args.save_every_epoch, args.total_epoch, args.g_pretrained_path, args.d_pretrained_path, args.rvc_version, args.gpu, args.batch_size, args.sample_rate, args.pitch_guidance, args.save_only_latest, args.save_every_weights, args.cache_data_in_gpu, args.overtraining_detector, args.overtraining_threshold, args.cleanup, args.model_author, args.vocoder, args.checkpointing
|
100 |
-
|
101 |
-
experiment_dir = os.path.join("assets", "logs", model_name)
|
102 |
-
training_file_path = os.path.join(experiment_dir, "training_data.json")
|
103 |
-
config_save_path = os.path.join(experiment_dir, "config.json")
|
104 |
-
|
105 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",")
|
106 |
-
n_gpus = len(gpus.split("-"))
|
107 |
-
|
108 |
-
torch.backends.cudnn.deterministic = False
|
109 |
-
torch.backends.cudnn.benchmark = False
|
110 |
-
|
111 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
112 |
-
global_step, last_loss_gen_all, overtrain_save_epoch = 0, 0, 0
|
113 |
-
loss_gen_history, smoothed_loss_gen_history, loss_disc_history, smoothed_loss_disc_history = [], [], [], []
|
114 |
-
|
115 |
-
with open(config_save_path, "r") as f:
|
116 |
-
config = json.load(f)
|
117 |
-
|
118 |
-
config = HParams(**config)
|
119 |
-
config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
|
120 |
-
logger = logging.getLogger(__name__)
|
121 |
-
|
122 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
123 |
-
else:
|
124 |
-
console_handler = logging.StreamHandler()
|
125 |
-
console_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
126 |
-
console_handler.setLevel(logging.INFO)
|
127 |
-
file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_dir, "train.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
128 |
-
file_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
|
129 |
-
file_handler.setLevel(logging.DEBUG)
|
130 |
-
logger.addHandler(console_handler)
|
131 |
-
logger.addHandler(file_handler)
|
132 |
-
logger.setLevel(logging.DEBUG)
|
133 |
-
|
134 |
-
log_data = {translations['modelname']: model_name, translations["save_every_epoch"]: save_every_epoch, translations["total_e"]: total_epoch, translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD): "", translations['training_version']: version, "Gpu": gpus, translations['batch_size']: batch_size, translations['pretrain_sr']: sample_rate, translations['training_f0']: pitch_guidance, translations['save_only_latest']: save_only_latest, translations['save_every_weights']: save_every_weights, translations['cache_in_gpu']: cache_data_in_gpu, translations['overtraining_detector']: overtraining_detector, translations['threshold']: overtraining_threshold, translations['cleanup_training']: cleanup, translations['memory_efficient_training']: checkpointing}
|
135 |
-
if model_author: log_data[translations["model_author"].format(model_author=model_author)] = ""
|
136 |
-
if vocoder != "Default": log_data[translations['vocoder']] = vocoder
|
137 |
-
|
138 |
-
for key, value in log_data.items():
|
139 |
-
logger.debug(f"{key}: {value}" if value != "" else f"{key} {value}")
|
140 |
-
|
141 |
-
def main():
|
142 |
-
global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author, vocoder, checkpointing
|
143 |
-
|
144 |
-
os.environ["MASTER_ADDR"] = "localhost"
|
145 |
-
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
146 |
-
|
147 |
-
if torch.cuda.is_available(): device, n_gpus = torch.device("cuda"), torch.cuda.device_count()
|
148 |
-
elif torch.backends.mps.is_available(): device, n_gpus = torch.device("mps"), 1
|
149 |
-
else: device, n_gpus = torch.device("cpu"), 1
|
150 |
-
|
151 |
-
def start():
|
152 |
-
children = []
|
153 |
-
pid_data = {"process_pids": []}
|
154 |
-
|
155 |
-
with open(config_save_path, "r") as pid_file:
|
156 |
-
try:
|
157 |
-
pid_data.update(json.load(pid_file))
|
158 |
-
except json.JSONDecodeError:
|
159 |
-
pass
|
160 |
-
|
161 |
-
with open(config_save_path, "w") as pid_file:
|
162 |
-
for i in range(n_gpus):
|
163 |
-
subproc = mp.Process(target=run, args=(i, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, total_epoch, save_every_weights, config, device, model_author, vocoder, checkpointing))
|
164 |
-
children.append(subproc)
|
165 |
-
|
166 |
-
subproc.start()
|
167 |
-
pid_data["process_pids"].append(subproc.pid)
|
168 |
-
|
169 |
-
json.dump(pid_data, pid_file, indent=4)
|
170 |
-
|
171 |
-
for i in range(n_gpus):
|
172 |
-
children[i].join()
|
173 |
-
|
174 |
-
def load_from_json(file_path):
|
175 |
-
if os.path.exists(file_path):
|
176 |
-
with open(file_path, "r") as f:
|
177 |
-
data = json.load(f)
|
178 |
-
return (data.get("loss_disc_history", []), data.get("smoothed_loss_disc_history", []), data.get("loss_gen_history", []), data.get("smoothed_loss_gen_history", []))
|
179 |
-
|
180 |
-
return [], [], [], []
|
181 |
-
|
182 |
-
def continue_overtrain_detector(training_file_path):
|
183 |
-
if overtraining_detector and os.path.exists(training_file_path): (loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) = load_from_json(training_file_path)
|
184 |
-
|
185 |
-
n_gpus = torch.cuda.device_count()
|
186 |
-
|
187 |
-
if not torch.cuda.is_available() and torch.backends.mps.is_available(): n_gpus = 1
|
188 |
-
if n_gpus < 1:
|
189 |
-
logger.warning(translations["not_gpu"])
|
190 |
-
n_gpus = 1
|
191 |
-
|
192 |
-
if cleanup:
|
193 |
-
for root, dirs, files in os.walk(experiment_dir, topdown=False):
|
194 |
-
for name in files:
|
195 |
-
file_path = os.path.join(root, name)
|
196 |
-
_, file_extension = os.path.splitext(name)
|
197 |
-
if (file_extension == ".0" or (name.startswith("D_") and file_extension == ".pth") or (name.startswith("G_") and file_extension == ".pth") or (file_extension == ".index")): os.remove(file_path)
|
198 |
-
|
199 |
-
for name in dirs:
|
200 |
-
if name == "eval":
|
201 |
-
folder_path = os.path.join(root, name)
|
202 |
-
|
203 |
-
for item in os.listdir(folder_path):
|
204 |
-
item_path = os.path.join(folder_path, item)
|
205 |
-
if os.path.isfile(item_path): os.remove(item_path)
|
206 |
-
|
207 |
-
os.rmdir(folder_path)
|
208 |
-
|
209 |
-
continue_overtrain_detector(training_file_path)
|
210 |
-
start()
|
211 |
-
|
212 |
-
def plot_spectrogram_to_numpy(spectrogram):
|
213 |
-
global MATPLOTLIB_FLAG
|
214 |
-
|
215 |
-
if not MATPLOTLIB_FLAG:
|
216 |
-
plt.switch_backend("Agg")
|
217 |
-
MATPLOTLIB_FLAG = True
|
218 |
-
|
219 |
-
fig, ax = plt.subplots(figsize=(10, 2))
|
220 |
-
|
221 |
-
plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax)
|
222 |
-
plt.xlabel("Frames")
|
223 |
-
plt.ylabel("Channels")
|
224 |
-
plt.tight_layout()
|
225 |
-
fig.canvas.draw()
|
226 |
-
plt.close(fig)
|
227 |
-
|
228 |
-
return np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
229 |
-
|
230 |
-
def verify_checkpoint_shapes(checkpoint_path, model):
|
231 |
-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
232 |
-
checkpoint_state_dict = checkpoint["model"]
|
233 |
-
try:
|
234 |
-
model_state_dict = model.module.load_state_dict(checkpoint_state_dict) if hasattr(model, "module") else model.load_state_dict(checkpoint_state_dict)
|
235 |
-
except RuntimeError:
|
236 |
-
logger.error(translations["checkpointing_err"])
|
237 |
-
sys.exit(1)
|
238 |
-
else: del checkpoint, checkpoint_state_dict, model_state_dict
|
239 |
-
|
240 |
-
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
|
241 |
-
for k, v in scalars.items():
|
242 |
-
writer.add_scalar(k, v, global_step)
|
243 |
-
|
244 |
-
for k, v in histograms.items():
|
245 |
-
writer.add_histogram(k, v, global_step)
|
246 |
-
|
247 |
-
for k, v in images.items():
|
248 |
-
writer.add_image(k, v, global_step, dataformats="HWC")
|
249 |
-
|
250 |
-
for k, v in audios.items():
|
251 |
-
writer.add_audio(k, v, global_step, audio_sample_rate)
|
252 |
-
|
253 |
-
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
254 |
-
assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
|
255 |
-
checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu"), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
|
256 |
-
new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()}
|
257 |
-
|
258 |
-
if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False)
|
259 |
-
else: model.load_state_dict(new_state_dict, strict=False)
|
260 |
-
|
261 |
-
if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
|
262 |
-
logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
|
263 |
-
return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"])
|
264 |
-
|
265 |
-
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
266 |
-
state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
|
267 |
-
torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path)
|
268 |
-
logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
|
269 |
-
|
270 |
-
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
271 |
-
checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
|
272 |
-
return checkpoints[-1] if checkpoints else None
|
273 |
-
|
274 |
-
def load_wav_to_torch(full_path):
|
275 |
-
data, sample_rate = sf.read(full_path, dtype='float32')
|
276 |
-
return torch.FloatTensor(data.astype(np.float32)), sample_rate
|
277 |
-
|
278 |
-
def load_filepaths_and_text(filename, split="|"):
|
279 |
-
with open(filename, encoding="utf-8") as f:
|
280 |
-
return [line.strip().split(split) for line in f]
|
281 |
-
|
282 |
-
def feature_loss(fmap_r, fmap_g):
|
283 |
-
loss = 0
|
284 |
-
for dr, dg in zip(fmap_r, fmap_g):
|
285 |
-
for rl, gl in zip(dr, dg):
|
286 |
-
loss += torch.mean(torch.abs(rl.float().detach() - gl.float()))
|
287 |
-
return loss * 2
|
288 |
-
|
289 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
290 |
-
loss = 0
|
291 |
-
r_losses, g_losses = [], []
|
292 |
-
|
293 |
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
294 |
-
dr = dr.float()
|
295 |
-
dg = dg.float()
|
296 |
-
r_loss = torch.mean((1 - dr) ** 2)
|
297 |
-
g_loss = torch.mean(dg**2)
|
298 |
-
loss += r_loss + g_loss
|
299 |
-
r_losses.append(r_loss.item())
|
300 |
-
g_losses.append(g_loss.item())
|
301 |
-
return loss, r_losses, g_losses
|
302 |
-
|
303 |
-
def generator_loss(disc_outputs):
|
304 |
-
loss = 0
|
305 |
-
gen_losses = []
|
306 |
-
|
307 |
-
for dg in disc_outputs:
|
308 |
-
l = torch.mean((1 - dg.float()) ** 2)
|
309 |
-
gen_losses.append(l)
|
310 |
-
loss += l
|
311 |
-
return loss, gen_losses
|
312 |
-
|
313 |
-
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
314 |
-
z_p = z_p.float()
|
315 |
-
logs_q = logs_q.float()
|
316 |
-
m_p = m_p.float()
|
317 |
-
logs_p = logs_p.float()
|
318 |
-
z_mask = z_mask.float()
|
319 |
-
kl = logs_p - logs_q - 0.5
|
320 |
-
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
321 |
-
return torch.sum(kl * z_mask) / torch.sum(z_mask)
|
322 |
-
|
323 |
-
class TextAudioLoaderMultiNSFsid(tdata.Dataset):
|
324 |
-
def __init__(self, hparams):
|
325 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
326 |
-
self.max_wav_value = hparams.max_wav_value
|
327 |
-
self.sample_rate = hparams.sample_rate
|
328 |
-
self.filter_length = hparams.filter_length
|
329 |
-
self.hop_length = hparams.hop_length
|
330 |
-
self.win_length = hparams.win_length
|
331 |
-
self.sample_rate = hparams.sample_rate
|
332 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
333 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
334 |
-
self._filter()
|
335 |
-
|
336 |
-
def _filter(self):
|
337 |
-
audiopaths_and_text_new, lengths = [], []
|
338 |
-
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
|
339 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
340 |
-
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
|
341 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
342 |
-
|
343 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
344 |
-
self.lengths = lengths
|
345 |
-
|
346 |
-
def get_sid(self, sid):
|
347 |
-
try:
|
348 |
-
sid = torch.LongTensor([int(sid)])
|
349 |
-
except ValueError as e:
|
350 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
351 |
-
sid = torch.LongTensor([0])
|
352 |
-
return sid
|
353 |
-
|
354 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
355 |
-
phone, pitch, pitchf = self.get_labels(audiopath_and_text[1], audiopath_and_text[2], audiopath_and_text[3])
|
356 |
-
spec, wav = self.get_audio(audiopath_and_text[0])
|
357 |
-
dv = self.get_sid(audiopath_and_text[4])
|
358 |
-
len_phone = phone.size()[0]
|
359 |
-
len_spec = spec.size()[-1]
|
360 |
-
|
361 |
-
if len_phone != len_spec:
|
362 |
-
len_min = min(len_phone, len_spec)
|
363 |
-
len_wav = len_min * self.hop_length
|
364 |
-
spec, wav, phone = spec[:, :len_min], wav[:, :len_wav], phone[:len_min, :]
|
365 |
-
pitch, pitchf = pitch[:len_min], pitchf[:len_min]
|
366 |
-
return (spec, wav, phone, pitch, pitchf, dv)
|
367 |
-
|
368 |
-
def get_labels(self, phone, pitch, pitchf):
|
369 |
-
phone = np.repeat(np.load(phone), 2, axis=0)
|
370 |
-
n_num = min(phone.shape[0], 900)
|
371 |
-
return torch.FloatTensor(phone[:n_num, :]), torch.LongTensor(np.load(pitch)[:n_num]), torch.FloatTensor(np.load(pitchf)[:n_num])
|
372 |
-
|
373 |
-
def get_audio(self, filename):
|
374 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
375 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
376 |
-
audio_norm = audio.unsqueeze(0)
|
377 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
378 |
-
|
379 |
-
if os.path.exists(spec_filename):
|
380 |
-
try:
|
381 |
-
spec = torch.load(spec_filename)
|
382 |
-
except Exception as e:
|
383 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
384 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
385 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
386 |
-
else:
|
387 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
388 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
389 |
-
return spec, audio_norm
|
390 |
-
|
391 |
-
def __getitem__(self, index):
|
392 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
393 |
-
|
394 |
-
def __len__(self):
|
395 |
-
return len(self.audiopaths_and_text)
|
396 |
-
|
397 |
-
class TextAudioCollateMultiNSFsid:
|
398 |
-
def __init__(self, return_ids=False):
|
399 |
-
self.return_ids = return_ids
|
400 |
-
|
401 |
-
def __call__(self, batch):
|
402 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
403 |
-
spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
|
404 |
-
spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
|
405 |
-
spec_padded.zero_()
|
406 |
-
wave_padded.zero_()
|
407 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
408 |
-
phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
409 |
-
pitch_padded, pitchf_padded = torch.LongTensor(len(batch), max_phone_len), torch.FloatTensor(len(batch), max_phone_len)
|
410 |
-
phone_padded.zero_()
|
411 |
-
pitch_padded.zero_()
|
412 |
-
pitchf_padded.zero_()
|
413 |
-
sid = torch.LongTensor(len(batch))
|
414 |
-
|
415 |
-
for i in range(len(ids_sorted_decreasing)):
|
416 |
-
row = batch[ids_sorted_decreasing[i]]
|
417 |
-
spec = row[0]
|
418 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
419 |
-
spec_lengths[i] = spec.size(1)
|
420 |
-
wave = row[1]
|
421 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
422 |
-
wave_lengths[i] = wave.size(1)
|
423 |
-
phone = row[2]
|
424 |
-
phone_padded[i, : phone.size(0), :] = phone
|
425 |
-
phone_lengths[i] = phone.size(0)
|
426 |
-
pitch = row[3]
|
427 |
-
pitch_padded[i, : pitch.size(0)] = pitch
|
428 |
-
pitchf = row[4]
|
429 |
-
pitchf_padded[i, : pitchf.size(0)] = pitchf
|
430 |
-
sid[i] = row[5]
|
431 |
-
return (phone_padded, phone_lengths, pitch_padded, pitchf_padded, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
|
432 |
-
|
433 |
-
class TextAudioLoader(tdata.Dataset):
|
434 |
-
def __init__(self, hparams):
|
435 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
436 |
-
self.max_wav_value = hparams.max_wav_value
|
437 |
-
self.sample_rate = hparams.sample_rate
|
438 |
-
self.filter_length = hparams.filter_length
|
439 |
-
self.hop_length = hparams.hop_length
|
440 |
-
self.win_length = hparams.win_length
|
441 |
-
self.sample_rate = hparams.sample_rate
|
442 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
443 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
444 |
-
self._filter()
|
445 |
-
|
446 |
-
def _filter(self):
|
447 |
-
audiopaths_and_text_new, lengths = [], []
|
448 |
-
for entry in self.audiopaths_and_text:
|
449 |
-
if len(entry) >= 3:
|
450 |
-
audiopath, text, dv = entry[:3]
|
451 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
452 |
-
audiopaths_and_text_new.append([audiopath, text, dv])
|
453 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
454 |
-
|
455 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
456 |
-
self.lengths = lengths
|
457 |
-
|
458 |
-
def get_sid(self, sid):
|
459 |
-
try:
|
460 |
-
sid = torch.LongTensor([int(sid)])
|
461 |
-
except ValueError as e:
|
462 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
463 |
-
sid = torch.LongTensor([0])
|
464 |
-
return sid
|
465 |
-
|
466 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
467 |
-
phone = self.get_labels(audiopath_and_text[1])
|
468 |
-
spec, wav = self.get_audio(audiopath_and_text[0])
|
469 |
-
dv = self.get_sid(audiopath_and_text[2])
|
470 |
-
len_phone = phone.size()[0]
|
471 |
-
len_spec = spec.size()[-1]
|
472 |
-
|
473 |
-
if len_phone != len_spec:
|
474 |
-
len_min = min(len_phone, len_spec)
|
475 |
-
len_wav = len_min * self.hop_length
|
476 |
-
spec = spec[:, :len_min]
|
477 |
-
wav = wav[:, :len_wav]
|
478 |
-
phone = phone[:len_min, :]
|
479 |
-
return (spec, wav, phone, dv)
|
480 |
-
|
481 |
-
def get_labels(self, phone):
|
482 |
-
phone = np.repeat(np.load(phone), 2, axis=0)
|
483 |
-
return torch.FloatTensor(phone[:min(phone.shape[0], 900), :])
|
484 |
-
|
485 |
-
def get_audio(self, filename):
|
486 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
487 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
488 |
-
audio_norm = audio.unsqueeze(0)
|
489 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
490 |
-
|
491 |
-
if os.path.exists(spec_filename):
|
492 |
-
try:
|
493 |
-
spec = torch.load(spec_filename)
|
494 |
-
except Exception as e:
|
495 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
496 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
497 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
498 |
-
else:
|
499 |
-
spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
|
500 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
501 |
-
return spec, audio_norm
|
502 |
-
|
503 |
-
def __getitem__(self, index):
|
504 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
505 |
-
|
506 |
-
def __len__(self):
|
507 |
-
return len(self.audiopaths_and_text)
|
508 |
-
|
509 |
-
class TextAudioCollate:
|
510 |
-
def __init__(self, return_ids=False):
|
511 |
-
self.return_ids = return_ids
|
512 |
-
|
513 |
-
def __call__(self, batch):
|
514 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
515 |
-
spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
|
516 |
-
spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
|
517 |
-
spec_padded.zero_()
|
518 |
-
wave_padded.zero_()
|
519 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
520 |
-
phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
521 |
-
phone_padded.zero_()
|
522 |
-
sid = torch.LongTensor(len(batch))
|
523 |
-
for i in range(len(ids_sorted_decreasing)):
|
524 |
-
row = batch[ids_sorted_decreasing[i]]
|
525 |
-
spec = row[0]
|
526 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
527 |
-
spec_lengths[i] = spec.size(1)
|
528 |
-
wave = row[1]
|
529 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
530 |
-
wave_lengths[i] = wave.size(1)
|
531 |
-
phone = row[2]
|
532 |
-
phone_padded[i, : phone.size(0), :] = phone
|
533 |
-
phone_lengths[i] = phone.size(0)
|
534 |
-
sid[i] = row[3]
|
535 |
-
return (phone_padded, phone_lengths, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
|
536 |
-
|
537 |
-
class DistributedBucketSampler(tdata.distributed.DistributedSampler):
|
538 |
-
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
539 |
-
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
540 |
-
self.lengths = dataset.lengths
|
541 |
-
self.batch_size = batch_size
|
542 |
-
self.boundaries = boundaries
|
543 |
-
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
544 |
-
self.total_size = sum(self.num_samples_per_bucket)
|
545 |
-
self.num_samples = self.total_size // self.num_replicas
|
546 |
-
|
547 |
-
def _create_buckets(self):
|
548 |
-
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
549 |
-
for i in range(len(self.lengths)):
|
550 |
-
idx_bucket = self._bisect(self.lengths[i])
|
551 |
-
if idx_bucket != -1: buckets[idx_bucket].append(i)
|
552 |
-
|
553 |
-
for i in range(len(buckets) - 1, -1, -1):
|
554 |
-
if len(buckets[i]) == 0:
|
555 |
-
buckets.pop(i)
|
556 |
-
self.boundaries.pop(i + 1)
|
557 |
-
|
558 |
-
num_samples_per_bucket = []
|
559 |
-
for i in range(len(buckets)):
|
560 |
-
len_bucket = len(buckets[i])
|
561 |
-
total_batch_size = self.num_replicas * self.batch_size
|
562 |
-
num_samples_per_bucket.append(len_bucket + ((total_batch_size - (len_bucket % total_batch_size)) % total_batch_size))
|
563 |
-
return buckets, num_samples_per_bucket
|
564 |
-
|
565 |
-
def __iter__(self):
|
566 |
-
g = torch.Generator()
|
567 |
-
g.manual_seed(self.epoch)
|
568 |
-
indices, batches = [], []
|
569 |
-
if self.shuffle:
|
570 |
-
for bucket in self.buckets:
|
571 |
-
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
572 |
-
else:
|
573 |
-
for bucket in self.buckets:
|
574 |
-
indices.append(list(range(len(bucket))))
|
575 |
-
|
576 |
-
for i in range(len(self.buckets)):
|
577 |
-
bucket = self.buckets[i]
|
578 |
-
len_bucket = len(bucket)
|
579 |
-
ids_bucket = indices[i]
|
580 |
-
rem = self.num_samples_per_bucket[i] - len_bucket
|
581 |
-
ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])[self.rank :: self.num_replicas]
|
582 |
-
|
583 |
-
for j in range(len(ids_bucket) // self.batch_size):
|
584 |
-
batches.append([bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]])
|
585 |
-
|
586 |
-
if self.shuffle: batches = [batches[i] for i in torch.randperm(len(batches), generator=g).tolist()]
|
587 |
-
self.batches = batches
|
588 |
-
assert len(self.batches) * self.batch_size == self.num_samples
|
589 |
-
return iter(self.batches)
|
590 |
-
|
591 |
-
def _bisect(self, x, lo=0, hi=None):
|
592 |
-
if hi is None: hi = len(self.boundaries) - 1
|
593 |
-
|
594 |
-
if hi > lo:
|
595 |
-
mid = (hi + lo) // 2
|
596 |
-
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid
|
597 |
-
elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid)
|
598 |
-
else: return self._bisect(x, mid + 1, hi)
|
599 |
-
else: return -1
|
600 |
-
|
601 |
-
def __len__(self):
|
602 |
-
return self.num_samples // self.batch_size
|
603 |
-
|
604 |
-
class MultiPeriodDiscriminator(torch.nn.Module):
|
605 |
-
def __init__(self, version, use_spectral_norm=False, checkpointing=False):
|
606 |
-
super(MultiPeriodDiscriminator, self).__init__()
|
607 |
-
self.checkpointing = checkpointing
|
608 |
-
periods = ([2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37])
|
609 |
-
self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm, checkpointing=checkpointing)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing) for p in periods])
|
610 |
-
|
611 |
-
def forward(self, y, y_hat):
|
612 |
-
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
613 |
-
for d in self.discriminators:
|
614 |
-
if self.training and self.checkpointing:
|
615 |
-
def forward_discriminator(d, y, y_hat):
|
616 |
-
y_d_r, fmap_r = d(y)
|
617 |
-
y_d_g, fmap_g = d(y_hat)
|
618 |
-
return y_d_r, fmap_r, y_d_g, fmap_g
|
619 |
-
y_d_r, fmap_r, y_d_g, fmap_g = checkpoint.checkpoint(forward_discriminator, d, y, y_hat, use_reentrant=False)
|
620 |
-
else:
|
621 |
-
y_d_r, fmap_r = d(y)
|
622 |
-
y_d_g, fmap_g = d(y_hat)
|
623 |
-
|
624 |
-
y_d_rs.append(y_d_r)
|
625 |
-
y_d_gs.append(y_d_g)
|
626 |
-
fmap_rs.append(fmap_r)
|
627 |
-
fmap_gs.append(fmap_g)
|
628 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
629 |
-
|
630 |
-
class DiscriminatorS(torch.nn.Module):
|
631 |
-
def __init__(self, use_spectral_norm=False, checkpointing=False):
|
632 |
-
super(DiscriminatorS, self).__init__()
|
633 |
-
self.checkpointing = checkpointing
|
634 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
635 |
-
self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))])
|
636 |
-
self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
|
637 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
638 |
-
|
639 |
-
def forward(self, x):
|
640 |
-
fmap = []
|
641 |
-
for conv in self.convs:
|
642 |
-
x = checkpoint.checkpoint(self.lrelu, checkpoint.checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
|
643 |
-
fmap.append(x)
|
644 |
-
|
645 |
-
x = self.conv_post(x)
|
646 |
-
fmap.append(x)
|
647 |
-
return torch.flatten(x, 1, -1), fmap
|
648 |
-
|
649 |
-
class DiscriminatorP(torch.nn.Module):
|
650 |
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, checkpointing=False):
|
651 |
-
super(DiscriminatorP, self).__init__()
|
652 |
-
self.period = period
|
653 |
-
self.checkpointing = checkpointing
|
654 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
655 |
-
self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv2d(in_ch, out_ch, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))) for in_ch, out_ch in zip([1, 32, 128, 512, 1024], [32, 128, 512, 1024, 1024])])
|
656 |
-
self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
657 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
658 |
-
|
659 |
-
def forward(self, x):
|
660 |
-
fmap = []
|
661 |
-
b, c, t = x.shape
|
662 |
-
if t % self.period != 0: x = torch.nn.functional.pad(x, (0, (self.period - (t % self.period))), "reflect")
|
663 |
-
x = x.view(b, c, -1, self.period)
|
664 |
-
for conv in self.convs:
|
665 |
-
x = checkpoint.checkpoint(self.lrelu, checkpoint.checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
|
666 |
-
fmap.append(x)
|
667 |
-
|
668 |
-
x = self.conv_post(x)
|
669 |
-
fmap.append(x)
|
670 |
-
return torch.flatten(x, 1, -1), fmap
|
671 |
-
|
672 |
-
class EpochRecorder:
|
673 |
-
def __init__(self):
|
674 |
-
self.last_time = ttime()
|
675 |
-
|
676 |
-
def record(self):
|
677 |
-
now_time = ttime()
|
678 |
-
elapsed_time = now_time - self.last_time
|
679 |
-
self.last_time = now_time
|
680 |
-
return translations["time_or_speed_training"].format(current_time=datetime.datetime.now().strftime("%H:%M:%S"), elapsed_time_str=str(datetime.timedelta(seconds=int(round(elapsed_time, 1)))))
|
681 |
-
|
682 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
683 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
684 |
-
|
685 |
-
def dynamic_range_decompression_torch(x, C=1):
|
686 |
-
return torch.exp(x) / C
|
687 |
-
|
688 |
-
def spectral_normalize_torch(magnitudes):
|
689 |
-
return dynamic_range_compression_torch(magnitudes)
|
690 |
-
|
691 |
-
def spectral_de_normalize_torch(magnitudes):
|
692 |
-
return dynamic_range_decompression_torch(magnitudes)
|
693 |
-
|
694 |
-
mel_basis, hann_window = {}, {}
|
695 |
-
|
696 |
-
def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
|
697 |
-
global hann_window
|
698 |
-
|
699 |
-
wnsize_dtype_device = str(win_size) + "_" + str(y.dtype) + "_" + str(y.device)
|
700 |
-
if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
701 |
-
spec = torch.stft(torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect").squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
702 |
-
return torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
|
703 |
-
|
704 |
-
def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
705 |
-
global mel_basis
|
706 |
-
|
707 |
-
fmax_dtype_device = str(fmax) + "_" + str(spec.dtype) + "_" + str(spec.device)
|
708 |
-
if fmax_dtype_device not in mel_basis: mel_basis[fmax_dtype_device] = torch.from_numpy(librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)).to(dtype=spec.dtype, device=spec.device)
|
709 |
-
return spectral_normalize_torch(torch.matmul(mel_basis[fmax_dtype_device], spec))
|
710 |
-
|
711 |
-
def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
|
712 |
-
return spec_to_mel_torch(spectrogram_torch(y, n_fft, hop_size, win_size, center), n_fft, num_mels, sample_rate, fmin, fmax)
|
713 |
-
|
714 |
-
def replace_keys_in_dict(d, old_key_part, new_key_part):
|
715 |
-
updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
|
716 |
-
for key, value in d.items():
|
717 |
-
updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
|
718 |
-
return updated_dict
|
719 |
-
|
720 |
-
def extract_model(ckpt, sr, pitch_guidance, name, model_path, epoch, step, version, hps, model_author, vocoder):
|
721 |
-
try:
|
722 |
-
logger.info(translations["savemodel"].format(model_dir=model_path, epoch=epoch, step=step))
|
723 |
-
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
724 |
-
|
725 |
-
opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key})
|
726 |
-
opt["config"] = [hps.data.filter_length // 2 + 1, 32, hps.model.inter_channels, hps.model.hidden_channels, hps.model.filter_channels, hps.model.n_heads, hps.model.n_layers, hps.model.kernel_size, hps.model.p_dropout, hps.model.resblock, hps.model.resblock_kernel_sizes, hps.model.resblock_dilation_sizes, hps.model.upsample_rates, hps.model.upsample_initial_channel, hps.model.upsample_kernel_sizes, hps.model.spk_embed_dim, hps.model.gin_channels, hps.data.sample_rate]
|
727 |
-
opt["epoch"] = f"{epoch}epoch"
|
728 |
-
opt["step"] = step
|
729 |
-
opt["sr"] = sr
|
730 |
-
opt["f0"] = int(pitch_guidance)
|
731 |
-
opt["version"] = version
|
732 |
-
opt["creation_date"] = datetime.datetime.now().isoformat()
|
733 |
-
opt["model_hash"] = hashlib.sha256(f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}".encode()).hexdigest()
|
734 |
-
opt["model_name"] = name
|
735 |
-
opt["author"] = model_author
|
736 |
-
opt["vocoder"] = vocoder
|
737 |
-
|
738 |
-
torch.save(replace_keys_in_dict(replace_keys_in_dict(opt, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), model_path)
|
739 |
-
except Exception as e:
|
740 |
-
logger.error(f"{translations['extract_model_error']}: {e}")
|
741 |
-
|
742 |
-
def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, model_author, vocoder, checkpointing):
|
743 |
-
global global_step
|
744 |
-
|
745 |
-
if rank == 0: writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
|
746 |
-
else: writer_eval = None
|
747 |
-
|
748 |
-
dist.init_process_group(backend="gloo", init_method="env://", world_size=n_gpus, rank=rank)
|
749 |
-
torch.manual_seed(config.train.seed)
|
750 |
-
if torch.cuda.is_available(): torch.cuda.set_device(rank)
|
751 |
-
|
752 |
-
train_dataset = TextAudioLoaderMultiNSFsid(config.data)
|
753 |
-
train_loader = tdata.DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=TextAudioCollateMultiNSFsid(), batch_sampler=DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True), persistent_workers=True, prefetch_factor=8)
|
754 |
-
|
755 |
-
net_g, net_d = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance, sr=sample_rate, vocoder=vocoder, checkpointing=checkpointing), MultiPeriodDiscriminator(version, config.model.use_spectral_norm, checkpointing=checkpointing)
|
756 |
-
|
757 |
-
if torch.cuda.is_available(): net_g, net_d = net_g.cuda(rank), net_d.cuda(rank)
|
758 |
-
else: net_g, net_d = net_g.to(device), net_d.to(device)
|
759 |
-
optim_g, optim_d = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps), torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
|
760 |
-
net_g, net_d = (DDP(net_g, device_ids=[rank]), DDP(net_d, device_ids=[rank])) if torch.cuda.is_available() else (DDP(net_g), DDP(net_d))
|
761 |
-
|
762 |
-
try:
|
763 |
-
logger.info(translations["start_training"])
|
764 |
-
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d)
|
765 |
-
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g)
|
766 |
-
epoch_str += 1
|
767 |
-
global_step = (epoch_str - 1) * len(train_loader)
|
768 |
-
except:
|
769 |
-
epoch_str, global_step = 1, 0
|
770 |
-
|
771 |
-
if pretrainG != "" and pretrainG != "None":
|
772 |
-
if rank == 0:
|
773 |
-
verify_checkpoint_shapes(pretrainG, net_g)
|
774 |
-
logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG))
|
775 |
-
|
776 |
-
if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
777 |
-
else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
778 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="G"))
|
779 |
-
|
780 |
-
if pretrainD != "" and pretrainD != "None":
|
781 |
-
if rank == 0:
|
782 |
-
verify_checkpoint_shapes(pretrainD, net_d)
|
783 |
-
logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD))
|
784 |
-
|
785 |
-
if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
786 |
-
else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
787 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="D"))
|
788 |
-
|
789 |
-
scheduler_g, scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2), torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
|
790 |
-
|
791 |
-
optim_d.step()
|
792 |
-
optim_g.step()
|
793 |
-
|
794 |
-
scaler = GradScaler(enabled=False)
|
795 |
-
cache = []
|
796 |
-
|
797 |
-
for info in train_loader:
|
798 |
-
phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
|
799 |
-
reference = (phone.cuda(rank, non_blocking=True), phone_lengths.cuda(rank, non_blocking=True), (pitch.cuda(rank, non_blocking=True) if pitch_guidance else None), (pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None), sid.cuda(rank, non_blocking=True)) if device.type == "cuda" else (phone.to(device), phone_lengths.to(device), (pitch.to(device) if pitch_guidance else None), (pitchf.to(device) if pitch_guidance else None), sid.to(device))
|
800 |
-
break
|
801 |
-
|
802 |
-
for epoch in range(epoch_str, total_epoch + 1):
|
803 |
-
train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, train_loader, writer_eval, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author, vocoder)
|
804 |
-
scheduler_g.step()
|
805 |
-
scheduler_d.step()
|
806 |
-
|
807 |
-
def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, train_loader, writer, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author, vocoder):
|
808 |
-
global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
|
809 |
-
|
810 |
-
if epoch == 1:
|
811 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
812 |
-
last_loss_gen_all, consecutive_increases_gen, consecutive_increases_disc = 0.0, 0, 0
|
813 |
-
|
814 |
-
net_g, net_d = nets
|
815 |
-
optim_g, optim_d = optims
|
816 |
-
train_loader.batch_sampler.set_epoch(epoch)
|
817 |
-
|
818 |
-
net_g.train()
|
819 |
-
net_d.train()
|
820 |
-
|
821 |
-
if device.type == "cuda" and cache_data_in_gpu:
|
822 |
-
data_iterator = cache
|
823 |
-
if cache == []:
|
824 |
-
for batch_idx, info in enumerate(train_loader):
|
825 |
-
cache.append((batch_idx, [tensor.cuda(rank, non_blocking=True) for tensor in info]))
|
826 |
-
else: shuffle(cache)
|
827 |
-
else: data_iterator = enumerate(train_loader)
|
828 |
-
|
829 |
-
with tqdm(total=len(train_loader), leave=False) as pbar:
|
830 |
-
for batch_idx, info in data_iterator:
|
831 |
-
if device.type == "cuda" and not cache_data_in_gpu: info = [tensor.cuda(rank, non_blocking=True) for tensor in info]
|
832 |
-
elif device.type != "cuda": info = [tensor.to(device) for tensor in info]
|
833 |
-
|
834 |
-
phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, _, sid = info
|
835 |
-
pitch = pitch if pitch_guidance else None
|
836 |
-
pitchf = pitchf if pitch_guidance else None
|
837 |
-
|
838 |
-
with autocast(enabled=False):
|
839 |
-
model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
840 |
-
y_hat, ids_slice, _, z_mask, (_, z_p, m_p, logs_p, _, logs_q) = model_output
|
841 |
-
|
842 |
-
mel = spec_to_mel_torch(spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax)
|
843 |
-
y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3)
|
844 |
-
|
845 |
-
with autocast(enabled=False):
|
846 |
-
y_hat_mel = mel_spectrogram_torch(y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax)
|
847 |
-
|
848 |
-
wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3)
|
849 |
-
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
850 |
-
|
851 |
-
with autocast(enabled=False):
|
852 |
-
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
853 |
-
|
854 |
-
optim_d.zero_grad()
|
855 |
-
scaler.scale(loss_disc).backward()
|
856 |
-
scaler.unscale_(optim_d)
|
857 |
-
grad_norm_d = clip_grad_value(net_d.parameters(), None)
|
858 |
-
scaler.step(optim_d)
|
859 |
-
|
860 |
-
with autocast(enabled=False):
|
861 |
-
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
862 |
-
with autocast(enabled=False):
|
863 |
-
loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
|
864 |
-
loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl)
|
865 |
-
loss_fm = feature_loss(fmap_r, fmap_g)
|
866 |
-
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
867 |
-
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
868 |
-
|
869 |
-
if loss_gen_all < lowest_value["value"]:
|
870 |
-
lowest_value["value"] = loss_gen_all
|
871 |
-
lowest_value["step"] = global_step
|
872 |
-
lowest_value["epoch"] = epoch
|
873 |
-
if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"])
|
874 |
-
|
875 |
-
optim_g.zero_grad()
|
876 |
-
scaler.scale(loss_gen_all).backward()
|
877 |
-
scaler.unscale_(optim_g)
|
878 |
-
grad_norm_g = clip_grad_value(net_g.parameters(), None)
|
879 |
-
scaler.step(optim_g)
|
880 |
-
scaler.update()
|
881 |
-
|
882 |
-
if rank == 0:
|
883 |
-
if global_step % config.train.log_interval == 0:
|
884 |
-
if loss_mel > 75: loss_mel = 75
|
885 |
-
if loss_kl > 9: loss_kl = 9
|
886 |
-
|
887 |
-
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": optim_g.param_groups[0]["lr"], "grad/norm_d": grad_norm_d, "grad/norm_g": grad_norm_g, "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}
|
888 |
-
scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
|
889 |
-
scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)})
|
890 |
-
scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)})
|
891 |
-
|
892 |
-
with torch.no_grad():
|
893 |
-
o, *_ = net_g.module.infer(*reference) if hasattr(net_g, "module") else net_g.infer(*reference)
|
894 |
-
|
895 |
-
summarize(writer=writer, global_step=global_step, images={"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy())}, scalars=scalar_dict, audios={f"gen/audio_{global_step:07d}": o[0, :, :]}, audio_sample_rate=config.data.sample_rate)
|
896 |
-
|
897 |
-
global_step += 1
|
898 |
-
pbar.update(1)
|
899 |
-
|
900 |
-
def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
|
901 |
-
if len(smoothed_loss_history) < threshold + 1: return False
|
902 |
-
|
903 |
-
for i in range(-threshold, -1):
|
904 |
-
if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True
|
905 |
-
if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False
|
906 |
-
|
907 |
-
return True
|
908 |
-
|
909 |
-
def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987):
|
910 |
-
smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value)
|
911 |
-
smoothed_loss_history.append(smoothed_value)
|
912 |
-
return smoothed_value
|
913 |
-
|
914 |
-
def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history):
|
915 |
-
with open(file_path, "w") as f:
|
916 |
-
json.dump({"loss_disc_history": loss_disc_history, "smoothed_loss_disc_history": smoothed_loss_disc_history, "loss_gen_history": loss_gen_history, "smoothed_loss_gen_history": smoothed_loss_gen_history}, f)
|
917 |
-
|
918 |
-
model_add, model_del = [], []
|
919 |
-
done = False
|
920 |
-
|
921 |
-
if rank == 0:
|
922 |
-
if epoch % save_every_epoch == False:
|
923 |
-
checkpoint_suffix = f"{'latest' if save_only_latest else global_step}.pth"
|
924 |
-
save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix))
|
925 |
-
save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix))
|
926 |
-
if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
927 |
-
|
928 |
-
if overtraining_detector and epoch > 1:
|
929 |
-
current_loss_disc = float(loss_disc)
|
930 |
-
loss_disc_history.append(current_loss_disc)
|
931 |
-
smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc)
|
932 |
-
is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2)
|
933 |
-
|
934 |
-
if is_overtraining_disc: consecutive_increases_disc += 1
|
935 |
-
else: consecutive_increases_disc = 0
|
936 |
-
|
937 |
-
current_loss_gen = float(lowest_value["value"])
|
938 |
-
loss_gen_history.append(current_loss_gen)
|
939 |
-
smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen)
|
940 |
-
is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01)
|
941 |
-
|
942 |
-
if is_overtraining_gen: consecutive_increases_gen += 1
|
943 |
-
else: consecutive_increases_gen = 0
|
944 |
-
|
945 |
-
if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history)
|
946 |
-
|
947 |
-
if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)):
|
948 |
-
logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
949 |
-
done = True
|
950 |
-
else:
|
951 |
-
logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
952 |
-
for file in glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth")):
|
953 |
-
model_del.append(file)
|
954 |
-
|
955 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth"))
|
956 |
-
|
957 |
-
if epoch >= custom_total_epoch:
|
958 |
-
logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3)))
|
959 |
-
logger.info(translations["training_info"].format(lowest_value_rounded=round(float(lowest_value["value"]), 3), lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
960 |
-
|
961 |
-
pid_file_path = os.path.join(experiment_dir, "config.json")
|
962 |
-
|
963 |
-
with open(pid_file_path, "r") as pid_file:
|
964 |
-
pid_data = json.load(pid_file)
|
965 |
-
|
966 |
-
with open(pid_file_path, "w") as pid_file:
|
967 |
-
pid_data.pop("process_pids", None)
|
968 |
-
json.dump(pid_data, pid_file, indent=4)
|
969 |
-
|
970 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
971 |
-
done = True
|
972 |
-
|
973 |
-
for m in model_del:
|
974 |
-
os.remove(m)
|
975 |
-
|
976 |
-
if model_add:
|
977 |
-
ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict())
|
978 |
-
|
979 |
-
for m in model_add:
|
980 |
-
extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_path=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author, vocoder=vocoder)
|
981 |
-
|
982 |
-
lowest_value_rounded = round(float(lowest_value["value"]), 3)
|
983 |
-
epoch_recorder = EpochRecorder()
|
984 |
-
|
985 |
-
if epoch > 1 and overtraining_detector: logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=(overtraining_threshold - consecutive_increases_gen), remaining_epochs_disc=((overtraining_threshold * 2) - consecutive_increases_disc), smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
986 |
-
elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
987 |
-
else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record()))
|
988 |
-
|
989 |
-
last_loss_gen_all = loss_gen_all
|
990 |
-
if done: os._exit(0)
|
991 |
-
|
992 |
-
if __name__ == "__main__":
|
993 |
-
torch.multiprocessing.set_start_method("spawn")
|
994 |
-
try:
|
995 |
-
main()
|
996 |
-
except Exception as e:
|
997 |
-
logger.error(f"{translations['training_error']} {e}")
|
998 |
-
|
999 |
-
import traceback
|
1000 |
-
logger.debug(traceback.format_exc())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/commons.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
def init_weights(m, mean=0.0, std=0.01):
|
5 |
-
if m.__class__.__name__.find("Conv") != -1: m.weight.data.normal_(mean, std)
|
6 |
-
|
7 |
-
def get_padding(kernel_size, dilation=1):
|
8 |
-
return int((kernel_size * dilation - dilation) / 2)
|
9 |
-
|
10 |
-
def convert_pad_shape(pad_shape):
|
11 |
-
return [item for sublist in pad_shape[::-1] for item in sublist]
|
12 |
-
|
13 |
-
def slice_segments(x, ids_str, segment_size = 4, dim = 2):
|
14 |
-
if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
|
15 |
-
elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
|
16 |
-
for i in range(x.size(0)):
|
17 |
-
idx_str = ids_str[i].item()
|
18 |
-
idx_end = idx_str + segment_size
|
19 |
-
if dim == 2: ret[i] = x[i, idx_str:idx_end]
|
20 |
-
else: ret[i] = x[i, :, idx_str:idx_end]
|
21 |
-
return ret
|
22 |
-
|
23 |
-
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
24 |
-
b, _, t = x.size()
|
25 |
-
if x_lengths is None: x_lengths = t
|
26 |
-
ids_str = (torch.rand([b]).to(device=x.device) * (x_lengths - segment_size + 1)).to(dtype=torch.long)
|
27 |
-
return slice_segments(x, ids_str, segment_size, dim=3), ids_str
|
28 |
-
|
29 |
-
@torch.jit.script
|
30 |
-
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
31 |
-
n_channels_int = n_channels[0]
|
32 |
-
in_act = input_a + input_b
|
33 |
-
return torch.tanh(in_act[:, :n_channels_int, :]) * torch.sigmoid(in_act[:, n_channels_int:, :])
|
34 |
-
|
35 |
-
def convert_pad_shape(pad_shape):
|
36 |
-
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
|
37 |
-
|
38 |
-
def sequence_mask(length, max_length = None):
|
39 |
-
if max_length is None: max_length = length.max()
|
40 |
-
return torch.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
|
41 |
-
|
42 |
-
def clip_grad_value(parameters, clip_value, norm_type=2):
|
43 |
-
if isinstance(parameters, torch.Tensor): parameters = [parameters]
|
44 |
-
norm_type = float(norm_type)
|
45 |
-
if clip_value is not None: clip_value = float(clip_value)
|
46 |
-
total_norm = 0
|
47 |
-
for p in list(filter(lambda p: p.grad is not None, parameters)):
|
48 |
-
total_norm += (p.grad.data.norm(norm_type)).item() ** norm_type
|
49 |
-
if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
50 |
-
return total_norm ** (1.0 / norm_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/modules.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
|
5 |
-
sys.path.append(os.getcwd())
|
6 |
-
|
7 |
-
from .commons import fused_add_tanh_sigmoid_multiply
|
8 |
-
|
9 |
-
|
10 |
-
class WaveNet(torch.nn.Module):
|
11 |
-
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
12 |
-
super(WaveNet, self).__init__()
|
13 |
-
assert kernel_size % 2 == 1
|
14 |
-
self.hidden_channels = hidden_channels
|
15 |
-
self.kernel_size = (kernel_size,)
|
16 |
-
self.dilation_rate = dilation_rate
|
17 |
-
self.n_layers = n_layers
|
18 |
-
self.gin_channels = gin_channels
|
19 |
-
self.p_dropout = p_dropout
|
20 |
-
self.in_layers = torch.nn.ModuleList()
|
21 |
-
self.res_skip_layers = torch.nn.ModuleList()
|
22 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
23 |
-
|
24 |
-
if gin_channels != 0: self.cond_layer = torch.nn.utils.parametrizations.weight_norm(torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), name="weight")
|
25 |
-
|
26 |
-
dilations = [dilation_rate**i for i in range(n_layers)]
|
27 |
-
paddings = [(kernel_size * d - d) // 2 for d in dilations]
|
28 |
-
|
29 |
-
for i in range(n_layers):
|
30 |
-
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
|
31 |
-
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
32 |
-
self.in_layers.append(in_layer)
|
33 |
-
|
34 |
-
res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
|
35 |
-
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
36 |
-
|
37 |
-
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
38 |
-
self.res_skip_layers.append(res_skip_layer)
|
39 |
-
|
40 |
-
def forward(self, x, x_mask, g=None, **kwargs):
|
41 |
-
output = torch.zeros_like(x)
|
42 |
-
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
43 |
-
|
44 |
-
if g is not None: g = self.cond_layer(g)
|
45 |
-
|
46 |
-
for i in range(self.n_layers):
|
47 |
-
x_in = self.in_layers[i](x)
|
48 |
-
|
49 |
-
if g is not None:
|
50 |
-
cond_offset = i * 2 * self.hidden_channels
|
51 |
-
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
52 |
-
else: g_l = torch.zeros_like(x_in)
|
53 |
-
|
54 |
-
res_skip_acts = self.res_skip_layers[i](self.drop(fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)))
|
55 |
-
|
56 |
-
if i < self.n_layers - 1:
|
57 |
-
x = (x + (res_skip_acts[:, : self.hidden_channels, :])) * x_mask
|
58 |
-
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
59 |
-
else: output = output + res_skip_acts
|
60 |
-
|
61 |
-
return output * x_mask
|
62 |
-
|
63 |
-
def remove_weight_norm(self):
|
64 |
-
if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
|
65 |
-
|
66 |
-
for l in self.in_layers:
|
67 |
-
torch.nn.utils.remove_weight_norm(l)
|
68 |
-
|
69 |
-
for l in self.res_skip_layers:
|
70 |
-
torch.nn.utils.remove_weight_norm(l)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/mrf_hifigan.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import torch.utils.checkpoint as checkpoint
|
7 |
-
|
8 |
-
from torch.nn.utils import remove_weight_norm
|
9 |
-
from torch.nn.utils.parametrizations import weight_norm
|
10 |
-
|
11 |
-
|
12 |
-
LRELU_SLOPE = 0.1
|
13 |
-
|
14 |
-
class MRFLayer(torch.nn.Module):
|
15 |
-
def __init__(self, channels, kernel_size, dilation):
|
16 |
-
super().__init__()
|
17 |
-
self.conv1 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
|
18 |
-
self.conv2 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
|
19 |
-
|
20 |
-
def forward(self, x):
|
21 |
-
return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
|
22 |
-
|
23 |
-
def remove_weight_norm(self):
|
24 |
-
remove_weight_norm(self.conv1)
|
25 |
-
remove_weight_norm(self.conv2)
|
26 |
-
|
27 |
-
class MRFBlock(torch.nn.Module):
|
28 |
-
def __init__(self, channels, kernel_size, dilations):
|
29 |
-
super().__init__()
|
30 |
-
self.layers = torch.nn.ModuleList()
|
31 |
-
|
32 |
-
for dilation in dilations:
|
33 |
-
self.layers.append(MRFLayer(channels, kernel_size, dilation))
|
34 |
-
|
35 |
-
def forward(self, x):
|
36 |
-
for layer in self.layers:
|
37 |
-
x = layer(x)
|
38 |
-
|
39 |
-
return x
|
40 |
-
|
41 |
-
def remove_weight_norm(self):
|
42 |
-
for layer in self.layers:
|
43 |
-
layer.remove_weight_norm()
|
44 |
-
|
45 |
-
class SineGenerator(torch.nn.Module):
|
46 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
|
47 |
-
super(SineGenerator, self).__init__()
|
48 |
-
self.sine_amp = sine_amp
|
49 |
-
self.noise_std = noise_std
|
50 |
-
self.harmonic_num = harmonic_num
|
51 |
-
self.dim = self.harmonic_num + 1
|
52 |
-
self.sampling_rate = samp_rate
|
53 |
-
self.voiced_threshold = voiced_threshold
|
54 |
-
|
55 |
-
def _f02uv(self, f0):
|
56 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
57 |
-
|
58 |
-
def _f02sine(self, f0_values):
|
59 |
-
rad_values = (f0_values / self.sampling_rate) % 1
|
60 |
-
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
61 |
-
|
62 |
-
rand_ini[:, 0] = 0
|
63 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
64 |
-
|
65 |
-
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
66 |
-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
67 |
-
|
68 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
69 |
-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
70 |
-
|
71 |
-
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
72 |
-
|
73 |
-
def forward(self, f0):
|
74 |
-
with torch.no_grad():
|
75 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
76 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
77 |
-
|
78 |
-
for idx in np.arange(self.harmonic_num):
|
79 |
-
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
80 |
-
|
81 |
-
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
82 |
-
uv = self._f02uv(f0)
|
83 |
-
|
84 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
85 |
-
|
86 |
-
return sine_waves
|
87 |
-
|
88 |
-
class SourceModuleHnNSF(torch.nn.Module):
|
89 |
-
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0):
|
90 |
-
super(SourceModuleHnNSF, self).__init__()
|
91 |
-
self.sine_amp = sine_amp
|
92 |
-
self.noise_std = add_noise_std
|
93 |
-
|
94 |
-
self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
|
95 |
-
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
96 |
-
self.l_tanh = torch.nn.Tanh()
|
97 |
-
|
98 |
-
def forward(self, x):
|
99 |
-
return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
|
100 |
-
|
101 |
-
class HiFiGANMRFGenerator(torch.nn.Module):
|
102 |
-
def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing=False):
|
103 |
-
super().__init__()
|
104 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
105 |
-
|
106 |
-
self.f0_upsample = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
107 |
-
self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
|
108 |
-
|
109 |
-
self.conv_pre = weight_norm(torch.nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
|
110 |
-
self.checkpointing = checkpointing
|
111 |
-
|
112 |
-
self.upsamples = torch.nn.ModuleList()
|
113 |
-
self.noise_convs = torch.nn.ModuleList()
|
114 |
-
|
115 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
116 |
-
|
117 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
118 |
-
self.upsamples.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=(((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2)), output_padding=u % 2)))
|
119 |
-
stride = stride_f0s[i]
|
120 |
-
|
121 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
122 |
-
self.noise_convs.append(torch.nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel, stride=stride, padding=( 0 if stride == 1 else (kernel - stride) // 2)))
|
123 |
-
|
124 |
-
self.mrfs = torch.nn.ModuleList()
|
125 |
-
|
126 |
-
for i in range(len(self.upsamples)):
|
127 |
-
channel = upsample_initial_channel // (2 ** (i + 1))
|
128 |
-
self.mrfs.append(torch.nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
|
129 |
-
|
130 |
-
self.conv_post = weight_norm(torch.nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
|
131 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
132 |
-
|
133 |
-
def forward(self, x, f0, g = None):
|
134 |
-
har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
|
135 |
-
|
136 |
-
x = self.conv_pre(x)
|
137 |
-
if g is not None: x = x + self.cond(g)
|
138 |
-
|
139 |
-
for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
|
140 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
141 |
-
x = checkpoint.checkpoint(ups, x, use_reentrant=False) if self.training and self.checkpointing else ups(x)
|
142 |
-
x += noise_conv(har_source)
|
143 |
-
|
144 |
-
def mrf_sum(x, layers):
|
145 |
-
return sum(layer(x) for layer in layers) / self.num_kernels
|
146 |
-
|
147 |
-
x = checkpoint.checkpoint(mrf_sum, x, mrf, use_reentrant=False) if self.training and self.checkpointing else mrf_sum(x, mrf)
|
148 |
-
|
149 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
150 |
-
|
151 |
-
def remove_weight_norm(self):
|
152 |
-
remove_weight_norm(self.conv_pre)
|
153 |
-
|
154 |
-
for up in self.upsamples:
|
155 |
-
remove_weight_norm(up)
|
156 |
-
|
157 |
-
for mrf in self.mrfs:
|
158 |
-
mrf.remove_weight_norm()
|
159 |
-
|
160 |
-
remove_weight_norm(self.conv_post)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/refinegan.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torch.utils.checkpoint as checkpoint
|
9 |
-
|
10 |
-
from torch.nn.utils.parametrizations import weight_norm
|
11 |
-
from torch.nn.utils.parametrize import remove_parametrizations
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from .commons import get_padding
|
16 |
-
|
17 |
-
class ResBlock(torch.nn.Module):
|
18 |
-
def __init__(self, *, in_channels, out_channels, kernel_size = 7, dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
19 |
-
super(ResBlock, self).__init__()
|
20 |
-
self.leaky_relu_slope = leaky_relu_slope
|
21 |
-
self.in_channels = in_channels
|
22 |
-
self.out_channels = out_channels
|
23 |
-
|
24 |
-
self.convs1 = torch.nn.ModuleList([weight_norm(torch.nn.Conv1d(in_channels=in_channels if idx == 0 else out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for idx, d in enumerate(dilation)])
|
25 |
-
self.convs1.apply(self.init_weights)
|
26 |
-
|
27 |
-
self.convs2 = torch.nn.ModuleList([weight_norm(torch.nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for _, d in enumerate(dilation)])
|
28 |
-
self.convs2.apply(self.init_weights)
|
29 |
-
|
30 |
-
def forward(self, x):
|
31 |
-
for idx, (c1, c2) in enumerate(zip(self.convs1, self.convs2)):
|
32 |
-
xt = c2(F.leaky_relu(c1(F.leaky_relu(x, self.leaky_relu_slope)), self.leaky_relu_slope))
|
33 |
-
x = (xt + x) if idx != 0 or self.in_channels == self.out_channels else xt
|
34 |
-
|
35 |
-
return x
|
36 |
-
|
37 |
-
def remove_parametrizations(self):
|
38 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
39 |
-
remove_parametrizations(c1)
|
40 |
-
remove_parametrizations(c2)
|
41 |
-
|
42 |
-
def init_weights(self, m):
|
43 |
-
if type(m) == torch.nn.Conv1d:
|
44 |
-
m.weight.data.normal_(0, 0.01)
|
45 |
-
m.bias.data.fill_(0.0)
|
46 |
-
|
47 |
-
class AdaIN(torch.nn.Module):
|
48 |
-
def __init__(self, *, channels, leaky_relu_slope = 0.2):
|
49 |
-
super().__init__()
|
50 |
-
self.weight = torch.nn.Parameter(torch.ones(channels))
|
51 |
-
self.activation = torch.nn.LeakyReLU(leaky_relu_slope)
|
52 |
-
|
53 |
-
def forward(self, x):
|
54 |
-
return self.activation(x + (torch.randn_like(x) * self.weight[None, :, None]))
|
55 |
-
|
56 |
-
class ParallelResBlock(torch.nn.Module):
|
57 |
-
def __init__(self, *, in_channels, out_channels, kernel_sizes = (3, 7, 11), dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
58 |
-
super().__init__()
|
59 |
-
self.in_channels = in_channels
|
60 |
-
self.out_channels = out_channels
|
61 |
-
self.input_conv = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3)
|
62 |
-
self.blocks = torch.nn.ModuleList([torch.nn.Sequential(AdaIN(channels=out_channels), ResBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope), AdaIN(channels=out_channels)) for kernel_size in kernel_sizes])
|
63 |
-
|
64 |
-
def forward(self, x):
|
65 |
-
return torch.mean(torch.stack([block(self.input_conv(x)) for block in self.blocks]), dim=0)
|
66 |
-
|
67 |
-
def remove_parametrizations(self):
|
68 |
-
for block in self.blocks:
|
69 |
-
block[1].remove_parametrizations()
|
70 |
-
|
71 |
-
class SineGenerator(torch.nn.Module):
|
72 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
|
73 |
-
super(SineGenerator, self).__init__()
|
74 |
-
self.sine_amp = sine_amp
|
75 |
-
self.noise_std = noise_std
|
76 |
-
self.harmonic_num = harmonic_num
|
77 |
-
self.dim = self.harmonic_num + 1
|
78 |
-
self.sampling_rate = samp_rate
|
79 |
-
self.voiced_threshold = voiced_threshold
|
80 |
-
self.merge = torch.nn.Sequential(torch.nn.Linear(self.dim, 1, bias=False), torch.nn.Tanh())
|
81 |
-
|
82 |
-
def _f02uv(self, f0):
|
83 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
84 |
-
|
85 |
-
def _f02sine(self, f0_values):
|
86 |
-
rad_values = (f0_values / self.sampling_rate) % 1
|
87 |
-
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
88 |
-
|
89 |
-
rand_ini[:, 0] = 0
|
90 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
91 |
-
|
92 |
-
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
93 |
-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
94 |
-
|
95 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
96 |
-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
97 |
-
|
98 |
-
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
99 |
-
|
100 |
-
def forward(self, f0):
|
101 |
-
with torch.no_grad():
|
102 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
103 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
104 |
-
|
105 |
-
for idx in np.arange(self.harmonic_num):
|
106 |
-
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
107 |
-
|
108 |
-
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
109 |
-
uv = self._f02uv(f0)
|
110 |
-
|
111 |
-
sine_waves = sine_waves * uv + (uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves)
|
112 |
-
sine_waves = sine_waves - sine_waves.mean(dim=1, keepdim=True)
|
113 |
-
|
114 |
-
return self.merge(sine_waves)
|
115 |
-
|
116 |
-
class RefineGANGenerator(torch.nn.Module):
|
117 |
-
def __init__(self, *, sample_rate = 44100, upsample_rates = (8, 8, 2, 2), leaky_relu_slope = 0.2, num_mels = 128, gin_channels = 256, checkpointing = False, upsample_initial_channel = 512):
|
118 |
-
super().__init__()
|
119 |
-
self.upsample_rates = upsample_rates
|
120 |
-
self.checkpointing = checkpointing
|
121 |
-
self.leaky_relu_slope = leaky_relu_slope
|
122 |
-
self.upp = np.prod(upsample_rates)
|
123 |
-
self.m_source = SineGenerator(sample_rate)
|
124 |
-
self.pre_conv = weight_norm(torch.nn.Conv1d(in_channels=1, out_channels=upsample_initial_channel // 2, kernel_size=7, stride=1, padding=3, bias=False))
|
125 |
-
channels = upsample_initial_channel
|
126 |
-
self.downsample_blocks = torch.nn.ModuleList([])
|
127 |
-
|
128 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
129 |
-
|
130 |
-
for i, _ in enumerate(upsample_rates):
|
131 |
-
stride = stride_f0s[i]
|
132 |
-
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
133 |
-
|
134 |
-
self.downsample_blocks.append(torch.nn.Conv1d(in_channels=1, out_channels=channels // 2 ** (i + 2), kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
135 |
-
|
136 |
-
self.mel_conv = weight_norm(torch.nn.Conv1d(in_channels=num_mels, out_channels=channels // 2, kernel_size=7, stride=1, padding=3))
|
137 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(256, channels // 2, 1)
|
138 |
-
|
139 |
-
self.upsample_blocks = torch.nn.ModuleList([])
|
140 |
-
self.upsample_conv_blocks = torch.nn.ModuleList([])
|
141 |
-
self.filters = torch.nn.ModuleList([])
|
142 |
-
|
143 |
-
for rate in upsample_rates:
|
144 |
-
new_channels = channels // 2
|
145 |
-
self.upsample_blocks.append(torch.nn.Upsample(scale_factor=rate, mode="linear"))
|
146 |
-
|
147 |
-
low_pass = torch.nn.Conv1d(channels, channels, kernel_size=15, padding=7, groups=channels, bias=False)
|
148 |
-
low_pass.weight.data.fill_(1.0 / 15)
|
149 |
-
|
150 |
-
self.filters.append(low_pass)
|
151 |
-
|
152 |
-
self.upsample_conv_blocks.append(ParallelResBlock(in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope))
|
153 |
-
channels = new_channels
|
154 |
-
|
155 |
-
self.conv_post = weight_norm(torch.nn.Conv1d(in_channels=channels, out_channels=1, kernel_size=7, stride=1, padding=3))
|
156 |
-
|
157 |
-
def forward(self, mel, f0, g = None):
|
158 |
-
har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2)
|
159 |
-
x = F.interpolate(self.pre_conv(har_source), size=mel.shape[-1], mode="linear")
|
160 |
-
|
161 |
-
mel = self.mel_conv(mel)
|
162 |
-
if g is not None: mel += self.cond(g)
|
163 |
-
|
164 |
-
x = torch.cat([mel, x], dim=1)
|
165 |
-
|
166 |
-
for ups, res, down, flt in zip(self.upsample_blocks, self.upsample_conv_blocks, self.downsample_blocks, self.filters):
|
167 |
-
x = checkpoint(res, torch.cat([checkpoint(flt, checkpoint(ups, x, use_reentrant=False), use_reentrant=False), down(har_source)], dim=1), use_reentrant=False) if self.training and self.checkpointing else res(torch.cat([flt(ups(x)), down(har_source)], dim=1))
|
168 |
-
|
169 |
-
return torch.tanh_(self.conv_post(F.leaky_relu_(x, self.leaky_relu_slope)))
|
170 |
-
|
171 |
-
def remove_parametrizations(self):
|
172 |
-
remove_parametrizations(self.source_conv)
|
173 |
-
remove_parametrizations(self.mel_conv)
|
174 |
-
remove_parametrizations(self.conv_post)
|
175 |
-
|
176 |
-
for block in self.downsample_blocks:
|
177 |
-
block[1].remove_parametrizations()
|
178 |
-
|
179 |
-
for block in self.upsample_conv_blocks:
|
180 |
-
block.remove_parametrizations()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/residuals.py
DELETED
@@ -1,140 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from torch.nn.utils import remove_weight_norm
|
6 |
-
from torch.nn.utils.parametrizations import weight_norm
|
7 |
-
|
8 |
-
sys.path.append(os.getcwd())
|
9 |
-
|
10 |
-
from .modules import WaveNet
|
11 |
-
from .commons import get_padding, init_weights
|
12 |
-
|
13 |
-
|
14 |
-
LRELU_SLOPE = 0.1
|
15 |
-
|
16 |
-
def create_conv1d_layer(channels, kernel_size, dilation):
|
17 |
-
return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
|
18 |
-
|
19 |
-
def apply_mask(tensor, mask):
|
20 |
-
return tensor * mask if mask is not None else tensor
|
21 |
-
|
22 |
-
class ResBlockBase(torch.nn.Module):
|
23 |
-
def __init__(self, channels, kernel_size, dilations):
|
24 |
-
super(ResBlockBase, self).__init__()
|
25 |
-
|
26 |
-
self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
|
27 |
-
self.convs1.apply(init_weights)
|
28 |
-
|
29 |
-
self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
|
30 |
-
self.convs2.apply(init_weights)
|
31 |
-
|
32 |
-
def forward(self, x, x_mask=None):
|
33 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
34 |
-
x = c2(apply_mask(torch.nn.functional.leaky_relu(c1(apply_mask(torch.nn.functional.leaky_relu(x, LRELU_SLOPE), x_mask)), LRELU_SLOPE), x_mask)) + x
|
35 |
-
|
36 |
-
return apply_mask(x, x_mask)
|
37 |
-
|
38 |
-
def remove_weight_norm(self):
|
39 |
-
for conv in self.convs1 + self.convs2:
|
40 |
-
remove_weight_norm(conv)
|
41 |
-
|
42 |
-
class ResBlock(ResBlockBase):
|
43 |
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
44 |
-
super(ResBlock, self).__init__(channels, kernel_size, dilation)
|
45 |
-
|
46 |
-
class Log(torch.nn.Module):
|
47 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
48 |
-
if not reverse:
|
49 |
-
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
50 |
-
return y, torch.sum(-y, [1, 2])
|
51 |
-
else: return torch.exp(x) * x_mask
|
52 |
-
|
53 |
-
class Flip(torch.nn.Module):
|
54 |
-
def forward(self, x, *args, reverse=False, **kwargs):
|
55 |
-
x = torch.flip(x, [1])
|
56 |
-
|
57 |
-
if not reverse: return x, torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
58 |
-
else: return x
|
59 |
-
|
60 |
-
class ElementwiseAffine(torch.nn.Module):
|
61 |
-
def __init__(self, channels):
|
62 |
-
super().__init__()
|
63 |
-
self.channels = channels
|
64 |
-
self.m = torch.nn.Parameter(torch.zeros(channels, 1))
|
65 |
-
self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
|
66 |
-
|
67 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
68 |
-
if not reverse: return ((self.m + torch.exp(self.logs) * x) * x_mask), torch.sum(self.logs * x_mask, [1, 2])
|
69 |
-
else: return (x - self.m) * torch.exp(-self.logs) * x_mask
|
70 |
-
|
71 |
-
class ResidualCouplingBlock(torch.nn.Module):
|
72 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
73 |
-
super(ResidualCouplingBlock, self).__init__()
|
74 |
-
self.channels = channels
|
75 |
-
self.hidden_channels = hidden_channels
|
76 |
-
self.kernel_size = kernel_size
|
77 |
-
self.dilation_rate = dilation_rate
|
78 |
-
self.n_layers = n_layers
|
79 |
-
self.n_flows = n_flows
|
80 |
-
self.gin_channels = gin_channels
|
81 |
-
self.flows = torch.nn.ModuleList()
|
82 |
-
|
83 |
-
for _ in range(n_flows):
|
84 |
-
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
85 |
-
self.flows.append(Flip())
|
86 |
-
|
87 |
-
def forward(self, x, x_mask, g = None, reverse = False):
|
88 |
-
if not reverse:
|
89 |
-
for flow in self.flows:
|
90 |
-
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
91 |
-
else:
|
92 |
-
for flow in reversed(self.flows):
|
93 |
-
x = flow.forward(x, x_mask, g=g, reverse=reverse)
|
94 |
-
|
95 |
-
return x
|
96 |
-
|
97 |
-
def remove_weight_norm(self):
|
98 |
-
for i in range(self.n_flows):
|
99 |
-
self.flows[i * 2].remove_weight_norm()
|
100 |
-
|
101 |
-
def __prepare_scriptable__(self):
|
102 |
-
for i in range(self.n_flows):
|
103 |
-
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
104 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
105 |
-
|
106 |
-
return self
|
107 |
-
|
108 |
-
class ResidualCouplingLayer(torch.nn.Module):
|
109 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
|
110 |
-
assert channels % 2 == 0, "Channels/2"
|
111 |
-
super().__init__()
|
112 |
-
self.channels = channels
|
113 |
-
self.hidden_channels = hidden_channels
|
114 |
-
self.kernel_size = kernel_size
|
115 |
-
self.dilation_rate = dilation_rate
|
116 |
-
self.n_layers = n_layers
|
117 |
-
self.half_channels = channels // 2
|
118 |
-
self.mean_only = mean_only
|
119 |
-
|
120 |
-
self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
|
121 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
122 |
-
self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
123 |
-
|
124 |
-
self.post.weight.data.zero_()
|
125 |
-
self.post.bias.data.zero_()
|
126 |
-
|
127 |
-
def forward(self, x, x_mask, g=None, reverse=False):
|
128 |
-
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
129 |
-
stats = self.post(self.enc((self.pre(x0) * x_mask), x_mask, g=g)) * x_mask
|
130 |
-
|
131 |
-
if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
132 |
-
else:
|
133 |
-
m = stats
|
134 |
-
logs = torch.zeros_like(m)
|
135 |
-
|
136 |
-
if not reverse: return torch.cat([x0, (m + x1 * torch.exp(logs) * x_mask)], 1), torch.sum(logs, [1, 2])
|
137 |
-
else: return torch.cat([x0, ((x1 - m) * torch.exp(-logs) * x_mask)], 1)
|
138 |
-
|
139 |
-
def remove_weight_norm(self):
|
140 |
-
self.enc.remove_weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/separator.py
DELETED
@@ -1,330 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import yaml
|
5 |
-
import torch
|
6 |
-
import codecs
|
7 |
-
import hashlib
|
8 |
-
import logging
|
9 |
-
import platform
|
10 |
-
import warnings
|
11 |
-
import requests
|
12 |
-
import onnxruntime
|
13 |
-
|
14 |
-
from importlib import metadata, import_module
|
15 |
-
|
16 |
-
now_dir = os.getcwd()
|
17 |
-
sys.path.append(now_dir)
|
18 |
-
|
19 |
-
from main.configs.config import Config
|
20 |
-
translations = Config().translations
|
21 |
-
|
22 |
-
class Separator:
|
23 |
-
def __init__(self, logger=logging.getLogger(__name__), log_level=logging.INFO, log_formatter=None, model_file_dir="assets/models/uvr5", output_dir=None, output_format="wav", output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}):
|
24 |
-
self.logger = logger
|
25 |
-
self.log_level = log_level
|
26 |
-
self.log_formatter = log_formatter
|
27 |
-
self.log_handler = logging.StreamHandler()
|
28 |
-
|
29 |
-
if self.log_formatter is None: self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
30 |
-
self.log_handler.setFormatter(self.log_formatter)
|
31 |
-
|
32 |
-
if not self.logger.hasHandlers(): self.logger.addHandler(self.log_handler)
|
33 |
-
if log_level > logging.DEBUG: warnings.filterwarnings("ignore")
|
34 |
-
|
35 |
-
self.logger.info(translations["separator_info"].format(output_dir=output_dir, output_format=output_format))
|
36 |
-
self.model_file_dir = model_file_dir
|
37 |
-
|
38 |
-
if output_dir is None:
|
39 |
-
output_dir = now_dir
|
40 |
-
self.logger.info(translations["output_dir_is_none"])
|
41 |
-
|
42 |
-
self.output_dir = output_dir
|
43 |
-
|
44 |
-
os.makedirs(self.model_file_dir, exist_ok=True)
|
45 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
46 |
-
|
47 |
-
self.output_format = output_format
|
48 |
-
self.output_bitrate = output_bitrate
|
49 |
-
|
50 |
-
if self.output_format is None: self.output_format = "wav"
|
51 |
-
self.normalization_threshold = normalization_threshold
|
52 |
-
if normalization_threshold <= 0 or normalization_threshold > 1: raise ValueError(translations[">0or=1"])
|
53 |
-
|
54 |
-
self.output_single_stem = output_single_stem
|
55 |
-
if output_single_stem is not None: self.logger.debug(translations["output_single"].format(output_single_stem=output_single_stem))
|
56 |
-
|
57 |
-
self.invert_using_spec = invert_using_spec
|
58 |
-
if self.invert_using_spec: self.logger.debug(translations["step2"])
|
59 |
-
|
60 |
-
self.sample_rate = int(sample_rate)
|
61 |
-
self.arch_specific_params = {"MDX": mdx_params, "Demucs": demucs_params}
|
62 |
-
self.torch_device = None
|
63 |
-
self.torch_device_cpu = None
|
64 |
-
self.torch_device_mps = None
|
65 |
-
self.onnx_execution_provider = None
|
66 |
-
self.model_instance = None
|
67 |
-
self.model_is_uvr_vip = False
|
68 |
-
self.model_friendly_name = None
|
69 |
-
self.setup_accelerated_inferencing_device()
|
70 |
-
|
71 |
-
def setup_accelerated_inferencing_device(self):
|
72 |
-
system_info = self.get_system_info()
|
73 |
-
self.log_onnxruntime_packages()
|
74 |
-
self.setup_torch_device(system_info)
|
75 |
-
|
76 |
-
def get_system_info(self):
|
77 |
-
os_name = platform.system()
|
78 |
-
os_version = platform.version()
|
79 |
-
self.logger.info(f"{translations['os']}: {os_name} {os_version}")
|
80 |
-
system_info = platform.uname()
|
81 |
-
self.logger.info(translations["platform_info"].format(system_info=system_info, node=system_info.node, release=system_info.release, machine=system_info.machine, processor=system_info.processor))
|
82 |
-
python_version = platform.python_version()
|
83 |
-
self.logger.info(f"{translations['name_ver'].format(name='python')}: {python_version}")
|
84 |
-
pytorch_version = torch.__version__
|
85 |
-
self.logger.info(f"{translations['name_ver'].format(name='pytorch')}: {pytorch_version}")
|
86 |
-
|
87 |
-
return system_info
|
88 |
-
|
89 |
-
def log_onnxruntime_packages(self):
|
90 |
-
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
91 |
-
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
92 |
-
|
93 |
-
if onnxruntime_gpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='GPU')}: {onnxruntime_gpu_package.version}")
|
94 |
-
if onnxruntime_cpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='CPU')}: {onnxruntime_cpu_package.version}")
|
95 |
-
|
96 |
-
def setup_torch_device(self, system_info):
|
97 |
-
hardware_acceleration_enabled = False
|
98 |
-
ort_providers = onnxruntime.get_available_providers()
|
99 |
-
self.torch_device_cpu = torch.device("cpu")
|
100 |
-
|
101 |
-
if torch.cuda.is_available():
|
102 |
-
self.configure_cuda(ort_providers)
|
103 |
-
hardware_acceleration_enabled = True
|
104 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
105 |
-
self.configure_mps(ort_providers)
|
106 |
-
hardware_acceleration_enabled = True
|
107 |
-
|
108 |
-
if not hardware_acceleration_enabled:
|
109 |
-
self.logger.info(translations["running_in_cpu"])
|
110 |
-
self.torch_device = self.torch_device_cpu
|
111 |
-
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
112 |
-
|
113 |
-
def configure_cuda(self, ort_providers):
|
114 |
-
self.logger.info(translations["running_in_cuda"])
|
115 |
-
self.torch_device = torch.device("cuda")
|
116 |
-
|
117 |
-
if "CUDAExecutionProvider" in ort_providers:
|
118 |
-
self.logger.info(translations["onnx_have"].format(have='CUDAExecutionProvider'))
|
119 |
-
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
120 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CUDAExecutionProvider'))
|
121 |
-
|
122 |
-
def configure_mps(self, ort_providers):
|
123 |
-
self.logger.info(translations["set_torch_mps"])
|
124 |
-
self.torch_device_mps = torch.device("mps")
|
125 |
-
self.torch_device = self.torch_device_mps
|
126 |
-
|
127 |
-
if "CoreMLExecutionProvider" in ort_providers:
|
128 |
-
self.logger.info(translations["onnx_have"].format(have='CoreMLExecutionProvider'))
|
129 |
-
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
130 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CoreMLExecutionProvider'))
|
131 |
-
|
132 |
-
def get_package_distribution(self, package_name):
|
133 |
-
try:
|
134 |
-
return metadata.distribution(package_name)
|
135 |
-
except metadata.PackageNotFoundError:
|
136 |
-
self.logger.debug(translations["python_not_install"].format(package_name=package_name))
|
137 |
-
return None
|
138 |
-
|
139 |
-
def get_model_hash(self, model_path):
|
140 |
-
self.logger.debug(translations["hash"].format(model_path=model_path))
|
141 |
-
|
142 |
-
try:
|
143 |
-
with open(model_path, "rb") as f:
|
144 |
-
f.seek(-10000 * 1024, 2)
|
145 |
-
return hashlib.md5(f.read()).hexdigest()
|
146 |
-
except IOError as e:
|
147 |
-
self.logger.error(translations["ioerror"].format(e=e))
|
148 |
-
return hashlib.md5(open(model_path, "rb").read()).hexdigest()
|
149 |
-
|
150 |
-
def download_file_if_not_exists(self, url, output_path):
|
151 |
-
if os.path.isfile(output_path):
|
152 |
-
self.logger.debug(translations["cancel_download"].format(output_path=output_path))
|
153 |
-
return
|
154 |
-
|
155 |
-
self.logger.debug(translations["download_model"].format(url=url, output_path=output_path))
|
156 |
-
response = requests.get(url, stream=True, timeout=300)
|
157 |
-
|
158 |
-
if response.status_code == 200:
|
159 |
-
from tqdm import tqdm
|
160 |
-
|
161 |
-
progress_bar = tqdm(total=int(response.headers.get("content-length", 0)), ncols=100, unit="byte")
|
162 |
-
|
163 |
-
with open(output_path, "wb") as f:
|
164 |
-
for chunk in response.iter_content(chunk_size=8192):
|
165 |
-
progress_bar.update(len(chunk))
|
166 |
-
f.write(chunk)
|
167 |
-
|
168 |
-
progress_bar.close()
|
169 |
-
else: raise RuntimeError(translations["download_error"].format(url=url, status_code=response.status_code))
|
170 |
-
|
171 |
-
def print_uvr_vip_message(self):
|
172 |
-
if self.model_is_uvr_vip:
|
173 |
-
self.logger.warning(translations["vip_model"].format(model_friendly_name=self.model_friendly_name))
|
174 |
-
self.logger.warning(translations["vip_print"])
|
175 |
-
|
176 |
-
def list_supported_model_files(self):
|
177 |
-
response = requests.get(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/hie_zbqryf.wfba", "rot13"))
|
178 |
-
response.raise_for_status()
|
179 |
-
model_downloads_list = response.json()
|
180 |
-
self.logger.debug(translations["load_download_json"])
|
181 |
-
|
182 |
-
return {"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, "Demucs": {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}}
|
183 |
-
|
184 |
-
def download_model_files(self, model_filename):
|
185 |
-
model_path = os.path.join(self.model_file_dir, model_filename)
|
186 |
-
supported_model_files_grouped = self.list_supported_model_files()
|
187 |
-
|
188 |
-
yaml_config_filename = None
|
189 |
-
self.logger.debug(translations["search_model"].format(model_filename=model_filename))
|
190 |
-
|
191 |
-
for model_type, model_list in supported_model_files_grouped.items():
|
192 |
-
for model_friendly_name, model_download_list in model_list.items():
|
193 |
-
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
194 |
-
model_repo_url_prefix = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/hie5_zbqryf", "rot13")
|
195 |
-
|
196 |
-
if isinstance(model_download_list, str) and model_download_list == model_filename:
|
197 |
-
self.logger.debug(translations["single_model"].format(model_friendly_name=model_friendly_name))
|
198 |
-
self.model_friendly_name = model_friendly_name
|
199 |
-
|
200 |
-
try:
|
201 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/MDX/{model_filename}", model_path)
|
202 |
-
except RuntimeError:
|
203 |
-
self.logger.warning(translations["not_found_model"])
|
204 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{model_filename}", model_path)
|
205 |
-
|
206 |
-
self.print_uvr_vip_message()
|
207 |
-
self.logger.debug(translations["single_model_path"].format(model_path=model_path))
|
208 |
-
|
209 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
210 |
-
elif isinstance(model_download_list, dict):
|
211 |
-
this_model_matches_input_filename = False
|
212 |
-
|
213 |
-
for file_name, file_url in model_download_list.items():
|
214 |
-
if file_name == model_filename or file_url == model_filename:
|
215 |
-
self.logger.debug(translations["find_model"].format(model_filename=model_filename, model_friendly_name=model_friendly_name))
|
216 |
-
this_model_matches_input_filename = True
|
217 |
-
|
218 |
-
if this_model_matches_input_filename:
|
219 |
-
self.logger.debug(translations["find_models"].format(model_friendly_name=model_friendly_name))
|
220 |
-
self.model_friendly_name = model_friendly_name
|
221 |
-
self.print_uvr_vip_message()
|
222 |
-
|
223 |
-
for config_key, config_value in model_download_list.items():
|
224 |
-
self.logger.debug(f"{translations['find_path']}: {config_key} -> {config_value}")
|
225 |
-
|
226 |
-
if config_value.startswith("http"): self.download_file_if_not_exists(config_value, os.path.join(self.model_file_dir, config_key))
|
227 |
-
elif config_key.endswith(".ckpt"):
|
228 |
-
try:
|
229 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_key}", os.path.join(self.model_file_dir, config_key))
|
230 |
-
except RuntimeError:
|
231 |
-
self.logger.warning(translations["not_found_model_warehouse"])
|
232 |
-
|
233 |
-
if model_filename.endswith(".yaml"):
|
234 |
-
self.logger.warning(translations["yaml_warning"].format(model_filename=model_filename))
|
235 |
-
self.logger.warning(translations["yaml_warning_2"].format(config_key=config_key))
|
236 |
-
self.logger.warning(translations["yaml_warning_3"])
|
237 |
-
|
238 |
-
model_filename = config_key
|
239 |
-
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
240 |
-
|
241 |
-
yaml_config_filename = config_value
|
242 |
-
yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
|
243 |
-
|
244 |
-
try:
|
245 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/mdx_c_configs/{yaml_config_filename}", yaml_config_filepath)
|
246 |
-
except RuntimeError:
|
247 |
-
self.logger.debug(translations["yaml_debug"])
|
248 |
-
else: self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_value}", os.path.join(self.model_file_dir, config_value))
|
249 |
-
|
250 |
-
self.logger.debug(translations["download_model_friendly"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
251 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
252 |
-
|
253 |
-
raise ValueError(translations["not_found_model_2"].format(model_filename=model_filename))
|
254 |
-
|
255 |
-
def load_model_data_from_yaml(self, yaml_config_filename):
|
256 |
-
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) if not os.path.exists(yaml_config_filename) else yaml_config_filename
|
257 |
-
self.logger.debug(translations["load_yaml"].format(model_data_yaml_filepath=model_data_yaml_filepath))
|
258 |
-
|
259 |
-
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
260 |
-
self.logger.debug(translations["load_yaml_2"].format(model_data=model_data))
|
261 |
-
|
262 |
-
if "roformer" in model_data_yaml_filepath: model_data["is_roformer"] = True
|
263 |
-
return model_data
|
264 |
-
|
265 |
-
def load_model_data_using_hash(self, model_path):
|
266 |
-
self.logger.debug(translations["hash_md5"])
|
267 |
-
model_hash = self.get_model_hash(model_path)
|
268 |
-
|
269 |
-
self.logger.debug(translations["model_hash"].format(model_path=model_path, model_hash=model_hash))
|
270 |
-
mdx_model_data_path = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/zbqry_qngn.wfba", "rot13")
|
271 |
-
self.logger.debug(translations["mdx_data"].format(mdx_model_data_path=mdx_model_data_path))
|
272 |
-
|
273 |
-
response = requests.get(mdx_model_data_path)
|
274 |
-
response.raise_for_status()
|
275 |
-
|
276 |
-
mdx_model_data_object = response.json()
|
277 |
-
self.logger.debug(translations["load_mdx"])
|
278 |
-
|
279 |
-
if model_hash in mdx_model_data_object: model_data = mdx_model_data_object[model_hash]
|
280 |
-
else: raise ValueError(translations["model_not_support"].format(model_hash=model_hash))
|
281 |
-
|
282 |
-
self.logger.debug(translations["uvr_json"].format(model_hash=model_hash, model_data=model_data))
|
283 |
-
return model_data
|
284 |
-
|
285 |
-
def load_model(self, model_filename):
|
286 |
-
self.logger.info(translations["loading_model"].format(model_filename=model_filename))
|
287 |
-
load_model_start_time = time.perf_counter()
|
288 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
289 |
-
self.logger.debug(translations["download_model_friendly_2"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
290 |
-
|
291 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
292 |
-
|
293 |
-
common_params = {"logger": self.logger, "log_level": self.log_level, "torch_device": self.torch_device, "torch_device_cpu": self.torch_device_cpu, "torch_device_mps": self.torch_device_mps, "onnx_execution_provider": self.onnx_execution_provider, "model_name": model_filename.split(".")[0], "model_path": model_path, "model_data": self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path), "output_format": self.output_format, "output_bitrate": self.output_bitrate, "output_dir": self.output_dir, "normalization_threshold": self.normalization_threshold, "output_single_stem": self.output_single_stem, "invert_using_spec": self.invert_using_spec, "sample_rate": self.sample_rate}
|
294 |
-
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "Demucs": "demucs_separator.DemucsSeparator"}
|
295 |
-
|
296 |
-
if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(translations["model_type_not_support"].format(model_type=model_type))
|
297 |
-
if model_type == "Demucs" and sys.version_info < (3, 10): raise Exception(translations["demucs_not_support_python<3.10"])
|
298 |
-
|
299 |
-
self.logger.debug(f"{translations['import_module']} {model_type}: {separator_classes[model_type]}")
|
300 |
-
module_name, class_name = separator_classes[model_type].split(".")
|
301 |
-
separator_class = getattr(import_module(f"main.library.architectures.{module_name}"), class_name)
|
302 |
-
|
303 |
-
self.logger.debug(f"{translations['initialization']} {model_type}: {separator_class}")
|
304 |
-
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
305 |
-
|
306 |
-
self.logger.debug(translations["loading_model_success"])
|
307 |
-
self.logger.info(f"{translations['loading_model_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - load_model_start_time)))}")
|
308 |
-
|
309 |
-
def separate(self, audio_file_path):
|
310 |
-
self.logger.info(f"{translations['starting_separator']}: {audio_file_path}")
|
311 |
-
separate_start_time = time.perf_counter()
|
312 |
-
|
313 |
-
self.logger.debug(translations["normalization"].format(normalization_threshold=self.normalization_threshold))
|
314 |
-
output_files = self.model_instance.separate(audio_file_path)
|
315 |
-
|
316 |
-
self.model_instance.clear_gpu_cache()
|
317 |
-
self.model_instance.clear_file_specific_paths()
|
318 |
-
|
319 |
-
self.print_uvr_vip_message()
|
320 |
-
|
321 |
-
self.logger.debug(translations["separator_success_3"])
|
322 |
-
self.logger.info(f"{translations['separator_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - separate_start_time)))}")
|
323 |
-
return output_files
|
324 |
-
|
325 |
-
def download_model_and_data(self, model_filename):
|
326 |
-
self.logger.info(translations["loading_separator_model"].format(model_filename=model_filename))
|
327 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
328 |
-
|
329 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
330 |
-
self.logger.info(translations["downloading_model"].format(model_type=model_type, model_friendly_name=model_friendly_name, model_path=model_path, model_data_dict_size=len(self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/synthesizers.py
DELETED
@@ -1,450 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
|
6 |
-
import torch.nn.functional as F
|
7 |
-
import torch.utils.checkpoint as checkpoint
|
8 |
-
|
9 |
-
from torch.nn.utils import remove_weight_norm
|
10 |
-
from torch.nn.utils.parametrizations import weight_norm
|
11 |
-
|
12 |
-
sys.path.append(os.getcwd())
|
13 |
-
|
14 |
-
from .modules import WaveNet
|
15 |
-
from .refinegan import RefineGANGenerator
|
16 |
-
from .mrf_hifigan import HiFiGANMRFGenerator
|
17 |
-
from .residuals import ResidualCouplingBlock, ResBlock, LRELU_SLOPE
|
18 |
-
from .commons import init_weights, slice_segments, rand_slice_segments, sequence_mask, convert_pad_shape
|
19 |
-
|
20 |
-
class Generator(torch.nn.Module):
|
21 |
-
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
22 |
-
super(Generator, self).__init__()
|
23 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
24 |
-
self.num_upsamples = len(upsample_rates)
|
25 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
26 |
-
self.ups_and_resblocks = torch.nn.ModuleList()
|
27 |
-
|
28 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
29 |
-
self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
|
30 |
-
ch = upsample_initial_channel // (2 ** (i + 1))
|
31 |
-
|
32 |
-
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
33 |
-
self.ups_and_resblocks.append(ResBlock(ch, k, d))
|
34 |
-
|
35 |
-
self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
36 |
-
self.ups_and_resblocks.apply(init_weights)
|
37 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
38 |
-
|
39 |
-
def forward(self, x, g = None):
|
40 |
-
x = self.conv_pre(x)
|
41 |
-
if g is not None: x = x + self.cond(g)
|
42 |
-
resblock_idx = 0
|
43 |
-
|
44 |
-
for _ in range(self.num_upsamples):
|
45 |
-
x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
|
46 |
-
resblock_idx += 1
|
47 |
-
xs = 0
|
48 |
-
|
49 |
-
for _ in range(self.num_kernels):
|
50 |
-
xs += self.ups_and_resblocks[resblock_idx](x)
|
51 |
-
resblock_idx += 1
|
52 |
-
|
53 |
-
x = xs / self.num_kernels
|
54 |
-
|
55 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
56 |
-
|
57 |
-
def __prepare_scriptable__(self):
|
58 |
-
for l in self.ups_and_resblocks:
|
59 |
-
for hook in l._forward_pre_hooks.values():
|
60 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
|
61 |
-
|
62 |
-
return self
|
63 |
-
|
64 |
-
def remove_weight_norm(self):
|
65 |
-
for l in self.ups_and_resblocks:
|
66 |
-
remove_weight_norm(l)
|
67 |
-
|
68 |
-
class SineGen(torch.nn.Module):
|
69 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
|
70 |
-
super(SineGen, self).__init__()
|
71 |
-
self.sine_amp = sine_amp
|
72 |
-
self.noise_std = noise_std
|
73 |
-
self.harmonic_num = harmonic_num
|
74 |
-
self.dim = self.harmonic_num + 1
|
75 |
-
self.sample_rate = samp_rate
|
76 |
-
self.voiced_threshold = voiced_threshold
|
77 |
-
|
78 |
-
def _f02uv(self, f0):
|
79 |
-
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
80 |
-
|
81 |
-
def forward(self, f0, upp):
|
82 |
-
with torch.no_grad():
|
83 |
-
f0 = f0[:, None].transpose(1, 2)
|
84 |
-
|
85 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
86 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
87 |
-
f0_buf[:, :, 1:] = (f0_buf[:, :, 0:1] * torch.arange(2, self.harmonic_num + 2, device=f0.device)[None, None, :])
|
88 |
-
|
89 |
-
rad_values = (f0_buf / float(self.sample_rate)) % 1
|
90 |
-
rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
|
91 |
-
rand_ini[:, 0] = 0
|
92 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
93 |
-
|
94 |
-
tmp_over_one = torch.cumsum(rad_values, 1)
|
95 |
-
tmp_over_one *= upp
|
96 |
-
tmp_over_one = F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1)
|
97 |
-
|
98 |
-
rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
99 |
-
tmp_over_one %= 1
|
100 |
-
|
101 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
102 |
-
cumsum_shift[:, 1:, :] = ((tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0) * -1.0
|
103 |
-
|
104 |
-
uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
105 |
-
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi) * self.sine_amp
|
106 |
-
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
107 |
-
|
108 |
-
return sine_waves
|
109 |
-
|
110 |
-
class SourceModuleHnNSF(torch.nn.Module):
|
111 |
-
def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
|
112 |
-
super(SourceModuleHnNSF, self).__init__()
|
113 |
-
self.sine_amp = sine_amp
|
114 |
-
self.noise_std = add_noise_std
|
115 |
-
self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
116 |
-
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
117 |
-
self.l_tanh = torch.nn.Tanh()
|
118 |
-
|
119 |
-
def forward(self, x, upsample_factor = 1):
|
120 |
-
return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
|
121 |
-
|
122 |
-
class GeneratorNSF(torch.nn.Module):
|
123 |
-
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
|
124 |
-
super(GeneratorNSF, self).__init__()
|
125 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
126 |
-
self.num_upsamples = len(upsample_rates)
|
127 |
-
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
|
128 |
-
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
|
129 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
130 |
-
self.checkpointing = checkpointing
|
131 |
-
self.ups = torch.nn.ModuleList()
|
132 |
-
self.noise_convs = torch.nn.ModuleList()
|
133 |
-
channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(len(upsample_rates))]
|
134 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
135 |
-
|
136 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
137 |
-
self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=(k - u) // 2)))
|
138 |
-
self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=(stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1), stride=stride_f0s[i], padding=(stride_f0s[i] // 2 if stride_f0s[i] > 1 else 0)))
|
139 |
-
|
140 |
-
self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
|
141 |
-
self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
|
142 |
-
self.ups.apply(init_weights)
|
143 |
-
|
144 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
145 |
-
|
146 |
-
self.upp = math.prod(upsample_rates)
|
147 |
-
self.lrelu_slope = LRELU_SLOPE
|
148 |
-
|
149 |
-
def forward(self, x, f0, g = None):
|
150 |
-
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
151 |
-
x = self.conv_pre(x)
|
152 |
-
if g is not None: x = x + self.cond(g)
|
153 |
-
|
154 |
-
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
155 |
-
x = F.leaky_relu(x, self.lrelu_slope)
|
156 |
-
x = checkpoint.checkpoint(ups, x, use_reentrant=False) if self.training and self.checkpointing else ups(x)
|
157 |
-
x += noise_convs(har_source)
|
158 |
-
|
159 |
-
def resblock_forward(x, blocks):
|
160 |
-
return sum(block(x) for block in blocks) / len(blocks)
|
161 |
-
|
162 |
-
blocks = self.resblocks[i * self.num_kernels:(i + 1) * self.num_kernels]
|
163 |
-
x = checkpoint.checkpoint(resblock_forward, x, blocks, use_reentrant=False)if self.training and self.checkpointing else resblock_forward(x, blocks)
|
164 |
-
|
165 |
-
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
166 |
-
|
167 |
-
def remove_weight_norm(self):
|
168 |
-
for l in self.ups:
|
169 |
-
remove_weight_norm(l)
|
170 |
-
|
171 |
-
for l in self.resblocks:
|
172 |
-
l.remove_weight_norm()
|
173 |
-
|
174 |
-
class LayerNorm(torch.nn.Module):
|
175 |
-
def __init__(self, channels, eps=1e-5):
|
176 |
-
super().__init__()
|
177 |
-
self.eps = eps
|
178 |
-
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
179 |
-
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
180 |
-
|
181 |
-
def forward(self, x):
|
182 |
-
x = x.transpose(1, -1)
|
183 |
-
return F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps).transpose(1, -1)
|
184 |
-
|
185 |
-
class MultiHeadAttention(torch.nn.Module):
|
186 |
-
def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
187 |
-
super().__init__()
|
188 |
-
assert channels % n_heads == 0
|
189 |
-
self.channels = channels
|
190 |
-
self.out_channels = out_channels
|
191 |
-
self.n_heads = n_heads
|
192 |
-
self.p_dropout = p_dropout
|
193 |
-
self.window_size = window_size
|
194 |
-
self.heads_share = heads_share
|
195 |
-
self.block_length = block_length
|
196 |
-
self.proximal_bias = proximal_bias
|
197 |
-
self.proximal_init = proximal_init
|
198 |
-
self.attn = None
|
199 |
-
self.k_channels = channels // n_heads
|
200 |
-
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
201 |
-
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
202 |
-
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
203 |
-
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
204 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
205 |
-
|
206 |
-
if window_size is not None:
|
207 |
-
n_heads_rel = 1 if heads_share else n_heads
|
208 |
-
rel_stddev = self.k_channels**-0.5
|
209 |
-
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
210 |
-
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
211 |
-
|
212 |
-
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
213 |
-
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
214 |
-
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
215 |
-
|
216 |
-
if proximal_init:
|
217 |
-
with torch.no_grad():
|
218 |
-
self.conv_k.weight.copy_(self.conv_q.weight)
|
219 |
-
self.conv_k.bias.copy_(self.conv_q.bias)
|
220 |
-
|
221 |
-
def forward(self, x, c, attn_mask=None):
|
222 |
-
q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
|
223 |
-
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
224 |
-
return self.conv_o(x)
|
225 |
-
|
226 |
-
def attention(self, query, key, value, mask=None):
|
227 |
-
b, d, t_s, t_t = (*key.size(), query.size(2))
|
228 |
-
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
229 |
-
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
230 |
-
|
231 |
-
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
232 |
-
if self.window_size is not None:
|
233 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
234 |
-
scores = scores + self._relative_position_to_absolute_position(self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), self._get_relative_embeddings(self.emb_rel_k, t_s)))
|
235 |
-
|
236 |
-
if self.proximal_bias:
|
237 |
-
assert t_s == t_t, "t_s == t_t"
|
238 |
-
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
239 |
-
|
240 |
-
if mask is not None:
|
241 |
-
scores = scores.masked_fill(mask == 0, -1e4)
|
242 |
-
if self.block_length is not None:
|
243 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
244 |
-
scores = scores.masked_fill((torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)) == 0, -1e4)
|
245 |
-
|
246 |
-
p_attn = self.drop(F.softmax(scores, dim=-1) )
|
247 |
-
output = torch.matmul(p_attn, value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3))
|
248 |
-
|
249 |
-
if self.window_size is not None: output = output + self._matmul_with_relative_values(self._absolute_position_to_relative_position(p_attn), self._get_relative_embeddings(self.emb_rel_v, t_s))
|
250 |
-
return (output.transpose(2, 3).contiguous().view(b, d, t_t)), p_attn
|
251 |
-
|
252 |
-
def _matmul_with_relative_values(self, x, y):
|
253 |
-
return torch.matmul(x, y.unsqueeze(0))
|
254 |
-
|
255 |
-
def _matmul_with_relative_keys(self, x, y):
|
256 |
-
return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
257 |
-
|
258 |
-
def _get_relative_embeddings(self, relative_embeddings, length):
|
259 |
-
pad_length = max(length - (self.window_size + 1), 0)
|
260 |
-
slice_start_position = max((self.window_size + 1) - length, 0)
|
261 |
-
return (F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
|
262 |
-
|
263 |
-
def _relative_position_to_absolute_position(self, x):
|
264 |
-
batch, heads, length, _ = x.size()
|
265 |
-
return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])).view([batch, heads, length * 2 * length]), convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
266 |
-
|
267 |
-
def _absolute_position_to_relative_position(self, x):
|
268 |
-
batch, heads, length, _ = x.size()
|
269 |
-
return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length**2 + length * (length - 1)]), convert_pad_shape([[0, 0], [0, 0], [length, 0]])).view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
270 |
-
|
271 |
-
def _attention_bias_proximal(self, length):
|
272 |
-
r = torch.arange(length, dtype=torch.float32)
|
273 |
-
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs((torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)))), 0), 0)
|
274 |
-
|
275 |
-
class FFN(torch.nn.Module):
|
276 |
-
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False):
|
277 |
-
super().__init__()
|
278 |
-
self.in_channels = in_channels
|
279 |
-
self.out_channels = out_channels
|
280 |
-
self.filter_channels = filter_channels
|
281 |
-
self.kernel_size = kernel_size
|
282 |
-
self.p_dropout = p_dropout
|
283 |
-
self.activation = activation
|
284 |
-
self.causal = causal
|
285 |
-
self.padding = self._causal_padding if causal else self._same_padding
|
286 |
-
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size)
|
287 |
-
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size)
|
288 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
289 |
-
|
290 |
-
def forward(self, x, x_mask):
|
291 |
-
x = self.conv_1(self.padding(x * x_mask))
|
292 |
-
return self.conv_2(self.padding(self.drop(((x * torch.sigmoid(1.702 * x)) if self.activation == "gelu" else torch.relu(x))) * x_mask)) * x_mask
|
293 |
-
|
294 |
-
def _causal_padding(self, x):
|
295 |
-
if self.kernel_size == 1: return x
|
296 |
-
|
297 |
-
return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1), 0]]))
|
298 |
-
|
299 |
-
def _same_padding(self, x):
|
300 |
-
if self.kernel_size == 1: return x
|
301 |
-
|
302 |
-
return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [((self.kernel_size - 1) // 2), (self.kernel_size // 2)]]))
|
303 |
-
|
304 |
-
class Encoder(torch.nn.Module):
|
305 |
-
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs):
|
306 |
-
super().__init__()
|
307 |
-
self.hidden_channels = hidden_channels
|
308 |
-
self.filter_channels = filter_channels
|
309 |
-
self.n_heads = n_heads
|
310 |
-
self.n_layers = n_layers
|
311 |
-
self.kernel_size = kernel_size
|
312 |
-
self.p_dropout = p_dropout
|
313 |
-
self.window_size = window_size
|
314 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
315 |
-
self.attn_layers = torch.nn.ModuleList()
|
316 |
-
self.norm_layers_1 = torch.nn.ModuleList()
|
317 |
-
self.ffn_layers = torch.nn.ModuleList()
|
318 |
-
self.norm_layers_2 = torch.nn.ModuleList()
|
319 |
-
|
320 |
-
for _ in range(self.n_layers):
|
321 |
-
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
322 |
-
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
323 |
-
|
324 |
-
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
325 |
-
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
326 |
-
|
327 |
-
def forward(self, x, x_mask):
|
328 |
-
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
329 |
-
x = x * x_mask
|
330 |
-
|
331 |
-
for i in range(self.n_layers):
|
332 |
-
x = self.norm_layers_1[i](x + self.drop(self.attn_layers[i](x, x, attn_mask)))
|
333 |
-
x = self.norm_layers_2[i](x + self.drop(self.ffn_layers[i](x, x_mask)))
|
334 |
-
|
335 |
-
return x * x_mask
|
336 |
-
|
337 |
-
class TextEncoder(torch.nn.Module):
|
338 |
-
def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True):
|
339 |
-
super(TextEncoder, self).__init__()
|
340 |
-
self.out_channels = out_channels
|
341 |
-
self.hidden_channels = hidden_channels
|
342 |
-
self.filter_channels = filter_channels
|
343 |
-
self.n_heads = n_heads
|
344 |
-
self.n_layers = n_layers
|
345 |
-
self.kernel_size = kernel_size
|
346 |
-
self.p_dropout = float(p_dropout)
|
347 |
-
self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
|
348 |
-
self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
|
349 |
-
if f0: self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
|
350 |
-
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
|
351 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
352 |
-
|
353 |
-
def forward(self, phone, pitch, lengths):
|
354 |
-
x = self.emb_phone(phone) if pitch is None else (self.emb_phone(phone) + self.emb_pitch(pitch))
|
355 |
-
x = torch.transpose(self.lrelu((x * math.sqrt(self.hidden_channels))), 1, -1)
|
356 |
-
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
|
357 |
-
m, logs = torch.split((self.proj(self.encoder(x * x_mask, x_mask)) * x_mask), self.out_channels, dim=1)
|
358 |
-
return m, logs, x_mask
|
359 |
-
|
360 |
-
class PosteriorEncoder(torch.nn.Module):
|
361 |
-
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
|
362 |
-
super(PosteriorEncoder, self).__init__()
|
363 |
-
self.in_channels = in_channels
|
364 |
-
self.out_channels = out_channels
|
365 |
-
self.hidden_channels = hidden_channels
|
366 |
-
self.kernel_size = kernel_size
|
367 |
-
self.dilation_rate = dilation_rate
|
368 |
-
self.n_layers = n_layers
|
369 |
-
self.gin_channels = gin_channels
|
370 |
-
self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
|
371 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
372 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
373 |
-
|
374 |
-
def forward(self, x, x_lengths, g = None):
|
375 |
-
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
376 |
-
m, logs = torch.split((self.proj(self.enc((self.pre(x) * x_mask), x_mask, g=g)) * x_mask), self.out_channels, dim=1)
|
377 |
-
return ((m + torch.randn_like(m) * torch.exp(logs)) * x_mask), m, logs, x_mask
|
378 |
-
|
379 |
-
def remove_weight_norm(self):
|
380 |
-
self.enc.remove_weight_norm()
|
381 |
-
|
382 |
-
class Synthesizer(torch.nn.Module):
|
383 |
-
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing = False, **kwargs):
|
384 |
-
super(Synthesizer, self).__init__()
|
385 |
-
self.spec_channels = spec_channels
|
386 |
-
self.inter_channels = inter_channels
|
387 |
-
self.hidden_channels = hidden_channels
|
388 |
-
self.filter_channels = filter_channels
|
389 |
-
self.n_heads = n_heads
|
390 |
-
self.n_layers = n_layers
|
391 |
-
self.kernel_size = kernel_size
|
392 |
-
self.p_dropout = float(p_dropout)
|
393 |
-
self.resblock_kernel_sizes = resblock_kernel_sizes
|
394 |
-
self.resblock_dilation_sizes = resblock_dilation_sizes
|
395 |
-
self.upsample_rates = upsample_rates
|
396 |
-
self.upsample_initial_channel = upsample_initial_channel
|
397 |
-
self.upsample_kernel_sizes = upsample_kernel_sizes
|
398 |
-
self.segment_size = segment_size
|
399 |
-
self.gin_channels = gin_channels
|
400 |
-
self.spk_embed_dim = spk_embed_dim
|
401 |
-
self.use_f0 = use_f0
|
402 |
-
self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0)
|
403 |
-
|
404 |
-
if use_f0:
|
405 |
-
if vocoder == "RefineGAN": self.dec = RefineGANGenerator(sample_rate=sr, upsample_rates=upsample_rates, num_mels=inter_channels, checkpointing=checkpointing)
|
406 |
-
elif vocoder == "MRF HiFi-GAN": self.dec = HiFiGANMRFGenerator(in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing)
|
407 |
-
else: self.dec = GeneratorNSF(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing)
|
408 |
-
else: self.dec = Generator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
409 |
-
|
410 |
-
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
411 |
-
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
|
412 |
-
self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
|
413 |
-
|
414 |
-
def remove_weight_norm(self):
|
415 |
-
self.dec.remove_weight_norm()
|
416 |
-
self.flow.remove_weight_norm()
|
417 |
-
self.enc_q.remove_weight_norm()
|
418 |
-
|
419 |
-
@torch.jit.ignore
|
420 |
-
def forward(self, phone, phone_lengths, pitch = None, pitchf = None, y = None, y_lengths = None, ds = None):
|
421 |
-
g = self.emb_g(ds).unsqueeze(-1)
|
422 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
423 |
-
|
424 |
-
if y is not None:
|
425 |
-
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
426 |
-
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
|
427 |
-
return (self.dec(z_slice, slice_segments(pitchf, ids_slice, self.segment_size, 2), g=g) if self.use_f0 else self.dec(z_slice, g=g)), ids_slice, x_mask, y_mask, (z, self.flow(z, y_mask, g=g), m_p, logs_p, m_q, logs_q)
|
428 |
-
else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
|
429 |
-
|
430 |
-
@torch.jit.export
|
431 |
-
def infer(self, phone, phone_lengths, pitch = None, nsff0 = None, sid = None, rate = None):
|
432 |
-
g = self.emb_g(sid).unsqueeze(-1)
|
433 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
434 |
-
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
435 |
-
|
436 |
-
if rate is not None:
|
437 |
-
assert isinstance(rate, torch.Tensor)
|
438 |
-
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
439 |
-
z_p = z_p[:, :, head:]
|
440 |
-
x_mask = x_mask[:, :, head:]
|
441 |
-
if self.use_f0: nsff0 = nsff0[:, head:]
|
442 |
-
|
443 |
-
if self.use_f0:
|
444 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
445 |
-
o = self.dec(z * x_mask, nsff0, g=g)
|
446 |
-
else:
|
447 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
448 |
-
o = self.dec(z * x_mask, g=g)
|
449 |
-
|
450 |
-
return o, x_mask, (z, z_p, m_p, logs_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/demucs_separator.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import yaml
|
4 |
-
import torch
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
|
8 |
-
from pathlib import Path
|
9 |
-
from hashlib import sha256
|
10 |
-
|
11 |
-
sys.path.append(os.getcwd())
|
12 |
-
|
13 |
-
from main.configs.config import Config
|
14 |
-
from main.library.uvr5_separator import spec_utils, common_separator
|
15 |
-
from main.library.uvr5_separator.demucs import hdemucs, states, apply
|
16 |
-
|
17 |
-
translations = Config().translations
|
18 |
-
sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator"))
|
19 |
-
DEMUCS_4_SOURCE_MAPPER = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3}
|
20 |
-
|
21 |
-
class DemucsSeparator(common_separator.CommonSeparator):
|
22 |
-
def __init__(self, common_config, arch_config):
|
23 |
-
super().__init__(config=common_config)
|
24 |
-
self.segment_size = arch_config.get("segment_size", "Default")
|
25 |
-
self.shifts = arch_config.get("shifts", 2)
|
26 |
-
self.overlap = arch_config.get("overlap", 0.25)
|
27 |
-
self.segments_enabled = arch_config.get("segments_enabled", True)
|
28 |
-
self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled))
|
29 |
-
self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap))
|
30 |
-
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
31 |
-
self.audio_file_path = None
|
32 |
-
self.audio_file_base = None
|
33 |
-
self.demucs_model_instance = None
|
34 |
-
self.logger.info(translations["start_demucs"])
|
35 |
-
|
36 |
-
def separate(self, audio_file_path):
|
37 |
-
self.logger.debug(translations["start_separator"])
|
38 |
-
source = None
|
39 |
-
inst_source = {}
|
40 |
-
self.audio_file_path = audio_file_path
|
41 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
42 |
-
self.logger.debug(translations["prepare_mix"])
|
43 |
-
mix = self.prepare_mix(self.audio_file_path)
|
44 |
-
self.logger.debug(translations["demix"].format(shape=mix.shape))
|
45 |
-
self.logger.debug(translations["cancel_mix"])
|
46 |
-
self.demucs_model_instance = hdemucs.HDemucs(sources=["drums", "bass", "other", "vocals"])
|
47 |
-
self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path)))
|
48 |
-
self.demucs_model_instance = apply.demucs_segments(self.segment_size, self.demucs_model_instance)
|
49 |
-
self.demucs_model_instance.to(self.torch_device)
|
50 |
-
self.demucs_model_instance.eval()
|
51 |
-
self.logger.debug(translations["model_review"])
|
52 |
-
source = self.demix_demucs(mix)
|
53 |
-
del self.demucs_model_instance
|
54 |
-
self.clear_gpu_cache()
|
55 |
-
self.logger.debug(translations["del_gpu_cache_after_demix"])
|
56 |
-
output_files = []
|
57 |
-
self.logger.debug(translations["process_output_file"])
|
58 |
-
|
59 |
-
if isinstance(inst_source, np.ndarray):
|
60 |
-
self.logger.debug(translations["process_ver"])
|
61 |
-
inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]] = spec_utils.reshape_sources(inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]])
|
62 |
-
source = inst_source
|
63 |
-
|
64 |
-
if isinstance(source, np.ndarray):
|
65 |
-
source_length = len(source)
|
66 |
-
self.logger.debug(translations["source_length"].format(source_length=source_length))
|
67 |
-
self.logger.debug(translations["set_map"].format(part=source_length))
|
68 |
-
match source_length:
|
69 |
-
case 2: self.demucs_source_map = {common_separator.CommonSeparator.INST_STEM: 0, common_separator.CommonSeparator.VOCAL_STEM: 1}
|
70 |
-
case 6: self.demucs_source_map = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3, common_separator.CommonSeparator.GUITAR_STEM: 4, common_separator.CommonSeparator.PIANO_STEM: 5}
|
71 |
-
case _: self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
72 |
-
|
73 |
-
self.logger.debug(translations["process_all_part"])
|
74 |
-
for stem_name, stem_value in self.demucs_source_map.items():
|
75 |
-
if self.output_single_stem is not None:
|
76 |
-
if stem_name.lower() != self.output_single_stem.lower():
|
77 |
-
self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem))
|
78 |
-
continue
|
79 |
-
stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
|
80 |
-
self.final_process(stem_path, source[stem_value].T, stem_name)
|
81 |
-
output_files.append(stem_path)
|
82 |
-
return output_files
|
83 |
-
|
84 |
-
def demix_demucs(self, mix):
|
85 |
-
self.logger.debug(translations["starting_demix_demucs"])
|
86 |
-
processed = {}
|
87 |
-
mix = torch.tensor(mix, dtype=torch.float32)
|
88 |
-
ref = mix.mean(0)
|
89 |
-
mix = (mix - ref.mean()) / ref.std()
|
90 |
-
mix_infer = mix
|
91 |
-
with torch.no_grad():
|
92 |
-
self.logger.debug(translations["model_infer"])
|
93 |
-
sources = apply.apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0]
|
94 |
-
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
95 |
-
sources[[0, 1]] = sources[[1, 0]]
|
96 |
-
processed[mix] = sources[:, :, 0:None].copy()
|
97 |
-
return np.concatenate([s[:, :, 0:None] for s in list(processed.values())], axis=-1)
|
98 |
-
|
99 |
-
class LocalRepo:
|
100 |
-
def __init__(self, root):
|
101 |
-
self.root = root
|
102 |
-
self.scan()
|
103 |
-
|
104 |
-
def scan(self):
|
105 |
-
self._models, self._checksums = {}, {}
|
106 |
-
for file in self.root.iterdir():
|
107 |
-
if file.suffix == ".th":
|
108 |
-
if "-" in file.stem:
|
109 |
-
xp_sig, checksum = file.stem.split("-")
|
110 |
-
self._checksums[xp_sig] = checksum
|
111 |
-
else: xp_sig = file.stem
|
112 |
-
|
113 |
-
if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig))
|
114 |
-
self._models[xp_sig] = file
|
115 |
-
|
116 |
-
def has_model(self, sig):
|
117 |
-
return sig in self._models
|
118 |
-
|
119 |
-
def get_model(self, sig):
|
120 |
-
try:
|
121 |
-
file = self._models[sig]
|
122 |
-
except KeyError:
|
123 |
-
raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
|
124 |
-
|
125 |
-
if sig in self._checksums: check_checksum(file, self._checksums[sig])
|
126 |
-
return states.load_model(file)
|
127 |
-
|
128 |
-
class BagOnlyRepo:
|
129 |
-
def __init__(self, root, model_repo):
|
130 |
-
self.root = root
|
131 |
-
self.model_repo = model_repo
|
132 |
-
self.scan()
|
133 |
-
|
134 |
-
def scan(self):
|
135 |
-
self._bags = {}
|
136 |
-
for file in self.root.iterdir():
|
137 |
-
if file.suffix == ".yaml": self._bags[file.stem] = file
|
138 |
-
|
139 |
-
def get_model(self, name):
|
140 |
-
try:
|
141 |
-
yaml_file = self._bags[name]
|
142 |
-
except KeyError:
|
143 |
-
raise RuntimeError(translations["name_not_pretrained"].format(name=name))
|
144 |
-
bag = yaml.safe_load(open(yaml_file))
|
145 |
-
return apply.BagOfModels([self.model_repo.get_model(sig) for sig in bag["models"]], bag.get("weights"), bag.get("segment"))
|
146 |
-
|
147 |
-
def check_checksum(path, checksum):
|
148 |
-
sha = sha256()
|
149 |
-
with open(path, "rb") as file:
|
150 |
-
while 1:
|
151 |
-
buf = file.read(2**20)
|
152 |
-
if not buf: break
|
153 |
-
sha.update(buf)
|
154 |
-
|
155 |
-
actual_checksum = sha.hexdigest()[: len(checksum)]
|
156 |
-
if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum))
|
157 |
-
|
158 |
-
def get_demucs_model(name, repo = None):
|
159 |
-
model_repo = LocalRepo(repo)
|
160 |
-
return (model_repo.get_model(name) if model_repo.has_model(name) else BagOnlyRepo(repo, model_repo).get_model(name)).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/mdx_separator.py
DELETED
@@ -1,320 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import onnx
|
4 |
-
import torch
|
5 |
-
import platform
|
6 |
-
import onnx2torch
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import onnxruntime as ort
|
10 |
-
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.configs.config import Config
|
16 |
-
from main.library.uvr5_separator import spec_utils
|
17 |
-
from main.library.uvr5_separator.common_separator import CommonSeparator
|
18 |
-
|
19 |
-
translations = Config().translations
|
20 |
-
|
21 |
-
class MDXSeparator(CommonSeparator):
|
22 |
-
def __init__(self, common_config, arch_config):
|
23 |
-
super().__init__(config=common_config)
|
24 |
-
self.segment_size = arch_config.get("segment_size")
|
25 |
-
self.overlap = arch_config.get("overlap")
|
26 |
-
self.batch_size = arch_config.get("batch_size", 1)
|
27 |
-
self.hop_length = arch_config.get("hop_length")
|
28 |
-
self.enable_denoise = arch_config.get("enable_denoise")
|
29 |
-
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
|
30 |
-
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
|
31 |
-
self.compensate = self.model_data["compensate"]
|
32 |
-
self.dim_f = self.model_data["mdx_dim_f_set"]
|
33 |
-
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
|
34 |
-
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
|
35 |
-
self.config_yaml = self.model_data.get("config_yaml", None)
|
36 |
-
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
|
37 |
-
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
|
38 |
-
self.load_model()
|
39 |
-
self.n_bins = 0
|
40 |
-
self.trim = 0
|
41 |
-
self.chunk_size = 0
|
42 |
-
self.gen_size = 0
|
43 |
-
self.stft = None
|
44 |
-
self.primary_source = None
|
45 |
-
self.secondary_source = None
|
46 |
-
self.audio_file_path = None
|
47 |
-
self.audio_file_base = None
|
48 |
-
|
49 |
-
def load_model(self):
|
50 |
-
self.logger.debug(translations["load_model_onnx"])
|
51 |
-
|
52 |
-
if self.segment_size == self.dim_t:
|
53 |
-
ort_session_options = ort.SessionOptions()
|
54 |
-
ort_session_options.log_severity_level = 3 if self.log_level > 10 else 0
|
55 |
-
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
|
56 |
-
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
|
57 |
-
self.logger.debug(translations["load_model_onnx_success"])
|
58 |
-
else:
|
59 |
-
self.model_run = onnx2torch.convert(onnx.load(self.model_path)) if platform.system() == 'Windows' else onnx2torch.convert(self.model_path)
|
60 |
-
self.model_run.to(self.torch_device).eval()
|
61 |
-
self.logger.debug(translations["onnx_to_pytorch"])
|
62 |
-
|
63 |
-
def separate(self, audio_file_path):
|
64 |
-
self.audio_file_path = audio_file_path
|
65 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
66 |
-
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
|
67 |
-
mix = self.prepare_mix(self.audio_file_path)
|
68 |
-
self.logger.debug(translations["normalization_demix"])
|
69 |
-
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
|
70 |
-
source = self.demix(mix)
|
71 |
-
self.logger.debug(translations["mix_success"])
|
72 |
-
output_files = []
|
73 |
-
self.logger.debug(translations["process_output_file"])
|
74 |
-
|
75 |
-
if not isinstance(self.primary_source, np.ndarray):
|
76 |
-
self.logger.debug(translations["primary_source"])
|
77 |
-
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
|
78 |
-
|
79 |
-
if not isinstance(self.secondary_source, np.ndarray):
|
80 |
-
self.logger.debug(translations["secondary_source"])
|
81 |
-
raw_mix = self.demix(mix, is_match_mix=True)
|
82 |
-
|
83 |
-
if self.invert_using_spec:
|
84 |
-
self.logger.debug(translations["invert_using_spec"])
|
85 |
-
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
|
86 |
-
else:
|
87 |
-
self.logger.debug(translations["invert_using_spec_2"])
|
88 |
-
self.secondary_source = mix.T - source.T
|
89 |
-
|
90 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
91 |
-
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
92 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
|
93 |
-
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
94 |
-
output_files.append(self.secondary_stem_output_path)
|
95 |
-
|
96 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
97 |
-
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
98 |
-
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
|
99 |
-
|
100 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
|
101 |
-
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
102 |
-
output_files.append(self.primary_stem_output_path)
|
103 |
-
|
104 |
-
return output_files
|
105 |
-
|
106 |
-
def initialize_model_settings(self):
|
107 |
-
self.logger.debug(translations["starting_model"])
|
108 |
-
|
109 |
-
self.n_bins = self.n_fft // 2 + 1
|
110 |
-
self.trim = self.n_fft // 2
|
111 |
-
|
112 |
-
self.chunk_size = self.hop_length * (self.segment_size - 1)
|
113 |
-
self.gen_size = self.chunk_size - 2 * self.trim
|
114 |
-
|
115 |
-
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
|
116 |
-
|
117 |
-
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
|
118 |
-
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
|
119 |
-
|
120 |
-
def initialize_mix(self, mix, is_ckpt=False):
|
121 |
-
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
|
122 |
-
|
123 |
-
if mix.shape[0] != 2:
|
124 |
-
error_message = translations["!=2"].format(shape=mix.shape[0])
|
125 |
-
self.logger.error(error_message)
|
126 |
-
raise ValueError(error_message)
|
127 |
-
|
128 |
-
if is_ckpt:
|
129 |
-
self.logger.debug(translations["process_check"])
|
130 |
-
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
|
131 |
-
self.logger.debug(f"{translations['cache']}: {pad}")
|
132 |
-
|
133 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
134 |
-
|
135 |
-
num_chunks = mixture.shape[-1] // self.gen_size
|
136 |
-
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
|
137 |
-
|
138 |
-
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
139 |
-
else:
|
140 |
-
self.logger.debug(translations["process_no_check"])
|
141 |
-
mix_waves = []
|
142 |
-
n_sample = mix.shape[1]
|
143 |
-
|
144 |
-
pad = self.gen_size - n_sample % self.gen_size
|
145 |
-
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
|
146 |
-
|
147 |
-
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
|
148 |
-
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
|
149 |
-
|
150 |
-
i = 0
|
151 |
-
while i < n_sample + pad:
|
152 |
-
mix_waves.append(np.array(mix_p[:, i : i + self.chunk_size]))
|
153 |
-
|
154 |
-
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
|
155 |
-
i += self.gen_size
|
156 |
-
|
157 |
-
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
|
158 |
-
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
|
159 |
-
|
160 |
-
return mix_waves_tensor, pad
|
161 |
-
|
162 |
-
def demix(self, mix, is_match_mix=False):
|
163 |
-
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
|
164 |
-
self.initialize_model_settings()
|
165 |
-
self.logger.debug(f"{translations['mix_shape']}: {mix.shape}")
|
166 |
-
tar_waves_ = []
|
167 |
-
|
168 |
-
if is_match_mix:
|
169 |
-
chunk_size = self.hop_length * (self.segment_size - 1)
|
170 |
-
overlap = 0.02
|
171 |
-
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
|
172 |
-
else:
|
173 |
-
chunk_size = self.chunk_size
|
174 |
-
overlap = self.overlap
|
175 |
-
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
|
176 |
-
|
177 |
-
gen_size = chunk_size - 2 * self.trim
|
178 |
-
self.logger.debug(f"{translations['calc_size']}: {gen_size}")
|
179 |
-
|
180 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, gen_size + self.trim - ((mix.shape[-1]) % gen_size)), dtype="float32")), 1)
|
181 |
-
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
|
182 |
-
|
183 |
-
step = int((1 - overlap) * chunk_size)
|
184 |
-
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
|
185 |
-
|
186 |
-
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
187 |
-
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
188 |
-
|
189 |
-
total = 0
|
190 |
-
total_chunks = (mixture.shape[-1] + step - 1) // step
|
191 |
-
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
|
192 |
-
|
193 |
-
for i in tqdm(range(0, mixture.shape[-1], step), ncols=100, unit="f"):
|
194 |
-
total += 1
|
195 |
-
start = i
|
196 |
-
end = min(i + chunk_size, mixture.shape[-1])
|
197 |
-
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
|
198 |
-
|
199 |
-
chunk_size_actual = end - start
|
200 |
-
window = None
|
201 |
-
|
202 |
-
if overlap != 0:
|
203 |
-
window = np.hanning(chunk_size_actual)
|
204 |
-
window = np.tile(window[None, None, :], (1, 2, 1))
|
205 |
-
self.logger.debug(translations["window"])
|
206 |
-
|
207 |
-
mix_part_ = mixture[:, start:end]
|
208 |
-
|
209 |
-
if end != i + chunk_size:
|
210 |
-
pad_size = (i + chunk_size) - end
|
211 |
-
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
|
212 |
-
|
213 |
-
mix_waves = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device).split(self.batch_size)
|
214 |
-
|
215 |
-
total_batches = len(mix_waves)
|
216 |
-
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
|
217 |
-
|
218 |
-
with torch.no_grad():
|
219 |
-
batches_processed = 0
|
220 |
-
|
221 |
-
for mix_wave in mix_waves:
|
222 |
-
batches_processed += 1
|
223 |
-
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
|
224 |
-
|
225 |
-
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
|
226 |
-
|
227 |
-
if window is not None:
|
228 |
-
tar_waves[..., :chunk_size_actual] *= window
|
229 |
-
divider[..., start:end] += window
|
230 |
-
else: divider[..., start:end] += 1
|
231 |
-
|
232 |
-
result[..., start:end] += tar_waves[..., : end - start]
|
233 |
-
|
234 |
-
|
235 |
-
self.logger.debug(translations["normalization_2"])
|
236 |
-
tar_waves = result / divider
|
237 |
-
tar_waves_.append(tar_waves)
|
238 |
-
|
239 |
-
tar_waves = np.concatenate(np.vstack(tar_waves_)[:, :, self.trim : -self.trim], axis=-1)[:, : mix.shape[-1]]
|
240 |
-
|
241 |
-
source = tar_waves[:, 0:None]
|
242 |
-
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
|
243 |
-
|
244 |
-
if not is_match_mix:
|
245 |
-
source *= self.compensate
|
246 |
-
self.logger.debug(translations["mix_match"])
|
247 |
-
|
248 |
-
self.logger.debug(translations["mix_success"])
|
249 |
-
return source
|
250 |
-
|
251 |
-
def run_model(self, mix, is_match_mix=False):
|
252 |
-
spek = self.stft(mix.to(self.torch_device))
|
253 |
-
self.logger.debug(translations["stft_2"].format(shape=spek.shape))
|
254 |
-
|
255 |
-
spek[:, :, :3, :] *= 0
|
256 |
-
|
257 |
-
if is_match_mix:
|
258 |
-
spec_pred = spek.cpu().numpy()
|
259 |
-
self.logger.debug(translations["is_match_mix"])
|
260 |
-
else:
|
261 |
-
if self.enable_denoise:
|
262 |
-
spec_pred_neg = self.model_run(-spek)
|
263 |
-
spec_pred_pos = self.model_run(spek)
|
264 |
-
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
|
265 |
-
self.logger.debug(translations["enable_denoise"])
|
266 |
-
else:
|
267 |
-
spec_pred = self.model_run(spek)
|
268 |
-
self.logger.debug(translations["no_denoise"])
|
269 |
-
|
270 |
-
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
|
271 |
-
self.logger.debug(f"{translations['stft']}: {result.shape}")
|
272 |
-
|
273 |
-
return result
|
274 |
-
|
275 |
-
class STFT:
|
276 |
-
def __init__(self, logger, n_fft, hop_length, dim_f, device):
|
277 |
-
self.logger = logger
|
278 |
-
self.n_fft = n_fft
|
279 |
-
self.hop_length = hop_length
|
280 |
-
self.dim_f = dim_f
|
281 |
-
self.device = device
|
282 |
-
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
283 |
-
|
284 |
-
def __call__(self, input_tensor):
|
285 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
286 |
-
|
287 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
288 |
-
|
289 |
-
batch_dimensions = input_tensor.shape[:-2]
|
290 |
-
channel_dim, time_dim = input_tensor.shape[-2:]
|
291 |
-
|
292 |
-
permuted_stft_output = torch.stft(input_tensor.reshape([-1, time_dim]), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True, return_complex=False).permute([0, 3, 1, 2])
|
293 |
-
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
|
294 |
-
|
295 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
296 |
-
return final_output[..., : self.dim_f, :]
|
297 |
-
|
298 |
-
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
|
299 |
-
return torch.cat([input_tensor, torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)], -2)
|
300 |
-
|
301 |
-
def calculate_inverse_dimensions(self, input_tensor):
|
302 |
-
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
|
303 |
-
|
304 |
-
return input_tensor.shape[:-3], channel_dim, freq_dim, time_dim, self.n_fft // 2 + 1
|
305 |
-
|
306 |
-
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
|
307 |
-
permuted_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]).reshape([-1, 2, num_freq_bins, time_dim]).permute([0, 2, 3, 1])
|
308 |
-
|
309 |
-
return permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
|
310 |
-
|
311 |
-
def inverse(self, input_tensor):
|
312 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
313 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
314 |
-
|
315 |
-
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
|
316 |
-
final_output = torch.istft(self.prepare_for_istft(self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins), batch_dimensions, channel_dim, num_freq_bins, time_dim), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True).reshape([*batch_dimensions, 2, -1])
|
317 |
-
|
318 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
319 |
-
|
320 |
-
return final_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/CREPE.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import librosa
|
4 |
-
import functools
|
5 |
-
import scipy.stats
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
|
10 |
-
|
11 |
-
class Crepe(torch.nn.Module):
|
12 |
-
def __init__(self, model='full'):
|
13 |
-
super().__init__()
|
14 |
-
if model == 'full':
|
15 |
-
in_channels = [1, 1024, 128, 128, 128, 256]
|
16 |
-
out_channels = [1024, 128, 128, 128, 256, 512]
|
17 |
-
self.in_features = 2048
|
18 |
-
elif model == 'large':
|
19 |
-
in_channels = [1, 768, 96, 96, 96, 192]
|
20 |
-
out_channels = [768, 96, 96, 96, 192, 384]
|
21 |
-
self.in_features = 1536
|
22 |
-
elif model == 'medium':
|
23 |
-
in_channels = [1, 512, 64, 64, 64, 128]
|
24 |
-
out_channels = [512, 64, 64, 64, 128, 256]
|
25 |
-
self.in_features = 1024
|
26 |
-
elif model == 'small':
|
27 |
-
in_channels = [1, 256, 32, 32, 32, 64]
|
28 |
-
out_channels = [256, 32, 32, 32, 64, 128]
|
29 |
-
self.in_features = 512
|
30 |
-
elif model == 'tiny':
|
31 |
-
in_channels = [1, 128, 16, 16, 16, 32]
|
32 |
-
out_channels = [128, 16, 16, 16, 32, 64]
|
33 |
-
self.in_features = 256
|
34 |
-
|
35 |
-
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
|
36 |
-
strides = [(4, 1)] + 5 * [(1, 1)]
|
37 |
-
|
38 |
-
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
|
39 |
-
|
40 |
-
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
|
41 |
-
self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
|
42 |
-
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
|
43 |
-
self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
|
44 |
-
|
45 |
-
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
|
46 |
-
self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
|
47 |
-
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
|
48 |
-
self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
|
49 |
-
|
50 |
-
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
|
51 |
-
self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
|
52 |
-
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
|
53 |
-
self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
|
54 |
-
|
55 |
-
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
|
56 |
-
|
57 |
-
def forward(self, x, embed=False):
|
58 |
-
x = self.embed(x)
|
59 |
-
if embed: return x
|
60 |
-
|
61 |
-
return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
|
62 |
-
|
63 |
-
def embed(self, x):
|
64 |
-
x = x[:, None, :, None]
|
65 |
-
|
66 |
-
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
|
67 |
-
|
68 |
-
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
|
69 |
-
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
|
70 |
-
|
71 |
-
def viterbi(logits):
|
72 |
-
if not hasattr(viterbi, 'transition'):
|
73 |
-
xx, yy = np.meshgrid(range(360), range(360))
|
74 |
-
transition = np.maximum(12 - abs(xx - yy), 0)
|
75 |
-
viterbi.transition = transition / transition.sum(axis=1, keepdims=True)
|
76 |
-
|
77 |
-
with torch.no_grad():
|
78 |
-
probs = torch.nn.functional.softmax(logits, dim=1)
|
79 |
-
|
80 |
-
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
|
81 |
-
return bins, bins_to_frequency(bins)
|
82 |
-
|
83 |
-
def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False):
|
84 |
-
results = []
|
85 |
-
|
86 |
-
if onnx:
|
87 |
-
import onnxruntime as ort
|
88 |
-
|
89 |
-
sess_options = ort.SessionOptions()
|
90 |
-
sess_options.log_severity_level = 3
|
91 |
-
|
92 |
-
session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers)
|
93 |
-
|
94 |
-
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
95 |
-
result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity)
|
96 |
-
results.append((result[0], result[1]) if isinstance(result, tuple) else result)
|
97 |
-
|
98 |
-
del session
|
99 |
-
|
100 |
-
if return_periodicity:
|
101 |
-
pitch, periodicity = zip(*results)
|
102 |
-
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
103 |
-
|
104 |
-
return torch.cat(results, 1)
|
105 |
-
else:
|
106 |
-
with torch.no_grad():
|
107 |
-
for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
|
108 |
-
result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity)
|
109 |
-
results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
|
110 |
-
|
111 |
-
if return_periodicity:
|
112 |
-
pitch, periodicity = zip(*results)
|
113 |
-
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
114 |
-
|
115 |
-
return torch.cat(results, 1)
|
116 |
-
|
117 |
-
def bins_to_frequency(bins):
|
118 |
-
cents = CENTS_PER_BIN * bins + 1997.3794084376191
|
119 |
-
return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
|
120 |
-
|
121 |
-
def frequency_to_bins(frequency, quantize_fn=torch.floor):
|
122 |
-
return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
|
123 |
-
|
124 |
-
def infer(frames, model='full', device='cpu', embed=False):
|
125 |
-
if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model)
|
126 |
-
infer.model = infer.model.to(device)
|
127 |
-
|
128 |
-
return infer.model(frames, embed=embed)
|
129 |
-
|
130 |
-
def load_model(device, capacity='full'):
|
131 |
-
infer.capacity = capacity
|
132 |
-
infer.model = Crepe(capacity)
|
133 |
-
infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device))
|
134 |
-
infer.model = infer.model.to(torch.device(device))
|
135 |
-
infer.model.eval()
|
136 |
-
|
137 |
-
def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False):
|
138 |
-
probabilities = probabilities.detach()
|
139 |
-
|
140 |
-
probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf')
|
141 |
-
probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf')
|
142 |
-
|
143 |
-
bins, pitch = viterbi(probabilities)
|
144 |
-
|
145 |
-
if not return_periodicity: return pitch
|
146 |
-
return pitch, periodicity(probabilities, bins)
|
147 |
-
|
148 |
-
def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True):
|
149 |
-
hop_length = sample_rate // 100 if hop_length is None else hop_length
|
150 |
-
|
151 |
-
if sample_rate != SAMPLE_RATE:
|
152 |
-
audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
|
153 |
-
hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
|
154 |
-
|
155 |
-
if pad:
|
156 |
-
total_frames = 1 + int(audio.size(1) // hop_length)
|
157 |
-
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
|
158 |
-
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
|
159 |
-
|
160 |
-
batch_size = total_frames if batch_size is None else batch_size
|
161 |
-
|
162 |
-
for i in range(0, total_frames, batch_size):
|
163 |
-
frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
|
164 |
-
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device)
|
165 |
-
frames -= frames.mean(dim=1, keepdim=True)
|
166 |
-
frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
|
167 |
-
|
168 |
-
yield frames
|
169 |
-
|
170 |
-
def periodicity(probabilities, bins):
|
171 |
-
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
|
172 |
-
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
|
173 |
-
|
174 |
-
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
|
175 |
-
|
176 |
-
def mean(signals, win_length=9):
|
177 |
-
assert signals.dim() == 2
|
178 |
-
|
179 |
-
signals = signals.unsqueeze(1)
|
180 |
-
mask = ~torch.isnan(signals)
|
181 |
-
padding = win_length // 2
|
182 |
-
|
183 |
-
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
184 |
-
avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
|
185 |
-
avg_pooled[avg_pooled == 0] = float("nan")
|
186 |
-
|
187 |
-
return avg_pooled.squeeze(1)
|
188 |
-
|
189 |
-
def median(signals, win_length):
|
190 |
-
assert signals.dim() == 2
|
191 |
-
|
192 |
-
signals = signals.unsqueeze(1)
|
193 |
-
mask = ~torch.isnan(signals)
|
194 |
-
padding = win_length // 2
|
195 |
-
|
196 |
-
x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
|
197 |
-
mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
|
198 |
-
|
199 |
-
x = x.unfold(2, win_length, 1)
|
200 |
-
mask = mask.unfold(2, win_length, 1)
|
201 |
-
|
202 |
-
x = x.contiguous().view(x.size()[:3] + (-1,))
|
203 |
-
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
204 |
-
|
205 |
-
x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
|
206 |
-
|
207 |
-
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
|
208 |
-
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
209 |
-
|
210 |
-
return median_pooled.squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/FCPE.py
DELETED
@@ -1,670 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import math
|
3 |
-
import torch
|
4 |
-
import librosa
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import soundfile as sf
|
8 |
-
import torch.nn.functional as F
|
9 |
-
|
10 |
-
from torch import nn, einsum
|
11 |
-
from functools import partial
|
12 |
-
from einops import rearrange, repeat, pack, unpack
|
13 |
-
from torch.nn.utils.parametrizations import weight_norm
|
14 |
-
|
15 |
-
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
16 |
-
|
17 |
-
def exists(val):
|
18 |
-
return val is not None
|
19 |
-
|
20 |
-
def default(value, d):
|
21 |
-
return value if exists(value) else d
|
22 |
-
|
23 |
-
def max_neg_value(tensor):
|
24 |
-
return -torch.finfo(tensor.dtype).max
|
25 |
-
|
26 |
-
def l2norm(tensor):
|
27 |
-
return F.normalize(tensor, dim = -1).type(tensor.dtype)
|
28 |
-
|
29 |
-
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
|
30 |
-
seqlen = tensor.shape[dim]
|
31 |
-
m = seqlen / multiple
|
32 |
-
if m.is_integer(): return False, tensor
|
33 |
-
return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
|
34 |
-
|
35 |
-
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
|
36 |
-
t = x.shape[1]
|
37 |
-
dims = (len(x.shape) - dim) * (0, 0)
|
38 |
-
padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
|
39 |
-
return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
|
40 |
-
|
41 |
-
def rotate_half(x):
|
42 |
-
x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
|
43 |
-
return torch.cat((-x2, x1), dim = -1)
|
44 |
-
|
45 |
-
def apply_rotary_pos_emb(q, k, freqs, scale = 1):
|
46 |
-
q_len = q.shape[-2]
|
47 |
-
q_freqs = freqs[..., -q_len:, :]
|
48 |
-
inv_scale = scale ** -1
|
49 |
-
if scale.ndim == 2: scale = scale[-q_len:, :]
|
50 |
-
q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
|
51 |
-
k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
|
52 |
-
return q, k
|
53 |
-
|
54 |
-
class LocalAttention(nn.Module):
|
55 |
-
def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, dim = None, autopad = False, exact_windowsize = False, scale = None, use_rotary_pos_emb = True, use_xpos = False, xpos_scale_base = None):
|
56 |
-
super().__init__()
|
57 |
-
look_forward = default(look_forward, 0 if causal else 1)
|
58 |
-
assert not (causal and look_forward > 0)
|
59 |
-
self.scale = scale
|
60 |
-
self.window_size = window_size
|
61 |
-
self.autopad = autopad
|
62 |
-
self.exact_windowsize = exact_windowsize
|
63 |
-
self.causal = causal
|
64 |
-
self.look_backward = look_backward
|
65 |
-
self.look_forward = look_forward
|
66 |
-
self.dropout = nn.Dropout(dropout)
|
67 |
-
self.shared_qk = shared_qk
|
68 |
-
self.rel_pos = None
|
69 |
-
self.use_xpos = use_xpos
|
70 |
-
|
71 |
-
if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
|
72 |
-
if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
|
73 |
-
self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
|
74 |
-
|
75 |
-
def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
|
76 |
-
mask = default(mask, input_mask)
|
77 |
-
assert not (exists(window_size) and not self.use_xpos)
|
78 |
-
|
79 |
-
_, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
|
80 |
-
(q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
|
81 |
-
|
82 |
-
if autopad:
|
83 |
-
orig_seq_len = q.shape[1]
|
84 |
-
(_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
|
85 |
-
|
86 |
-
b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
|
87 |
-
scale = default(self.scale, dim_head ** -0.5)
|
88 |
-
assert (n % window_size) == 0
|
89 |
-
windows = n // window_size
|
90 |
-
if shared_qk: k = l2norm(k)
|
91 |
-
|
92 |
-
seq = torch.arange(n, device = device)
|
93 |
-
b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
|
94 |
-
bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
|
95 |
-
bq = bq * scale
|
96 |
-
look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
|
97 |
-
bk = look_around(bk, **look_around_kwargs)
|
98 |
-
bv = look_around(bv, **look_around_kwargs)
|
99 |
-
|
100 |
-
if exists(self.rel_pos):
|
101 |
-
pos_emb, xpos_scale = self.rel_pos(bk)
|
102 |
-
bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
|
103 |
-
|
104 |
-
bq_t = b_t
|
105 |
-
bq_k = look_around(b_t, **look_around_kwargs)
|
106 |
-
bq_t = rearrange(bq_t, '... i -> ... i 1')
|
107 |
-
bq_k = rearrange(bq_k, '... j -> ... 1 j')
|
108 |
-
pad_mask = bq_k == pad_value
|
109 |
-
sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
|
110 |
-
|
111 |
-
if exists(attn_bias):
|
112 |
-
heads = attn_bias.shape[0]
|
113 |
-
assert (b % heads) == 0
|
114 |
-
attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
|
115 |
-
sim = sim + attn_bias
|
116 |
-
|
117 |
-
mask_value = max_neg_value(sim)
|
118 |
-
|
119 |
-
if shared_qk:
|
120 |
-
self_mask = bq_t == bq_k
|
121 |
-
sim = sim.masked_fill(self_mask, -5e4)
|
122 |
-
del self_mask
|
123 |
-
|
124 |
-
if causal:
|
125 |
-
causal_mask = bq_t < bq_k
|
126 |
-
if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
|
127 |
-
sim = sim.masked_fill(causal_mask, mask_value)
|
128 |
-
del causal_mask
|
129 |
-
|
130 |
-
sim = sim.masked_fill(((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask, mask_value) if not causal and self.exact_windowsize else sim.masked_fill(pad_mask, mask_value)
|
131 |
-
|
132 |
-
if exists(mask):
|
133 |
-
batch = mask.shape[0]
|
134 |
-
assert (b % batch) == 0
|
135 |
-
h = b // mask.shape[0]
|
136 |
-
if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
|
137 |
-
mask = repeat(rearrange(look_around(rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size), **{**look_around_kwargs, 'pad_value': False}), '... j -> ... 1 j'), 'b ... -> (b h) ...', h = h)
|
138 |
-
sim = sim.masked_fill(~mask, mask_value)
|
139 |
-
del mask
|
140 |
-
|
141 |
-
out = rearrange(einsum('b h i j, b h j e -> b h i e', self.dropout(sim.softmax(dim = -1)), bv), 'b w n d -> b (w n) d')
|
142 |
-
if autopad: out = out[:, :orig_seq_len, :]
|
143 |
-
out, *_ = unpack(out, packed_shape, '* n d')
|
144 |
-
return out
|
145 |
-
|
146 |
-
class SinusoidalEmbeddings(nn.Module):
|
147 |
-
def __init__(self, dim, scale_base = None, use_xpos = False, theta = 10000):
|
148 |
-
super().__init__()
|
149 |
-
inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
150 |
-
self.register_buffer('inv_freq', inv_freq)
|
151 |
-
self.use_xpos = use_xpos
|
152 |
-
self.scale_base = scale_base
|
153 |
-
assert not (use_xpos and not exists(scale_base))
|
154 |
-
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
155 |
-
self.register_buffer('scale', scale, persistent = False)
|
156 |
-
|
157 |
-
def forward(self, x):
|
158 |
-
seq_len, device = x.shape[-2], x.device
|
159 |
-
t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
|
160 |
-
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
161 |
-
freqs = torch.cat((freqs, freqs), dim = -1)
|
162 |
-
|
163 |
-
if not self.use_xpos: return freqs, torch.ones(1, device = device)
|
164 |
-
|
165 |
-
power = (t - (seq_len // 2)) / self.scale_base
|
166 |
-
scale = self.scale ** rearrange(power, 'n -> n 1')
|
167 |
-
return freqs, torch.cat((scale, scale), dim = -1)
|
168 |
-
|
169 |
-
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
170 |
-
try:
|
171 |
-
data, sample_rate = sf.read(full_path, always_2d=True)
|
172 |
-
except Exception as e:
|
173 |
-
print(f"{full_path}: {e}")
|
174 |
-
|
175 |
-
if return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
176 |
-
else: raise
|
177 |
-
|
178 |
-
data = data[:, 0] if len(data.shape) > 1 else data
|
179 |
-
assert len(data) > 2
|
180 |
-
|
181 |
-
max_mag = (-np.iinfo(data.dtype).min if np.issubdtype(data.dtype, np.integer) else max(np.amax(data), -np.amin(data)))
|
182 |
-
data = torch.FloatTensor(data.astype(np.float32)) / ((2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0))
|
183 |
-
|
184 |
-
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
185 |
-
|
186 |
-
if target_sr is not None and sample_rate != target_sr:
|
187 |
-
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr))
|
188 |
-
sample_rate = target_sr
|
189 |
-
|
190 |
-
return data, sample_rate
|
191 |
-
|
192 |
-
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
193 |
-
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
194 |
-
|
195 |
-
def dynamic_range_decompression(x, C=1):
|
196 |
-
return np.exp(x) / C
|
197 |
-
|
198 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
199 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
200 |
-
|
201 |
-
def dynamic_range_decompression_torch(x, C=1):
|
202 |
-
return torch.exp(x) / C
|
203 |
-
|
204 |
-
class STFT:
|
205 |
-
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
206 |
-
self.target_sr = sr
|
207 |
-
self.n_mels = n_mels
|
208 |
-
self.n_fft = n_fft
|
209 |
-
self.win_size = win_size
|
210 |
-
self.hop_length = hop_length
|
211 |
-
self.fmin = fmin
|
212 |
-
self.fmax = fmax
|
213 |
-
self.clip_val = clip_val
|
214 |
-
self.mel_basis = {}
|
215 |
-
self.hann_window = {}
|
216 |
-
|
217 |
-
def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
|
218 |
-
n_fft = self.n_fft
|
219 |
-
win_size = self.win_size
|
220 |
-
hop_length = self.hop_length
|
221 |
-
fmax = self.fmax
|
222 |
-
factor = 2 ** (keyshift / 12)
|
223 |
-
win_size_new = int(np.round(win_size * factor))
|
224 |
-
hop_length_new = int(np.round(hop_length * speed))
|
225 |
-
mel_basis = self.mel_basis if not train else {}
|
226 |
-
hann_window = self.hann_window if not train else {}
|
227 |
-
mel_basis_key = str(fmax) + "_" + str(y.device)
|
228 |
-
|
229 |
-
if mel_basis_key not in mel_basis:
|
230 |
-
from librosa.filters import mel as librosa_mel_fn
|
231 |
-
mel_basis[mel_basis_key] = torch.from_numpy(librosa_mel_fn(sr=self.target_sr, n_fft=n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=fmax)).float().to(y.device)
|
232 |
-
|
233 |
-
keyshift_key = str(keyshift) + "_" + str(y.device)
|
234 |
-
if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
235 |
-
|
236 |
-
pad_left = (win_size_new - hop_length_new) // 2
|
237 |
-
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
238 |
-
spec = torch.stft(torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1), int(np.round(n_fft * factor)), hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
239 |
-
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
240 |
-
|
241 |
-
if keyshift != 0:
|
242 |
-
size = n_fft // 2 + 1
|
243 |
-
resize = spec.size(1)
|
244 |
-
spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
|
245 |
-
|
246 |
-
return dynamic_range_compression_torch(torch.matmul(mel_basis[mel_basis_key], spec), clip_val=self.clip_val)
|
247 |
-
|
248 |
-
def __call__(self, audiopath):
|
249 |
-
audio, _ = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
250 |
-
return self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
251 |
-
|
252 |
-
stft = STFT()
|
253 |
-
|
254 |
-
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
|
255 |
-
b, h, *_ = data.shape
|
256 |
-
|
257 |
-
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
|
258 |
-
ratio = projection_matrix.shape[0] ** -0.5
|
259 |
-
|
260 |
-
data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
|
261 |
-
diag_data = ((torch.sum(data**2, dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
|
262 |
-
|
263 |
-
return (ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps) if is_query else ratio * (torch.exp(data_dash - diag_data + eps))).type_as(data)
|
264 |
-
|
265 |
-
def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
|
266 |
-
unstructured_block = torch.randn((cols, cols), device=device)
|
267 |
-
|
268 |
-
q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
|
269 |
-
q, r = map(lambda t: t.to(device), (q, r))
|
270 |
-
|
271 |
-
if qr_uniform_q:
|
272 |
-
d = torch.diag(r, 0)
|
273 |
-
q *= d.sign()
|
274 |
-
|
275 |
-
return q.t()
|
276 |
-
|
277 |
-
def empty(tensor):
|
278 |
-
return tensor.numel() == 0
|
279 |
-
|
280 |
-
def cast_tuple(val):
|
281 |
-
return (val,) if not isinstance(val, tuple) else val
|
282 |
-
|
283 |
-
class PCmer(nn.Module):
|
284 |
-
def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
|
285 |
-
super().__init__()
|
286 |
-
self.num_layers = num_layers
|
287 |
-
self.num_heads = num_heads
|
288 |
-
self.dim_model = dim_model
|
289 |
-
self.dim_values = dim_values
|
290 |
-
self.dim_keys = dim_keys
|
291 |
-
self.residual_dropout = residual_dropout
|
292 |
-
self.attention_dropout = attention_dropout
|
293 |
-
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
|
294 |
-
|
295 |
-
def forward(self, phone, mask=None):
|
296 |
-
for layer in self._layers:
|
297 |
-
phone = layer(phone, mask)
|
298 |
-
|
299 |
-
return phone
|
300 |
-
|
301 |
-
class _EncoderLayer(nn.Module):
|
302 |
-
def __init__(self, parent: PCmer):
|
303 |
-
super().__init__()
|
304 |
-
self.conformer = ConformerConvModule(parent.dim_model)
|
305 |
-
self.norm = nn.LayerNorm(parent.dim_model)
|
306 |
-
self.dropout = nn.Dropout(parent.residual_dropout)
|
307 |
-
self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
|
308 |
-
|
309 |
-
def forward(self, phone, mask=None):
|
310 |
-
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
311 |
-
return phone + (self.conformer(phone))
|
312 |
-
|
313 |
-
def calc_same_padding(kernel_size):
|
314 |
-
pad = kernel_size // 2
|
315 |
-
return (pad, pad - (kernel_size + 1) % 2)
|
316 |
-
|
317 |
-
class Swish(nn.Module):
|
318 |
-
def forward(self, x):
|
319 |
-
return x * x.sigmoid()
|
320 |
-
|
321 |
-
class Transpose(nn.Module):
|
322 |
-
def __init__(self, dims):
|
323 |
-
super().__init__()
|
324 |
-
assert len(dims) == 2, "dims == 2"
|
325 |
-
|
326 |
-
self.dims = dims
|
327 |
-
|
328 |
-
def forward(self, x):
|
329 |
-
return x.transpose(*self.dims)
|
330 |
-
|
331 |
-
class GLU(nn.Module):
|
332 |
-
def __init__(self, dim):
|
333 |
-
super().__init__()
|
334 |
-
self.dim = dim
|
335 |
-
|
336 |
-
def forward(self, x):
|
337 |
-
out, gate = x.chunk(2, dim=self.dim)
|
338 |
-
return out * gate.sigmoid()
|
339 |
-
|
340 |
-
class DepthWiseConv1d(nn.Module):
|
341 |
-
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
342 |
-
super().__init__()
|
343 |
-
self.padding = padding
|
344 |
-
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
|
345 |
-
|
346 |
-
def forward(self, x):
|
347 |
-
return self.conv(F.pad(x, self.padding))
|
348 |
-
|
349 |
-
class ConformerConvModule(nn.Module):
|
350 |
-
def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
|
351 |
-
super().__init__()
|
352 |
-
inner_dim = dim * expansion_factor
|
353 |
-
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=(calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0))), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
354 |
-
|
355 |
-
def forward(self, x):
|
356 |
-
return self.net(x)
|
357 |
-
|
358 |
-
def linear_attention(q, k, v):
|
359 |
-
return torch.einsum("...ed,...nd->...ne", k, q) if v is None else torch.einsum("...de,...nd,...n->...ne", torch.einsum("...nd,...ne->...de", k, v), q, 1.0 / (torch.einsum("...nd,...d->...n", q, k.sum(dim=-2).type_as(q)) + 1e-8))
|
360 |
-
|
361 |
-
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
|
362 |
-
nb_full_blocks = int(nb_rows / nb_columns)
|
363 |
-
block_list = []
|
364 |
-
|
365 |
-
for _ in range(nb_full_blocks):
|
366 |
-
block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
|
367 |
-
|
368 |
-
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
369 |
-
if remaining_rows > 0: block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)[:remaining_rows])
|
370 |
-
|
371 |
-
if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
|
372 |
-
elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
|
373 |
-
else: raise ValueError(f"{scaling} != 0, 1")
|
374 |
-
|
375 |
-
return torch.diag(multiplier) @ torch.cat(block_list)
|
376 |
-
|
377 |
-
class FastAttention(nn.Module):
|
378 |
-
def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
|
379 |
-
super().__init__()
|
380 |
-
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
381 |
-
self.dim_heads = dim_heads
|
382 |
-
self.nb_features = nb_features
|
383 |
-
self.ortho_scaling = ortho_scaling
|
384 |
-
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
|
385 |
-
projection_matrix = self.create_projection()
|
386 |
-
self.register_buffer("projection_matrix", projection_matrix)
|
387 |
-
self.generalized_attention = generalized_attention
|
388 |
-
self.kernel_fn = kernel_fn
|
389 |
-
self.no_projection = no_projection
|
390 |
-
self.causal = causal
|
391 |
-
|
392 |
-
@torch.no_grad()
|
393 |
-
def redraw_projection_matrix(self):
|
394 |
-
projections = self.create_projection()
|
395 |
-
self.projection_matrix.copy_(projections)
|
396 |
-
|
397 |
-
del projections
|
398 |
-
|
399 |
-
def forward(self, q, k, v):
|
400 |
-
if self.no_projection: q, k = q.softmax(dim=-1), (torch.exp(k) if self.causal else k.softmax(dim=-2))
|
401 |
-
else:
|
402 |
-
create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
|
403 |
-
q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
|
404 |
-
|
405 |
-
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
406 |
-
return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
|
407 |
-
|
408 |
-
class SelfAttention(nn.Module):
|
409 |
-
def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
|
410 |
-
super().__init__()
|
411 |
-
assert dim % heads == 0
|
412 |
-
dim_head = default(dim_head, dim // heads)
|
413 |
-
inner_dim = dim_head * heads
|
414 |
-
self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
|
415 |
-
self.heads = heads
|
416 |
-
self.global_heads = heads - local_heads
|
417 |
-
self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
|
418 |
-
self.to_q = nn.Linear(dim, inner_dim)
|
419 |
-
self.to_k = nn.Linear(dim, inner_dim)
|
420 |
-
self.to_v = nn.Linear(dim, inner_dim)
|
421 |
-
self.to_out = nn.Linear(inner_dim, dim)
|
422 |
-
self.dropout = nn.Dropout(dropout)
|
423 |
-
|
424 |
-
@torch.no_grad()
|
425 |
-
def redraw_projection_matrix(self):
|
426 |
-
self.fast_attention.redraw_projection_matrix()
|
427 |
-
|
428 |
-
def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
|
429 |
-
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
|
430 |
-
cross_attend = exists(context)
|
431 |
-
|
432 |
-
context = default(context, x)
|
433 |
-
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
434 |
-
|
435 |
-
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
|
436 |
-
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
437 |
-
|
438 |
-
attn_outs = []
|
439 |
-
|
440 |
-
if not empty(q):
|
441 |
-
if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
|
442 |
-
|
443 |
-
if cross_attend: pass
|
444 |
-
else: out = self.fast_attention(q, k, v)
|
445 |
-
|
446 |
-
attn_outs.append(out)
|
447 |
-
|
448 |
-
if not empty(lq):
|
449 |
-
assert (not cross_attend), "not cross_attend"
|
450 |
-
|
451 |
-
out = self.local_attn(lq, lk, lv, input_mask=mask)
|
452 |
-
attn_outs.append(out)
|
453 |
-
|
454 |
-
return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
|
455 |
-
|
456 |
-
def l2_regularization(model, l2_alpha):
|
457 |
-
l2_loss = []
|
458 |
-
|
459 |
-
for module in model.modules():
|
460 |
-
if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
|
461 |
-
|
462 |
-
return l2_alpha * sum(l2_loss)
|
463 |
-
|
464 |
-
class _FCPE(nn.Module):
|
465 |
-
def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, use_siren=False, use_full=False, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
|
466 |
-
super().__init__()
|
467 |
-
if use_siren: raise ValueError("Siren not support")
|
468 |
-
if use_full: raise ValueError("Model full not support")
|
469 |
-
self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
|
470 |
-
self.loss_l2_regularization = (loss_l2_regularization if (loss_l2_regularization is not None) else False)
|
471 |
-
self.loss_l2_regularization_scale = (loss_l2_regularization_scale if (loss_l2_regularization_scale is not None) else 1)
|
472 |
-
self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
|
473 |
-
self.loss_grad1_mse_scale = (loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1)
|
474 |
-
self.f0_max = f0_max if (f0_max is not None) else 1975.5
|
475 |
-
self.f0_min = f0_min if (f0_min is not None) else 32.70
|
476 |
-
self.confidence = confidence if (confidence is not None) else False
|
477 |
-
self.threshold = threshold if (threshold is not None) else 0.05
|
478 |
-
self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
|
479 |
-
self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
|
480 |
-
self.register_buffer("cent_table", self.cent_table_b)
|
481 |
-
_leaky = nn.LeakyReLU()
|
482 |
-
self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), _leaky, nn.Conv1d(n_chans, n_chans, 3, 1, 1))
|
483 |
-
self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
|
484 |
-
self.norm = nn.LayerNorm(n_chans)
|
485 |
-
self.n_out = out_dims
|
486 |
-
self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
|
487 |
-
|
488 |
-
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"):
|
489 |
-
if cdecoder == "argmax": self.cdecoder = self.cents_decoder
|
490 |
-
elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
|
491 |
-
|
492 |
-
x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
|
493 |
-
|
494 |
-
if not infer:
|
495 |
-
loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
|
496 |
-
if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
|
497 |
-
x = loss_all
|
498 |
-
|
499 |
-
if infer:
|
500 |
-
x = self.cent_to_f0(self.cdecoder(x))
|
501 |
-
x = (1 + x / 700).log() if not return_hz_f0 else x
|
502 |
-
|
503 |
-
return x
|
504 |
-
|
505 |
-
def cents_decoder(self, y, mask=True):
|
506 |
-
B, N, _ = y.size()
|
507 |
-
rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
508 |
-
|
509 |
-
if mask:
|
510 |
-
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
511 |
-
confident_mask = torch.ones_like(confident)
|
512 |
-
|
513 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
514 |
-
rtn = rtn * confident_mask
|
515 |
-
|
516 |
-
return (rtn, confident) if self.confidence else rtn
|
517 |
-
|
518 |
-
def cents_local_decoder(self, y, mask=True):
|
519 |
-
B, N, _ = y.size()
|
520 |
-
|
521 |
-
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
522 |
-
local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
|
523 |
-
|
524 |
-
y_l = torch.gather(y, -1, local_argmax_index)
|
525 |
-
rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
526 |
-
|
527 |
-
if mask:
|
528 |
-
confident_mask = torch.ones_like(confident)
|
529 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
530 |
-
|
531 |
-
rtn = rtn * confident_mask
|
532 |
-
|
533 |
-
return (rtn, confident) if self.confidence else rtn
|
534 |
-
|
535 |
-
def cent_to_f0(self, cent):
|
536 |
-
return 10.0 * 2 ** (cent / 1200.0)
|
537 |
-
|
538 |
-
def f0_to_cent(self, f0):
|
539 |
-
return 1200.0 * torch.log2(f0 / 10.0)
|
540 |
-
|
541 |
-
def gaussian_blurred_cent(self, cents):
|
542 |
-
B, N, _ = cents.size()
|
543 |
-
return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
|
544 |
-
|
545 |
-
class FCPEInfer:
|
546 |
-
def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False):
|
547 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
548 |
-
self.wav2mel = Wav2Mel(device=device, dtype=dtype)
|
549 |
-
self.device = device
|
550 |
-
self.dtype = dtype
|
551 |
-
self.onnx = onnx
|
552 |
-
|
553 |
-
if self.onnx:
|
554 |
-
import onnxruntime as ort
|
555 |
-
|
556 |
-
sess_options = ort.SessionOptions()
|
557 |
-
sess_options.log_severity_level = 3
|
558 |
-
|
559 |
-
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
560 |
-
else:
|
561 |
-
ckpt = torch.load(model_path, map_location=torch.device(self.device))
|
562 |
-
self.args = DotDict(ckpt["config"])
|
563 |
-
model = _FCPE(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, use_siren=self.args.model.use_siren, use_full=self.args.model.use_full, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.args.model.f0_max, f0_min=self.args.model.f0_min, confidence=self.args.model.confidence)
|
564 |
-
|
565 |
-
model.to(self.device).to(self.dtype)
|
566 |
-
model.load_state_dict(ckpt["model"])
|
567 |
-
|
568 |
-
model.eval()
|
569 |
-
self.model = model
|
570 |
-
|
571 |
-
@torch.no_grad()
|
572 |
-
def __call__(self, audio, sr, threshold=0.05):
|
573 |
-
if not self.onnx: self.model.threshold = threshold
|
574 |
-
mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype)
|
575 |
-
|
576 |
-
return torch.as_tensor(self.model.run(["pitchf"], {"mel": mel.detach().cpu().numpy(), "threshold": np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device).squeeze() if self.onnx else self.model(mel=mel, infer=True, return_hz_f0=True)
|
577 |
-
|
578 |
-
class Wav2Mel:
|
579 |
-
def __init__(self, device=None, dtype=torch.float32):
|
580 |
-
self.sample_rate = 16000
|
581 |
-
self.hop_size = 160
|
582 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
583 |
-
self.device = device
|
584 |
-
self.dtype = dtype
|
585 |
-
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
|
586 |
-
self.resample_kernel = {}
|
587 |
-
|
588 |
-
def extract_nvstft(self, audio, keyshift=0, train=False):
|
589 |
-
return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
|
590 |
-
|
591 |
-
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
|
592 |
-
audio = audio.to(self.dtype).to(self.device)
|
593 |
-
|
594 |
-
if sample_rate == self.sample_rate: audio_res = audio
|
595 |
-
else:
|
596 |
-
key_str = str(sample_rate)
|
597 |
-
|
598 |
-
if key_str not in self.resample_kernel:
|
599 |
-
from torchaudio.transforms import Resample
|
600 |
-
self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
|
601 |
-
|
602 |
-
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
|
603 |
-
audio_res = self.resample_kernel[key_str](audio)
|
604 |
-
|
605 |
-
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
|
606 |
-
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
607 |
-
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
|
608 |
-
|
609 |
-
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
|
610 |
-
|
611 |
-
def __call__(self, audio, sample_rate, keyshift=0, train=False):
|
612 |
-
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
|
613 |
-
|
614 |
-
class DotDict(dict):
|
615 |
-
def __getattr__(*args):
|
616 |
-
val = dict.get(*args)
|
617 |
-
return DotDict(val) if type(val) is dict else val
|
618 |
-
|
619 |
-
__setattr__ = dict.__setitem__
|
620 |
-
__delattr__ = dict.__delitem__
|
621 |
-
|
622 |
-
class FCPE:
|
623 |
-
def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=44100, threshold=0.05, providers=None, onnx=False):
|
624 |
-
self.fcpe = FCPEInfer(model_path, device=device, dtype=dtype, providers=providers, onnx=onnx)
|
625 |
-
self.hop_length = hop_length
|
626 |
-
self.f0_min = f0_min
|
627 |
-
self.f0_max = f0_max
|
628 |
-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
629 |
-
self.threshold = threshold
|
630 |
-
self.sample_rate = sample_rate
|
631 |
-
self.dtype = dtype
|
632 |
-
self.name = "fcpe"
|
633 |
-
|
634 |
-
def repeat_expand(self, content, target_len, mode = "nearest"):
|
635 |
-
ndim = content.ndim
|
636 |
-
content = (content[None, None] if ndim == 1 else content[None] if ndim == 2 else content)
|
637 |
-
|
638 |
-
assert content.ndim == 3
|
639 |
-
is_np = isinstance(content, np.ndarray)
|
640 |
-
|
641 |
-
results = torch.nn.functional.interpolate(torch.from_numpy(content) if is_np else content, size=target_len, mode=mode)
|
642 |
-
results = results.numpy() if is_np else results
|
643 |
-
return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
|
644 |
-
|
645 |
-
def post_process(self, x, sample_rate, f0, pad_to):
|
646 |
-
f0 = (torch.from_numpy(f0).float().to(x.device) if isinstance(f0, np.ndarray) else f0)
|
647 |
-
f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
|
648 |
-
|
649 |
-
vuv_vector = torch.zeros_like(f0)
|
650 |
-
vuv_vector[f0 > 0.0] = 1.0
|
651 |
-
vuv_vector[f0 <= 0.0] = 0.0
|
652 |
-
|
653 |
-
nzindex = torch.nonzero(f0).squeeze()
|
654 |
-
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
|
655 |
-
vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
|
656 |
-
|
657 |
-
if f0.shape[0] <= 0: return np.zeros(pad_to), vuv_vector.cpu().numpy()
|
658 |
-
if f0.shape[0] == 1: return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
|
659 |
-
|
660 |
-
return np.interp(np.arange(pad_to) * self.hop_length / sample_rate, self.hop_length / sample_rate * nzindex.cpu().numpy(), f0, left=f0[0], right=f0[-1]), vuv_vector.cpu().numpy()
|
661 |
-
|
662 |
-
def compute_f0(self, wav, p_len=None):
|
663 |
-
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
664 |
-
p_len = x.shape[0] // self.hop_length if p_len is None else p_len
|
665 |
-
|
666 |
-
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)
|
667 |
-
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
|
668 |
-
|
669 |
-
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
|
670 |
-
return self.post_process(x, self.sample_rate, f0, p_len)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/RMVPE.py
DELETED
@@ -1,260 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
|
7 |
-
from librosa.filters import mel
|
8 |
-
|
9 |
-
N_MELS, N_CLASS = 128, 360
|
10 |
-
|
11 |
-
class ConvBlockRes(nn.Module):
|
12 |
-
def __init__(self, in_channels, out_channels, momentum=0.01):
|
13 |
-
super(ConvBlockRes, self).__init__()
|
14 |
-
self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
15 |
-
|
16 |
-
if in_channels != out_channels:
|
17 |
-
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
18 |
-
self.is_shortcut = True
|
19 |
-
else: self.is_shortcut = False
|
20 |
-
|
21 |
-
def forward(self, x):
|
22 |
-
return self.conv(x) + self.shortcut(x) if self.is_shortcut else self.conv(x) + x
|
23 |
-
|
24 |
-
class ResEncoderBlock(nn.Module):
|
25 |
-
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
26 |
-
super(ResEncoderBlock, self).__init__()
|
27 |
-
self.n_blocks = n_blocks
|
28 |
-
self.conv = nn.ModuleList()
|
29 |
-
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
30 |
-
|
31 |
-
for _ in range(n_blocks - 1):
|
32 |
-
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
33 |
-
|
34 |
-
self.kernel_size = kernel_size
|
35 |
-
if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
36 |
-
|
37 |
-
def forward(self, x):
|
38 |
-
for i in range(self.n_blocks):
|
39 |
-
x = self.conv[i](x)
|
40 |
-
|
41 |
-
if self.kernel_size is not None: return x, self.pool(x)
|
42 |
-
else: return x
|
43 |
-
|
44 |
-
class Encoder(nn.Module):
|
45 |
-
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
46 |
-
super(Encoder, self).__init__()
|
47 |
-
self.n_encoders = n_encoders
|
48 |
-
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
49 |
-
|
50 |
-
self.layers = nn.ModuleList()
|
51 |
-
self.latent_channels = []
|
52 |
-
|
53 |
-
for _ in range(self.n_encoders):
|
54 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
|
55 |
-
self.latent_channels.append([out_channels, in_size])
|
56 |
-
in_channels = out_channels
|
57 |
-
out_channels *= 2
|
58 |
-
in_size //= 2
|
59 |
-
|
60 |
-
self.out_size = in_size
|
61 |
-
self.out_channel = out_channels
|
62 |
-
|
63 |
-
def forward(self, x):
|
64 |
-
concat_tensors = []
|
65 |
-
x = self.bn(x)
|
66 |
-
|
67 |
-
for i in range(self.n_encoders):
|
68 |
-
t, x = self.layers[i](x)
|
69 |
-
concat_tensors.append(t)
|
70 |
-
|
71 |
-
return x, concat_tensors
|
72 |
-
|
73 |
-
class Intermediate(nn.Module):
|
74 |
-
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
75 |
-
super(Intermediate, self).__init__()
|
76 |
-
self.n_inters = n_inters
|
77 |
-
self.layers = nn.ModuleList()
|
78 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
79 |
-
|
80 |
-
for _ in range(self.n_inters - 1):
|
81 |
-
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
82 |
-
|
83 |
-
def forward(self, x):
|
84 |
-
for i in range(self.n_inters):
|
85 |
-
x = self.layers[i](x)
|
86 |
-
|
87 |
-
return x
|
88 |
-
|
89 |
-
class ResDecoderBlock(nn.Module):
|
90 |
-
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
91 |
-
super(ResDecoderBlock, self).__init__()
|
92 |
-
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
93 |
-
self.n_blocks = n_blocks
|
94 |
-
|
95 |
-
self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
96 |
-
self.conv2 = nn.ModuleList()
|
97 |
-
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
98 |
-
|
99 |
-
for _ in range(n_blocks - 1):
|
100 |
-
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
101 |
-
|
102 |
-
def forward(self, x, concat_tensor):
|
103 |
-
x = torch.cat((self.conv1(x), concat_tensor), dim=1)
|
104 |
-
|
105 |
-
for i in range(self.n_blocks):
|
106 |
-
x = self.conv2[i](x)
|
107 |
-
|
108 |
-
return x
|
109 |
-
|
110 |
-
class Decoder(nn.Module):
|
111 |
-
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
112 |
-
super(Decoder, self).__init__()
|
113 |
-
self.layers = nn.ModuleList()
|
114 |
-
self.n_decoders = n_decoders
|
115 |
-
|
116 |
-
for _ in range(self.n_decoders):
|
117 |
-
out_channels = in_channels // 2
|
118 |
-
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
|
119 |
-
in_channels = out_channels
|
120 |
-
|
121 |
-
def forward(self, x, concat_tensors):
|
122 |
-
for i in range(self.n_decoders):
|
123 |
-
x = self.layers[i](x, concat_tensors[-1 - i])
|
124 |
-
|
125 |
-
return x
|
126 |
-
|
127 |
-
class DeepUnet(nn.Module):
|
128 |
-
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
129 |
-
super(DeepUnet, self).__init__()
|
130 |
-
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
131 |
-
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
132 |
-
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
133 |
-
|
134 |
-
def forward(self, x):
|
135 |
-
x, concat_tensors = self.encoder(x)
|
136 |
-
return self.decoder(self.intermediate(x), concat_tensors)
|
137 |
-
|
138 |
-
class E2E(nn.Module):
|
139 |
-
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
140 |
-
super(E2E, self).__init__()
|
141 |
-
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
142 |
-
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
143 |
-
self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
|
144 |
-
|
145 |
-
def forward(self, mel):
|
146 |
-
return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
|
147 |
-
|
148 |
-
class MelSpectrogram(torch.nn.Module):
|
149 |
-
def __init__(self, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
|
150 |
-
super().__init__()
|
151 |
-
n_fft = win_length if n_fft is None else n_fft
|
152 |
-
self.hann_window = {}
|
153 |
-
mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
|
154 |
-
mel_basis = torch.from_numpy(mel_basis).float()
|
155 |
-
self.register_buffer("mel_basis", mel_basis)
|
156 |
-
self.n_fft = win_length if n_fft is None else n_fft
|
157 |
-
self.hop_length = hop_length
|
158 |
-
self.win_length = win_length
|
159 |
-
self.sample_rate = sample_rate
|
160 |
-
self.n_mel_channels = n_mel_channels
|
161 |
-
self.clamp = clamp
|
162 |
-
|
163 |
-
def forward(self, audio, keyshift=0, speed=1, center=True):
|
164 |
-
factor = 2 ** (keyshift / 12)
|
165 |
-
win_length_new = int(np.round(self.win_length * factor))
|
166 |
-
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
167 |
-
if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
168 |
-
|
169 |
-
fft = torch.stft(audio, n_fft=int(np.round(self.n_fft * factor)), hop_length=int(np.round(self.hop_length * speed)), win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
|
170 |
-
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
171 |
-
|
172 |
-
if keyshift != 0:
|
173 |
-
size = self.n_fft // 2 + 1
|
174 |
-
resize = magnitude.size(1)
|
175 |
-
if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
176 |
-
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
177 |
-
|
178 |
-
mel_output = torch.matmul(self.mel_basis, magnitude)
|
179 |
-
return torch.log(torch.clamp(mel_output, min=self.clamp))
|
180 |
-
|
181 |
-
class RMVPE:
|
182 |
-
def __init__(self, model_path, device=None, providers=None, onnx=False):
|
183 |
-
self.resample_kernel = {}
|
184 |
-
self.onnx = onnx
|
185 |
-
|
186 |
-
if self.onnx:
|
187 |
-
import onnxruntime as ort
|
188 |
-
|
189 |
-
sess_options = ort.SessionOptions()
|
190 |
-
sess_options.log_severity_level = 3
|
191 |
-
|
192 |
-
self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
|
193 |
-
else:
|
194 |
-
model = E2E(4, 1, (2, 2))
|
195 |
-
ckpt = torch.load(model_path, map_location="cpu")
|
196 |
-
model.load_state_dict(ckpt)
|
197 |
-
model.eval()
|
198 |
-
self.model = model.to(device)
|
199 |
-
|
200 |
-
self.resample_kernel = {}
|
201 |
-
self.device = device
|
202 |
-
self.mel_extractor = MelSpectrogram(N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
|
203 |
-
cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
|
204 |
-
self.cents_mapping = np.pad(cents_mapping, (4, 4))
|
205 |
-
|
206 |
-
def mel2hidden(self, mel):
|
207 |
-
with torch.no_grad():
|
208 |
-
n_frames = mel.shape[-1]
|
209 |
-
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
|
210 |
-
|
211 |
-
hidden = self.model.run([self.model.get_outputs()[0].name], input_feed={self.model.get_inputs()[0].name: mel.cpu().numpy()})[0] if self.onnx else self.model(mel.float())
|
212 |
-
return hidden[:, :n_frames]
|
213 |
-
|
214 |
-
def decode(self, hidden, thred=0.03):
|
215 |
-
f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
|
216 |
-
f0[f0 == 10] = 0
|
217 |
-
|
218 |
-
return f0
|
219 |
-
|
220 |
-
def infer_from_audio(self, audio, thred=0.03):
|
221 |
-
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
222 |
-
|
223 |
-
return self.decode(hidden.squeeze(0).cpu().numpy() if not self.onnx else hidden[0], thred=thred)
|
224 |
-
|
225 |
-
def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
|
226 |
-
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
227 |
-
|
228 |
-
f0 = self.decode(hidden.squeeze(0).cpu().numpy() if not self.onnx else hidden[0], thred=thred)
|
229 |
-
f0[(f0 < f0_min) | (f0 > f0_max)] = 0
|
230 |
-
|
231 |
-
return f0
|
232 |
-
|
233 |
-
def to_local_average_cents(self, salience, thred=0.05):
|
234 |
-
center = np.argmax(salience, axis=1)
|
235 |
-
salience = np.pad(salience, ((0, 0), (4, 4)))
|
236 |
-
|
237 |
-
center += 4
|
238 |
-
todo_salience, todo_cents_mapping = [], []
|
239 |
-
|
240 |
-
starts = center - 4
|
241 |
-
ends = center + 5
|
242 |
-
|
243 |
-
for idx in range(salience.shape[0]):
|
244 |
-
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
245 |
-
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
246 |
-
|
247 |
-
todo_salience = np.array(todo_salience)
|
248 |
-
|
249 |
-
devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
|
250 |
-
devided[np.max(salience, axis=1) <= thred] = 0
|
251 |
-
|
252 |
-
return devided
|
253 |
-
|
254 |
-
class BiGRU(nn.Module):
|
255 |
-
def __init__(self, input_features, hidden_features, num_layers):
|
256 |
-
super(BiGRU, self).__init__()
|
257 |
-
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
258 |
-
|
259 |
-
def forward(self, x):
|
260 |
-
return self.gru(x)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/WORLD.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import shutil
|
4 |
-
import ctypes
|
5 |
-
import platform
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import tempfile as tf
|
9 |
-
|
10 |
-
class DioOption(ctypes.Structure):
|
11 |
-
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("ChannelsInOctave", ctypes.c_double), ("FramePeriod", ctypes.c_double), ("Speed", ctypes.c_int), ("AllowedRange", ctypes.c_double)]
|
12 |
-
|
13 |
-
class HarvestOption(ctypes.Structure):
|
14 |
-
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("FramePeriod", ctypes.c_double)]
|
15 |
-
|
16 |
-
class PYWORLD:
|
17 |
-
def __init__(self):
|
18 |
-
model = torch.load(os.path.join("assets", "models", "predictors", "world.pth"), map_location="cpu")
|
19 |
-
|
20 |
-
model_type, suffix = (("world_64" if platform.architecture()[0] == "64bit" else "world_86"), ".dll") if platform.system() == "Windows" else ("world_linux", ".so")
|
21 |
-
|
22 |
-
self.temp_folder = os.path.join("assets", "models", "predictors", "temp")
|
23 |
-
os.makedirs(self.temp_folder, exist_ok=True)
|
24 |
-
|
25 |
-
with tf.NamedTemporaryFile(delete=False, suffix=suffix, dir=self.temp_folder) as temp_file:
|
26 |
-
temp_file.write(model[model_type])
|
27 |
-
temp_path = temp_file.name
|
28 |
-
|
29 |
-
self.world_dll = ctypes.CDLL(temp_path)
|
30 |
-
|
31 |
-
def harvest(self, x, fs, f0_floor=50, f0_ceil=1100, frame_period=10):
|
32 |
-
self.world_dll.Harvest.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(HarvestOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
33 |
-
self.world_dll.Harvest.restype = None
|
34 |
-
|
35 |
-
self.world_dll.InitializeHarvestOption.argtypes = [ctypes.POINTER(HarvestOption)]
|
36 |
-
self.world_dll.InitializeHarvestOption.restype = None
|
37 |
-
|
38 |
-
self.world_dll.GetSamplesForHarvest.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
39 |
-
self.world_dll.GetSamplesForHarvest.restype = ctypes.c_int
|
40 |
-
|
41 |
-
option = HarvestOption()
|
42 |
-
self.world_dll.InitializeHarvestOption(ctypes.byref(option))
|
43 |
-
|
44 |
-
option.F0Floor = f0_floor
|
45 |
-
option.F0Ceil = f0_ceil
|
46 |
-
option.FramePeriod = frame_period
|
47 |
-
|
48 |
-
f0_length = self.world_dll.GetSamplesForHarvest(fs, len(x), option.FramePeriod)
|
49 |
-
f0 = (ctypes.c_double * f0_length)()
|
50 |
-
tpos = (ctypes.c_double * f0_length)()
|
51 |
-
|
52 |
-
self.world_dll.Harvest((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
53 |
-
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
54 |
-
|
55 |
-
def dio(self, x, fs, f0_floor=50, f0_ceil=1100, channels_in_octave=2, frame_period=10, speed=1, allowed_range=0.1):
|
56 |
-
self.world_dll.Dio.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(DioOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
57 |
-
self.world_dll.Dio.restype = None
|
58 |
-
|
59 |
-
self.world_dll.InitializeDioOption.argtypes = [ctypes.POINTER(DioOption)]
|
60 |
-
self.world_dll.InitializeDioOption.restype = None
|
61 |
-
|
62 |
-
self.world_dll.GetSamplesForDIO.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
63 |
-
self.world_dll.GetSamplesForDIO.restype = ctypes.c_int
|
64 |
-
|
65 |
-
option = DioOption()
|
66 |
-
self.world_dll.InitializeDioOption(ctypes.byref(option))
|
67 |
-
|
68 |
-
option.F0Floor = f0_floor
|
69 |
-
option.F0Ceil = f0_ceil
|
70 |
-
option.ChannelsInOctave = channels_in_octave
|
71 |
-
option.FramePeriod = frame_period
|
72 |
-
option.Speed = speed
|
73 |
-
option.AllowedRange = allowed_range
|
74 |
-
|
75 |
-
f0_length = self.world_dll.GetSamplesForDIO(fs, len(x), option.FramePeriod)
|
76 |
-
f0 = (ctypes.c_double * f0_length)()
|
77 |
-
tpos = (ctypes.c_double * f0_length)()
|
78 |
-
|
79 |
-
self.world_dll.Dio((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
80 |
-
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
81 |
-
|
82 |
-
def stonemask(self, x, fs, tpos, f0):
|
83 |
-
self.world_dll.StoneMask.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
|
84 |
-
self.world_dll.StoneMask.restype = None
|
85 |
-
|
86 |
-
out_f0 = (ctypes.c_double * len(f0))()
|
87 |
-
self.world_dll.StoneMask((ctypes.c_double * len(x))(*x), len(x), fs, (ctypes.c_double * len(tpos))(*tpos), (ctypes.c_double * len(f0))(*f0), len(f0), out_f0)
|
88 |
-
|
89 |
-
if os.path.exists(self.temp_folder): shutil.rmtree(self.temp_folder, ignore_errors=True)
|
90 |
-
return np.array(out_f0, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/utils.py
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import codecs
|
5 |
-
import librosa
|
6 |
-
import logging
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import soundfile as sf
|
10 |
-
|
11 |
-
from pydub import AudioSegment, silence
|
12 |
-
|
13 |
-
sys.path.append(os.getcwd())
|
14 |
-
|
15 |
-
from main.tools import huggingface
|
16 |
-
from main.configs.config import Config
|
17 |
-
|
18 |
-
for l in ["httpx", "httpcore"]:
|
19 |
-
logging.getLogger(l).setLevel(logging.ERROR)
|
20 |
-
|
21 |
-
translations = Config().translations
|
22 |
-
|
23 |
-
|
24 |
-
def check_predictors(method):
|
25 |
-
def download(predictors):
|
26 |
-
if not os.path.exists(os.path.join("assets", "models", "predictors", predictors)): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cerqvpgbef/", "rot13") + predictors, os.path.join("assets", "models", "predictors", predictors))
|
27 |
-
|
28 |
-
model_dict = {**dict.fromkeys(["rmvpe", "rmvpe-legacy"], "rmvpe.pt"), **dict.fromkeys(["rmvpe-onnx", "rmvpe-legacy-onnx"], "rmvpe.onnx"), **dict.fromkeys(["fcpe", "fcpe-legacy"], "fcpe.pt"), **dict.fromkeys(["fcpe-onnx", "fcpe-legacy-onnx"], "fcpe.onnx"), **dict.fromkeys(["crepe-full", "mangio-crepe-full"], "crepe_full.pth"), **dict.fromkeys(["crepe-full-onnx", "mangio-crepe-full-onnx"], "crepe_full.onnx"), **dict.fromkeys(["crepe-large", "mangio-crepe-large"], "crepe_large.pth"), **dict.fromkeys(["crepe-large-onnx", "mangio-crepe-large-onnx"], "crepe_large.onnx"), **dict.fromkeys(["crepe-medium", "mangio-crepe-medium"], "crepe_medium.pth"), **dict.fromkeys(["crepe-medium-onnx", "mangio-crepe-medium-onnx"], "crepe_medium.onnx"), **dict.fromkeys(["crepe-small", "mangio-crepe-small"], "crepe_small.pth"), **dict.fromkeys(["crepe-small-onnx", "mangio-crepe-small-onnx"], "crepe_small.onnx"), **dict.fromkeys(["crepe-tiny", "mangio-crepe-tiny"], "crepe_tiny.pth"), **dict.fromkeys(["crepe-tiny-onnx", "mangio-crepe-tiny-onnx"], "crepe_tiny.onnx"), **dict.fromkeys(["harvest", "dio"], "world.pth")}
|
29 |
-
|
30 |
-
if "hybrid" in method:
|
31 |
-
methods_str = re.search("hybrid\[(.+)\]", method)
|
32 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
33 |
-
for method in methods:
|
34 |
-
if method in model_dict: download(model_dict[method])
|
35 |
-
elif method in model_dict: download(model_dict[method])
|
36 |
-
|
37 |
-
def check_embedders(hubert):
|
38 |
-
if hubert in ["contentvec_base", "hubert_base", "japanese_hubert_base", "korean_hubert_base", "chinese_hubert_base", "Hidden_Rabbit_last", "portuguese_hubert_base"]:
|
39 |
-
model_path = os.path.join("assets", "models", "embedders", hubert + '.pt')
|
40 |
-
if not os.path.exists(model_path): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/rzorqqref/", "rot13") + f"{hubert}.pt", model_path)
|
41 |
-
|
42 |
-
def load_audio(file):
|
43 |
-
try:
|
44 |
-
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
45 |
-
if not os.path.isfile(file): raise FileNotFoundError(translations["not_found"].format(name=file))
|
46 |
-
|
47 |
-
audio, sr = sf.read(file)
|
48 |
-
|
49 |
-
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
50 |
-
if sr != 16000: audio = librosa.resample(audio, orig_sr=sr, target_sr=16000, res_type="soxr_vhq")
|
51 |
-
except Exception as e:
|
52 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
53 |
-
return audio.flatten()
|
54 |
-
|
55 |
-
def process_audio(logger, file_path, output_path):
|
56 |
-
try:
|
57 |
-
song = pydub_convert(AudioSegment.from_file(file_path))
|
58 |
-
cut_files, time_stamps = [], []
|
59 |
-
|
60 |
-
for i, (start_i, end_i) in enumerate(silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70)):
|
61 |
-
chunk = song[start_i:end_i]
|
62 |
-
if len(chunk) > 10:
|
63 |
-
chunk_file_path = os.path.join(output_path, f"chunk{i}.wav")
|
64 |
-
if os.path.exists(chunk_file_path): os.remove(chunk_file_path)
|
65 |
-
chunk.export(chunk_file_path, format="wav")
|
66 |
-
cut_files.append(chunk_file_path)
|
67 |
-
time_stamps.append((start_i, end_i))
|
68 |
-
else: logger.debug(translations["skip_file"].format(i=i, chunk=len(chunk)))
|
69 |
-
logger.info(f"{translations['split_total']}: {len(cut_files)}")
|
70 |
-
return cut_files, time_stamps
|
71 |
-
except Exception as e:
|
72 |
-
raise RuntimeError(f"{translations['process_audio_error']}: {e}")
|
73 |
-
|
74 |
-
def merge_audio(files_list, time_stamps, original_file_path, output_path, format):
|
75 |
-
try:
|
76 |
-
def extract_number(filename):
|
77 |
-
match = re.search(r'_(\d+)', filename)
|
78 |
-
return int(match.group(1)) if match else 0
|
79 |
-
|
80 |
-
total_duration = len(AudioSegment.from_file(original_file_path))
|
81 |
-
combined = AudioSegment.empty()
|
82 |
-
current_position = 0
|
83 |
-
|
84 |
-
for file, (start_i, end_i) in zip(sorted(files_list, key=extract_number), time_stamps):
|
85 |
-
if start_i > current_position: combined += AudioSegment.silent(duration=start_i - current_position)
|
86 |
-
combined += AudioSegment.from_file(file)
|
87 |
-
current_position = end_i
|
88 |
-
|
89 |
-
if current_position < total_duration: combined += AudioSegment.silent(duration=total_duration - current_position)
|
90 |
-
combined.export(output_path, format=format)
|
91 |
-
return output_path
|
92 |
-
except Exception as e:
|
93 |
-
raise RuntimeError(f"{translations['merge_error']}: {e}")
|
94 |
-
|
95 |
-
def pydub_convert(audio):
|
96 |
-
samples = np.frombuffer(audio.raw_data, dtype=np.int16)
|
97 |
-
|
98 |
-
if samples.dtype != np.int16: samples = (samples * 32767).astype(np.int16)
|
99 |
-
|
100 |
-
return AudioSegment(samples.tobytes(), frame_rate=audio.frame_rate, sample_width=samples.dtype.itemsize, channels=audio.channels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/common_separator.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import gc
|
3 |
-
import sys
|
4 |
-
import torch
|
5 |
-
import librosa
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import soundfile as sf
|
9 |
-
|
10 |
-
from pydub import AudioSegment
|
11 |
-
|
12 |
-
sys.path.append(os.getcwd())
|
13 |
-
|
14 |
-
from .spec_utils import normalize
|
15 |
-
from main.configs.config import Config
|
16 |
-
|
17 |
-
translations = Config().translations
|
18 |
-
|
19 |
-
class CommonSeparator:
|
20 |
-
ALL_STEMS = "All Stems"
|
21 |
-
VOCAL_STEM = "Vocals"
|
22 |
-
INST_STEM = "Instrumental"
|
23 |
-
OTHER_STEM = "Other"
|
24 |
-
BASS_STEM = "Bass"
|
25 |
-
DRUM_STEM = "Drums"
|
26 |
-
GUITAR_STEM = "Guitar"
|
27 |
-
PIANO_STEM = "Piano"
|
28 |
-
SYNTH_STEM = "Synthesizer"
|
29 |
-
STRINGS_STEM = "Strings"
|
30 |
-
WOODWINDS_STEM = "Woodwinds"
|
31 |
-
BRASS_STEM = "Brass"
|
32 |
-
WIND_INST_STEM = "Wind Inst"
|
33 |
-
NO_OTHER_STEM = "No Other"
|
34 |
-
NO_BASS_STEM = "No Bass"
|
35 |
-
NO_DRUM_STEM = "No Drums"
|
36 |
-
NO_GUITAR_STEM = "No Guitar"
|
37 |
-
NO_PIANO_STEM = "No Piano"
|
38 |
-
NO_SYNTH_STEM = "No Synthesizer"
|
39 |
-
NO_STRINGS_STEM = "No Strings"
|
40 |
-
NO_WOODWINDS_STEM = "No Woodwinds"
|
41 |
-
NO_WIND_INST_STEM = "No Wind Inst"
|
42 |
-
NO_BRASS_STEM = "No Brass"
|
43 |
-
PRIMARY_STEM = "Primary Stem"
|
44 |
-
SECONDARY_STEM = "Secondary Stem"
|
45 |
-
LEAD_VOCAL_STEM = "lead_only"
|
46 |
-
BV_VOCAL_STEM = "backing_only"
|
47 |
-
LEAD_VOCAL_STEM_I = "with_lead_vocals"
|
48 |
-
BV_VOCAL_STEM_I = "with_backing_vocals"
|
49 |
-
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
|
50 |
-
BV_VOCAL_STEM_LABEL = "Backing Vocals"
|
51 |
-
NO_STEM = "No "
|
52 |
-
STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
|
53 |
-
NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
|
54 |
-
|
55 |
-
def __init__(self, config):
|
56 |
-
self.logger = config.get("logger")
|
57 |
-
self.log_level = config.get("log_level")
|
58 |
-
self.torch_device = config.get("torch_device")
|
59 |
-
self.torch_device_cpu = config.get("torch_device_cpu")
|
60 |
-
self.torch_device_mps = config.get("torch_device_mps")
|
61 |
-
self.onnx_execution_provider = config.get("onnx_execution_provider")
|
62 |
-
self.model_name = config.get("model_name")
|
63 |
-
self.model_path = config.get("model_path")
|
64 |
-
self.model_data = config.get("model_data")
|
65 |
-
self.output_dir = config.get("output_dir")
|
66 |
-
self.output_format = config.get("output_format")
|
67 |
-
self.output_bitrate = config.get("output_bitrate")
|
68 |
-
self.normalization_threshold = config.get("normalization_threshold")
|
69 |
-
self.enable_denoise = config.get("enable_denoise")
|
70 |
-
self.output_single_stem = config.get("output_single_stem")
|
71 |
-
self.invert_using_spec = config.get("invert_using_spec")
|
72 |
-
self.sample_rate = config.get("sample_rate")
|
73 |
-
self.primary_stem_name = None
|
74 |
-
self.secondary_stem_name = None
|
75 |
-
|
76 |
-
if "training" in self.model_data and "instruments" in self.model_data["training"]:
|
77 |
-
instruments = self.model_data["training"]["instruments"]
|
78 |
-
if instruments:
|
79 |
-
self.primary_stem_name = instruments[0]
|
80 |
-
self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
|
81 |
-
|
82 |
-
if self.primary_stem_name is None:
|
83 |
-
self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
|
84 |
-
self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
|
85 |
-
|
86 |
-
self.is_karaoke = self.model_data.get("is_karaoke", False)
|
87 |
-
self.is_bv_model = self.model_data.get("is_bv_model", False)
|
88 |
-
self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
|
89 |
-
self.logger.debug(translations["info"].format(model_name=self.model_name, model_path=self.model_path))
|
90 |
-
self.logger.debug(translations["info_2"].format(output_dir=self.output_dir, output_format=self.output_format))
|
91 |
-
self.logger.debug(translations["info_3"].format(normalization_threshold=self.normalization_threshold))
|
92 |
-
self.logger.debug(translations["info_4"].format(enable_denoise=self.enable_denoise, output_single_stem=self.output_single_stem))
|
93 |
-
self.logger.debug(translations["info_5"].format(invert_using_spec=self.invert_using_spec, sample_rate=self.sample_rate))
|
94 |
-
self.logger.debug(translations["info_6"].format(primary_stem_name=self.primary_stem_name, secondary_stem_name=self.secondary_stem_name))
|
95 |
-
self.logger.debug(translations["info_7"].format(is_karaoke=self.is_karaoke, is_bv_model=self.is_bv_model, bv_model_rebalance=self.bv_model_rebalance))
|
96 |
-
self.audio_file_path = None
|
97 |
-
self.audio_file_base = None
|
98 |
-
self.primary_source = None
|
99 |
-
self.secondary_source = None
|
100 |
-
self.primary_stem_output_path = None
|
101 |
-
self.secondary_stem_output_path = None
|
102 |
-
self.cached_sources_map = {}
|
103 |
-
|
104 |
-
def secondary_stem(self, primary_stem):
|
105 |
-
primary_stem = primary_stem if primary_stem else self.NO_STEM
|
106 |
-
return self.STEM_PAIR_MAPPER[primary_stem] if primary_stem in self.STEM_PAIR_MAPPER else primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
|
107 |
-
|
108 |
-
def separate(self, audio_file_path):
|
109 |
-
pass
|
110 |
-
|
111 |
-
def final_process(self, stem_path, source, stem_name):
|
112 |
-
self.logger.debug(translations["success_process"].format(stem_name=stem_name))
|
113 |
-
self.write_audio(stem_path, source)
|
114 |
-
return {stem_name: source}
|
115 |
-
|
116 |
-
def cached_sources_clear(self):
|
117 |
-
self.cached_sources_map = {}
|
118 |
-
|
119 |
-
def cached_source_callback(self, model_architecture, model_name=None):
|
120 |
-
model, sources = None, None
|
121 |
-
mapper = self.cached_sources_map[model_architecture]
|
122 |
-
for key, value in mapper.items():
|
123 |
-
if model_name in key:
|
124 |
-
model = key
|
125 |
-
sources = value
|
126 |
-
|
127 |
-
return model, sources
|
128 |
-
|
129 |
-
def cached_model_source_holder(self, model_architecture, sources, model_name=None):
|
130 |
-
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
|
131 |
-
|
132 |
-
def prepare_mix(self, mix):
|
133 |
-
audio_path = mix
|
134 |
-
if not isinstance(mix, np.ndarray):
|
135 |
-
self.logger.debug(f"{translations['load_audio']}: {mix}")
|
136 |
-
mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
|
137 |
-
self.logger.debug(translations["load_audio_success"].format(sr=sr, shape=mix.shape))
|
138 |
-
else:
|
139 |
-
self.logger.debug(translations["convert_mix"])
|
140 |
-
mix = mix.T
|
141 |
-
self.logger.debug(translations["convert_shape"].format(shape=mix.shape))
|
142 |
-
|
143 |
-
if isinstance(audio_path, str):
|
144 |
-
if not np.any(mix):
|
145 |
-
error_msg = translations["audio_not_valid"].format(audio_path=audio_path)
|
146 |
-
self.logger.error(error_msg)
|
147 |
-
raise ValueError(error_msg)
|
148 |
-
else: self.logger.debug(translations["audio_valid"])
|
149 |
-
|
150 |
-
if mix.ndim == 1:
|
151 |
-
self.logger.debug(translations["mix_single"])
|
152 |
-
mix = np.asfortranarray([mix, mix])
|
153 |
-
self.logger.debug(translations["convert_mix_audio"])
|
154 |
-
|
155 |
-
self.logger.debug(translations["mix_success_2"])
|
156 |
-
return mix
|
157 |
-
|
158 |
-
def write_audio(self, stem_path, stem_source):
|
159 |
-
duration_seconds = librosa.get_duration(filename=self.audio_file_path)
|
160 |
-
duration_hours = duration_seconds / 3600
|
161 |
-
self.logger.info(translations["duration"].format(duration_hours=f"{duration_hours:.2f}", duration_seconds=f"{duration_seconds:.2f}"))
|
162 |
-
|
163 |
-
if duration_hours >= 1:
|
164 |
-
self.logger.debug(translations["write"].format(name="soundfile"))
|
165 |
-
self.write_audio_soundfile(stem_path, stem_source)
|
166 |
-
else:
|
167 |
-
self.logger.info(translations["write"].format(name="pydub"))
|
168 |
-
self.write_audio_pydub(stem_path, stem_source)
|
169 |
-
|
170 |
-
def write_audio_pydub(self, stem_path, stem_source):
|
171 |
-
self.logger.debug(f"{translations['write_audio'].format(name='write_audio_pydub')} {stem_path}")
|
172 |
-
stem_source = normalize(wave=stem_source, max_peak=self.normalization_threshold)
|
173 |
-
|
174 |
-
if np.max(np.abs(stem_source)) < 1e-6:
|
175 |
-
self.logger.warning(translations["original_not_valid"])
|
176 |
-
return
|
177 |
-
|
178 |
-
if self.output_dir:
|
179 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
180 |
-
stem_path = os.path.join(self.output_dir, stem_path)
|
181 |
-
|
182 |
-
self.logger.debug(f"{translations['shape_audio']}: {stem_source.shape}")
|
183 |
-
self.logger.debug(f"{translations['convert_data']}: {stem_source.dtype}")
|
184 |
-
|
185 |
-
if stem_source.dtype != np.int16:
|
186 |
-
stem_source = (stem_source * 32767).astype(np.int16)
|
187 |
-
self.logger.debug(translations["original_source_to_int16"])
|
188 |
-
|
189 |
-
stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
190 |
-
stem_source_interleaved[0::2] = stem_source[:, 0]
|
191 |
-
stem_source_interleaved[1::2] = stem_source[:, 1]
|
192 |
-
self.logger.debug(f"{translations['shape_audio_2']}: {stem_source_interleaved.shape}")
|
193 |
-
|
194 |
-
try:
|
195 |
-
audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
|
196 |
-
self.logger.debug(translations["create_audiosegment"])
|
197 |
-
except (IOError, ValueError) as e:
|
198 |
-
self.logger.error(f"{translations['create_audiosegment_error']}: {e}")
|
199 |
-
return
|
200 |
-
|
201 |
-
file_format = stem_path.lower().split(".")[-1]
|
202 |
-
|
203 |
-
if file_format == "m4a": file_format = "mp4"
|
204 |
-
elif file_format == "mka": file_format = "matroska"
|
205 |
-
|
206 |
-
try:
|
207 |
-
audio_segment.export(stem_path, format=file_format, bitrate="320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate)
|
208 |
-
self.logger.debug(f"{translations['export_success']} {stem_path}")
|
209 |
-
except (IOError, ValueError) as e:
|
210 |
-
self.logger.error(f"{translations['export_error']}: {e}")
|
211 |
-
|
212 |
-
def write_audio_soundfile(self, stem_path, stem_source):
|
213 |
-
self.logger.debug(f"{translations['write_audio'].format(name='write_audio_soundfile')}: {stem_path}")
|
214 |
-
|
215 |
-
if stem_source.shape[1] == 2:
|
216 |
-
if stem_source.flags["F_CONTIGUOUS"]: stem_source = np.ascontiguousarray(stem_source)
|
217 |
-
else:
|
218 |
-
stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
219 |
-
stereo_interleaved[0::2] = stem_source[:, 0]
|
220 |
-
stereo_interleaved[1::2] = stem_source[:, 1]
|
221 |
-
stem_source = stereo_interleaved
|
222 |
-
|
223 |
-
self.logger.debug(f"{translations['shape_audio_2']}: {stem_source.shape}")
|
224 |
-
|
225 |
-
try:
|
226 |
-
sf.write(stem_path, stem_source, self.sample_rate)
|
227 |
-
self.logger.debug(f"{translations['export_success']} {stem_path}")
|
228 |
-
except Exception as e:
|
229 |
-
self.logger.error(f"{translations['export_error']}: {e}")
|
230 |
-
|
231 |
-
def clear_gpu_cache(self):
|
232 |
-
self.logger.debug(translations["clean"])
|
233 |
-
gc.collect()
|
234 |
-
|
235 |
-
if self.torch_device == torch.device("mps"):
|
236 |
-
self.logger.debug(translations["clean_cache"].format(name="MPS"))
|
237 |
-
torch.mps.empty_cache()
|
238 |
-
|
239 |
-
if self.torch_device == torch.device("cuda"):
|
240 |
-
self.logger.debug(translations["clean_cache"].format(name="CUDA"))
|
241 |
-
torch.cuda.empty_cache()
|
242 |
-
|
243 |
-
def clear_file_specific_paths(self):
|
244 |
-
self.logger.info(translations["del_path"])
|
245 |
-
self.audio_file_path = None
|
246 |
-
self.audio_file_base = None
|
247 |
-
self.primary_source = None
|
248 |
-
self.secondary_source = None
|
249 |
-
self.primary_stem_output_path = None
|
250 |
-
self.secondary_stem_output_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/apply.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
import tqdm
|
2 |
-
import torch
|
3 |
-
import random
|
4 |
-
|
5 |
-
from torch import nn
|
6 |
-
from torch.nn import functional as F
|
7 |
-
from concurrent.futures import ThreadPoolExecutor
|
8 |
-
|
9 |
-
from .utils import center_trim
|
10 |
-
|
11 |
-
class DummyPoolExecutor:
|
12 |
-
class DummyResult:
|
13 |
-
def __init__(self, func, *args, **kwargs):
|
14 |
-
self.func = func
|
15 |
-
self.args = args
|
16 |
-
self.kwargs = kwargs
|
17 |
-
|
18 |
-
def result(self):
|
19 |
-
return self.func(*self.args, **self.kwargs)
|
20 |
-
|
21 |
-
def __init__(self, workers=0):
|
22 |
-
pass
|
23 |
-
|
24 |
-
def submit(self, func, *args, **kwargs):
|
25 |
-
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
26 |
-
|
27 |
-
def __enter__(self):
|
28 |
-
return self
|
29 |
-
|
30 |
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
31 |
-
return
|
32 |
-
|
33 |
-
class BagOfModels(nn.Module):
|
34 |
-
def __init__(self, models, weights = None, segment = None):
|
35 |
-
super().__init__()
|
36 |
-
assert len(models) > 0
|
37 |
-
first = models[0]
|
38 |
-
|
39 |
-
for other in models:
|
40 |
-
assert other.sources == first.sources
|
41 |
-
assert other.samplerate == first.samplerate
|
42 |
-
assert other.audio_channels == first.audio_channels
|
43 |
-
|
44 |
-
if segment is not None: other.segment = segment
|
45 |
-
|
46 |
-
self.audio_channels = first.audio_channels
|
47 |
-
self.samplerate = first.samplerate
|
48 |
-
self.sources = first.sources
|
49 |
-
self.models = nn.ModuleList(models)
|
50 |
-
|
51 |
-
if weights is None: weights = [[1.0 for _ in first.sources] for _ in models]
|
52 |
-
else:
|
53 |
-
assert len(weights) == len(models)
|
54 |
-
|
55 |
-
for weight in weights:
|
56 |
-
assert len(weight) == len(first.sources)
|
57 |
-
|
58 |
-
self.weights = weights
|
59 |
-
|
60 |
-
def forward(self, x):
|
61 |
-
pass
|
62 |
-
|
63 |
-
class TensorChunk:
|
64 |
-
def __init__(self, tensor, offset=0, length=None):
|
65 |
-
total_length = tensor.shape[-1]
|
66 |
-
assert offset >= 0
|
67 |
-
assert offset < total_length
|
68 |
-
|
69 |
-
length = total_length - offset if length is None else min(total_length - offset, length)
|
70 |
-
|
71 |
-
if isinstance(tensor, TensorChunk):
|
72 |
-
self.tensor = tensor.tensor
|
73 |
-
self.offset = offset + tensor.offset
|
74 |
-
else:
|
75 |
-
self.tensor = tensor
|
76 |
-
self.offset = offset
|
77 |
-
|
78 |
-
self.length = length
|
79 |
-
self.device = tensor.device
|
80 |
-
|
81 |
-
@property
|
82 |
-
def shape(self):
|
83 |
-
shape = list(self.tensor.shape)
|
84 |
-
shape[-1] = self.length
|
85 |
-
return shape
|
86 |
-
|
87 |
-
def padded(self, target_length):
|
88 |
-
delta = target_length - self.length
|
89 |
-
total_length = self.tensor.shape[-1]
|
90 |
-
assert delta >= 0
|
91 |
-
|
92 |
-
start = self.offset - delta // 2
|
93 |
-
end = start + target_length
|
94 |
-
|
95 |
-
correct_start = max(0, start)
|
96 |
-
correct_end = min(total_length, end)
|
97 |
-
|
98 |
-
pad_left = correct_start - start
|
99 |
-
pad_right = end - correct_end
|
100 |
-
|
101 |
-
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
102 |
-
|
103 |
-
assert out.shape[-1] == target_length
|
104 |
-
return out
|
105 |
-
|
106 |
-
def tensor_chunk(tensor_or_chunk):
|
107 |
-
if isinstance(tensor_or_chunk, TensorChunk): return tensor_or_chunk
|
108 |
-
else:
|
109 |
-
assert isinstance(tensor_or_chunk, torch.Tensor)
|
110 |
-
return TensorChunk(tensor_or_chunk)
|
111 |
-
|
112 |
-
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
|
113 |
-
global fut_length, bag_num, prog_bar
|
114 |
-
|
115 |
-
device = mix.device if device is None else torch.device(device)
|
116 |
-
if pool is None: pool = ThreadPoolExecutor(num_workers) if num_workers > 0 and device.type == "cpu" else DummyPoolExecutor()
|
117 |
-
|
118 |
-
kwargs = {
|
119 |
-
"shifts": shifts,
|
120 |
-
"split": split,
|
121 |
-
"overlap": overlap,
|
122 |
-
"transition_power": transition_power,
|
123 |
-
"progress": progress,
|
124 |
-
"device": device,
|
125 |
-
"pool": pool,
|
126 |
-
"set_progress_bar": set_progress_bar,
|
127 |
-
"static_shifts": static_shifts,
|
128 |
-
}
|
129 |
-
|
130 |
-
if isinstance(model, BagOfModels):
|
131 |
-
estimates, fut_length, prog_bar, current_model = 0, 0, 0, 0
|
132 |
-
totals = [0] * len(model.sources)
|
133 |
-
bag_num = len(model.models)
|
134 |
-
|
135 |
-
for sub_model, weight in zip(model.models, model.weights):
|
136 |
-
original_model_device = next(iter(sub_model.parameters())).device
|
137 |
-
sub_model.to(device)
|
138 |
-
fut_length += fut_length
|
139 |
-
current_model += 1
|
140 |
-
out = apply_model(sub_model, mix, **kwargs)
|
141 |
-
sub_model.to(original_model_device)
|
142 |
-
|
143 |
-
for k, inst_weight in enumerate(weight):
|
144 |
-
out[:, k, :, :] *= inst_weight
|
145 |
-
totals[k] += inst_weight
|
146 |
-
|
147 |
-
estimates += out
|
148 |
-
del out
|
149 |
-
|
150 |
-
for k in range(estimates.shape[1]):
|
151 |
-
estimates[:, k, :, :] /= totals[k]
|
152 |
-
|
153 |
-
return estimates
|
154 |
-
|
155 |
-
model.to(device)
|
156 |
-
model.eval()
|
157 |
-
assert transition_power >= 1
|
158 |
-
batch, channels, length = mix.shape
|
159 |
-
|
160 |
-
if shifts:
|
161 |
-
kwargs["shifts"] = 0
|
162 |
-
max_shift = int(0.5 * model.samplerate)
|
163 |
-
mix = tensor_chunk(mix)
|
164 |
-
padded_mix = mix.padded(length + 2 * max_shift)
|
165 |
-
out = 0
|
166 |
-
|
167 |
-
for _ in range(shifts):
|
168 |
-
offset = random.randint(0, max_shift)
|
169 |
-
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
170 |
-
shifted_out = apply_model(model, shifted, **kwargs)
|
171 |
-
out += shifted_out[..., max_shift - offset :]
|
172 |
-
|
173 |
-
out /= shifts
|
174 |
-
return out
|
175 |
-
elif split:
|
176 |
-
kwargs["split"] = False
|
177 |
-
out = torch.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
178 |
-
sum_weight = torch.zeros(length, device=mix.device)
|
179 |
-
segment = int(model.samplerate * model.segment)
|
180 |
-
stride = int((1 - overlap) * segment)
|
181 |
-
offsets = range(0, length, stride)
|
182 |
-
weight = torch.cat([torch.arange(1, segment // 2 + 1, device=device), torch.arange(segment - segment // 2, 0, -1, device=device)])
|
183 |
-
assert len(weight) == segment
|
184 |
-
weight = (weight / weight.max()) ** transition_power
|
185 |
-
futures = []
|
186 |
-
|
187 |
-
for offset in offsets:
|
188 |
-
chunk = TensorChunk(mix, offset, segment)
|
189 |
-
future = pool.submit(apply_model, model, chunk, **kwargs)
|
190 |
-
futures.append((future, offset))
|
191 |
-
offset += segment
|
192 |
-
|
193 |
-
if progress: futures = tqdm.tqdm(futures)
|
194 |
-
|
195 |
-
for future, offset in futures:
|
196 |
-
if set_progress_bar:
|
197 |
-
fut_length = len(futures) * bag_num * static_shifts
|
198 |
-
prog_bar += 1
|
199 |
-
set_progress_bar(0.1, (0.8 / fut_length * prog_bar))
|
200 |
-
|
201 |
-
chunk_out = future.result()
|
202 |
-
chunk_length = chunk_out.shape[-1]
|
203 |
-
|
204 |
-
out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
|
205 |
-
sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device)
|
206 |
-
|
207 |
-
assert sum_weight.min() > 0
|
208 |
-
|
209 |
-
out /= sum_weight
|
210 |
-
return out
|
211 |
-
else:
|
212 |
-
valid_length = model.valid_length(length) if hasattr(model, "valid_length") else length
|
213 |
-
mix = tensor_chunk(mix)
|
214 |
-
padded_mix = mix.padded(valid_length).to(device)
|
215 |
-
|
216 |
-
with torch.no_grad():
|
217 |
-
out = model(padded_mix)
|
218 |
-
|
219 |
-
return center_trim(out, length)
|
220 |
-
|
221 |
-
def demucs_segments(demucs_segment, demucs_model):
|
222 |
-
if demucs_segment == "Default":
|
223 |
-
segment = None
|
224 |
-
|
225 |
-
if isinstance(demucs_model, BagOfModels):
|
226 |
-
if segment is not None:
|
227 |
-
for sub in demucs_model.models:
|
228 |
-
sub.segment = segment
|
229 |
-
else:
|
230 |
-
if segment is not None: sub.segment = segment
|
231 |
-
else:
|
232 |
-
try:
|
233 |
-
segment = int(demucs_segment)
|
234 |
-
if isinstance(demucs_model, BagOfModels):
|
235 |
-
if segment is not None:
|
236 |
-
for sub in demucs_model.models:
|
237 |
-
sub.segment = segment
|
238 |
-
else:
|
239 |
-
if segment is not None: sub.segment = segment
|
240 |
-
except:
|
241 |
-
segment = None
|
242 |
-
|
243 |
-
if isinstance(demucs_model, BagOfModels):
|
244 |
-
if segment is not None:
|
245 |
-
for sub in demucs_model.models:
|
246 |
-
sub.segment = segment
|
247 |
-
else:
|
248 |
-
if segment is not None: sub.segment = segment
|
249 |
-
|
250 |
-
return demucs_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/demucs.py
DELETED
@@ -1,370 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
import inspect
|
4 |
-
|
5 |
-
from torch import nn
|
6 |
-
|
7 |
-
from torch.nn import functional as F
|
8 |
-
|
9 |
-
from .utils import center_trim
|
10 |
-
from .states import capture_init
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
def unfold(a, kernel_size, stride):
|
15 |
-
*shape, length = a.shape
|
16 |
-
n_frames = math.ceil(length / stride)
|
17 |
-
tgt_length = (n_frames - 1) * stride + kernel_size
|
18 |
-
a = F.pad(a, (0, tgt_length - length))
|
19 |
-
strides = list(a.stride())
|
20 |
-
assert strides[-1] == 1
|
21 |
-
strides = strides[:-1] + [stride, 1]
|
22 |
-
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
23 |
-
|
24 |
-
def rescale_conv(conv, reference):
|
25 |
-
scale = (conv.weight.std().detach() / reference) ** 0.5
|
26 |
-
conv.weight.data /= scale
|
27 |
-
if conv.bias is not None: conv.bias.data /= scale
|
28 |
-
|
29 |
-
def rescale_module(module, reference):
|
30 |
-
for sub in module.modules():
|
31 |
-
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): rescale_conv(sub, reference)
|
32 |
-
|
33 |
-
class BLSTM(nn.Module):
|
34 |
-
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
35 |
-
super().__init__()
|
36 |
-
assert max_steps is None or max_steps % 4 == 0
|
37 |
-
self.max_steps = max_steps
|
38 |
-
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
39 |
-
self.linear = nn.Linear(2 * dim, dim)
|
40 |
-
self.skip = skip
|
41 |
-
|
42 |
-
def forward(self, x):
|
43 |
-
B, C, T = x.shape
|
44 |
-
y = x
|
45 |
-
framed = False
|
46 |
-
|
47 |
-
if self.max_steps is not None and T > self.max_steps:
|
48 |
-
width = self.max_steps
|
49 |
-
stride = width // 2
|
50 |
-
frames = unfold(x, width, stride)
|
51 |
-
nframes = frames.shape[2]
|
52 |
-
framed = True
|
53 |
-
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
54 |
-
|
55 |
-
x = x.permute(2, 0, 1)
|
56 |
-
x = self.lstm(x)[0]
|
57 |
-
x = self.linear(x)
|
58 |
-
x = x.permute(1, 2, 0)
|
59 |
-
|
60 |
-
if framed:
|
61 |
-
out = []
|
62 |
-
frames = x.reshape(B, -1, C, width)
|
63 |
-
limit = stride // 2
|
64 |
-
|
65 |
-
for k in range(nframes):
|
66 |
-
if k == 0: out.append(frames[:, k, :, :-limit])
|
67 |
-
elif k == nframes - 1: out.append(frames[:, k, :, limit:])
|
68 |
-
else: out.append(frames[:, k, :, limit:-limit])
|
69 |
-
|
70 |
-
out = torch.cat(out, -1)
|
71 |
-
out = out[..., :T]
|
72 |
-
x = out
|
73 |
-
|
74 |
-
if self.skip: x = x + y
|
75 |
-
return x
|
76 |
-
|
77 |
-
class LayerScale(nn.Module):
|
78 |
-
def __init__(self, channels, init = 0):
|
79 |
-
super().__init__()
|
80 |
-
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
81 |
-
self.scale.data[:] = init
|
82 |
-
|
83 |
-
def forward(self, x):
|
84 |
-
return self.scale[:, None] * x
|
85 |
-
|
86 |
-
class DConv(nn.Module):
|
87 |
-
def __init__(self, channels, compress = 4, depth = 2, init = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True):
|
88 |
-
super().__init__()
|
89 |
-
assert kernel % 2 == 1
|
90 |
-
self.channels = channels
|
91 |
-
self.compress = compress
|
92 |
-
self.depth = abs(depth)
|
93 |
-
dilate = depth > 0
|
94 |
-
norm_fn = lambda d: nn.Identity()
|
95 |
-
if norm: norm_fn = lambda d: nn.GroupNorm(1, d)
|
96 |
-
hidden = int(channels / compress)
|
97 |
-
act = nn.GELU if gelu else nn.ReLU
|
98 |
-
self.layers = nn.ModuleList([])
|
99 |
-
|
100 |
-
for d in range(self.depth):
|
101 |
-
dilation = 2**d if dilate else 1
|
102 |
-
padding = dilation * (kernel // 2)
|
103 |
-
|
104 |
-
mods = [nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), norm_fn(hidden), act(), nn.Conv1d(hidden, 2 * channels, 1), norm_fn(2 * channels), nn.GLU(1), LayerScale(channels, init)]
|
105 |
-
|
106 |
-
if attn: mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
107 |
-
if lstm: mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
108 |
-
layer = nn.Sequential(*mods)
|
109 |
-
self.layers.append(layer)
|
110 |
-
|
111 |
-
def forward(self, x):
|
112 |
-
for layer in self.layers:
|
113 |
-
x = x + layer(x)
|
114 |
-
|
115 |
-
return x
|
116 |
-
|
117 |
-
class LocalState(nn.Module):
|
118 |
-
def __init__(self, channels, heads = 4, nfreqs = 0, ndecay = 4):
|
119 |
-
super().__init__()
|
120 |
-
assert channels % heads == 0, (channels, heads)
|
121 |
-
self.heads = heads
|
122 |
-
self.nfreqs = nfreqs
|
123 |
-
self.ndecay = ndecay
|
124 |
-
self.content = nn.Conv1d(channels, channels, 1)
|
125 |
-
self.query = nn.Conv1d(channels, channels, 1)
|
126 |
-
self.key = nn.Conv1d(channels, channels, 1)
|
127 |
-
|
128 |
-
if nfreqs: self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
129 |
-
|
130 |
-
if ndecay:
|
131 |
-
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
132 |
-
self.query_decay.weight.data *= 0.01
|
133 |
-
assert self.query_decay.bias is not None
|
134 |
-
self.query_decay.bias.data[:] = -2
|
135 |
-
|
136 |
-
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
137 |
-
|
138 |
-
def forward(self, x):
|
139 |
-
B, C, T = x.shape
|
140 |
-
heads = self.heads
|
141 |
-
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
142 |
-
delta = indexes[:, None] - indexes[None, :]
|
143 |
-
queries = self.query(x).view(B, heads, -1, T)
|
144 |
-
keys = self.key(x).view(B, heads, -1, T)
|
145 |
-
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
146 |
-
dots /= keys.shape[2] ** 0.5
|
147 |
-
|
148 |
-
if self.nfreqs:
|
149 |
-
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
150 |
-
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
151 |
-
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5
|
152 |
-
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
153 |
-
|
154 |
-
if self.ndecay:
|
155 |
-
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
156 |
-
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
157 |
-
decay_q = torch.sigmoid(decay_q) / 2
|
158 |
-
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
159 |
-
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
160 |
-
|
161 |
-
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
162 |
-
weights = torch.softmax(dots, dim=2)
|
163 |
-
content = self.content(x).view(B, heads, -1, T)
|
164 |
-
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
165 |
-
|
166 |
-
if self.nfreqs:
|
167 |
-
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
168 |
-
result = torch.cat([result, time_sig], 2)
|
169 |
-
|
170 |
-
result = result.reshape(B, -1, T)
|
171 |
-
return x + self.proj(result)
|
172 |
-
|
173 |
-
class Demucs(nn.Module):
|
174 |
-
@capture_init
|
175 |
-
def __init__(self, sources, audio_channels=2, channels=64, growth=2.0, depth=6, rewrite=True, lstm_layers=0, kernel_size=8, stride=4, context=1, gelu=True, glu=True, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, normalize=True, resample=True, rescale=0.1, samplerate=44100, segment=4 * 10):
|
176 |
-
super().__init__()
|
177 |
-
self.audio_channels = audio_channels
|
178 |
-
self.sources = sources
|
179 |
-
self.kernel_size = kernel_size
|
180 |
-
self.context = context
|
181 |
-
self.stride = stride
|
182 |
-
self.depth = depth
|
183 |
-
self.resample = resample
|
184 |
-
self.channels = channels
|
185 |
-
self.normalize = normalize
|
186 |
-
self.samplerate = samplerate
|
187 |
-
self.segment = segment
|
188 |
-
self.encoder = nn.ModuleList()
|
189 |
-
self.decoder = nn.ModuleList()
|
190 |
-
self.skip_scales = nn.ModuleList()
|
191 |
-
|
192 |
-
if glu:
|
193 |
-
activation = nn.GLU(dim=1)
|
194 |
-
ch_scale = 2
|
195 |
-
else:
|
196 |
-
activation = nn.ReLU()
|
197 |
-
ch_scale = 1
|
198 |
-
|
199 |
-
act2 = nn.GELU if gelu else nn.ReLU
|
200 |
-
|
201 |
-
in_channels = audio_channels
|
202 |
-
padding = 0
|
203 |
-
|
204 |
-
for index in range(depth):
|
205 |
-
norm_fn = lambda d: nn.Identity()
|
206 |
-
if index >= norm_starts: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
|
207 |
-
|
208 |
-
encode = []
|
209 |
-
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2()]
|
210 |
-
attn = index >= dconv_attn
|
211 |
-
lstm = index >= dconv_lstm
|
212 |
-
|
213 |
-
if dconv_mode & 1: encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
214 |
-
if rewrite: encode += [nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation]
|
215 |
-
self.encoder.append(nn.Sequential(*encode))
|
216 |
-
|
217 |
-
decode = []
|
218 |
-
out_channels = in_channels if index > 0 else len(self.sources) * audio_channels
|
219 |
-
|
220 |
-
if rewrite: decode += [nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation]
|
221 |
-
if dconv_mode & 2: decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
222 |
-
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)]
|
223 |
-
|
224 |
-
if index > 0: decode += [norm_fn(out_channels), act2()]
|
225 |
-
self.decoder.insert(0, nn.Sequential(*decode))
|
226 |
-
in_channels = channels
|
227 |
-
channels = int(growth * channels)
|
228 |
-
|
229 |
-
channels = in_channels
|
230 |
-
self.lstm = BLSTM(channels, lstm_layers) if lstm_layers else None
|
231 |
-
if rescale: rescale_module(self, reference=rescale)
|
232 |
-
|
233 |
-
def valid_length(self, length):
|
234 |
-
if self.resample: length *= 2
|
235 |
-
|
236 |
-
for _ in range(self.depth):
|
237 |
-
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
238 |
-
length = max(1, length)
|
239 |
-
|
240 |
-
for _ in range(self.depth):
|
241 |
-
length = (length - 1) * self.stride + self.kernel_size
|
242 |
-
|
243 |
-
if self.resample: length = math.ceil(length / 2)
|
244 |
-
return int(length)
|
245 |
-
|
246 |
-
def forward(self, mix):
|
247 |
-
x = mix
|
248 |
-
length = x.shape[-1]
|
249 |
-
|
250 |
-
if self.normalize:
|
251 |
-
mono = mix.mean(dim=1, keepdim=True)
|
252 |
-
mean = mono.mean(dim=-1, keepdim=True)
|
253 |
-
std = mono.std(dim=-1, keepdim=True)
|
254 |
-
x = (x - mean) / (1e-5 + std)
|
255 |
-
else:
|
256 |
-
mean = 0
|
257 |
-
std = 1
|
258 |
-
|
259 |
-
delta = self.valid_length(length) - length
|
260 |
-
x = F.pad(x, (delta // 2, delta - delta // 2))
|
261 |
-
|
262 |
-
if self.resample: x = resample_frac(x, 1, 2)
|
263 |
-
saved = []
|
264 |
-
|
265 |
-
for encode in self.encoder:
|
266 |
-
x = encode(x)
|
267 |
-
saved.append(x)
|
268 |
-
|
269 |
-
if self.lstm: x = self.lstm(x)
|
270 |
-
|
271 |
-
for decode in self.decoder:
|
272 |
-
skip = saved.pop(-1)
|
273 |
-
skip = center_trim(skip, x)
|
274 |
-
x = decode(x + skip)
|
275 |
-
|
276 |
-
if self.resample: x = resample_frac(x, 2, 1)
|
277 |
-
|
278 |
-
x = x * std + mean
|
279 |
-
x = center_trim(x, length)
|
280 |
-
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
281 |
-
return x
|
282 |
-
|
283 |
-
def load_state_dict(self, state, strict=True):
|
284 |
-
for idx in range(self.depth):
|
285 |
-
for a in ["encoder", "decoder"]:
|
286 |
-
for b in ["bias", "weight"]:
|
287 |
-
new = f"{a}.{idx}.3.{b}"
|
288 |
-
old = f"{a}.{idx}.2.{b}"
|
289 |
-
|
290 |
-
if old in state and new not in state: state[new] = state.pop(old)
|
291 |
-
super().load_state_dict(state, strict=strict)
|
292 |
-
|
293 |
-
class ResampleFrac(torch.nn.Module):
|
294 |
-
def __init__(self, old_sr, new_sr, zeros = 24, rolloff = 0.945):
|
295 |
-
super().__init__()
|
296 |
-
gcd = math.gcd(old_sr, new_sr)
|
297 |
-
self.old_sr = old_sr // gcd
|
298 |
-
self.new_sr = new_sr // gcd
|
299 |
-
self.zeros = zeros
|
300 |
-
self.rolloff = rolloff
|
301 |
-
self._init_kernels()
|
302 |
-
|
303 |
-
def _init_kernels(self):
|
304 |
-
if self.old_sr == self.new_sr: return
|
305 |
-
|
306 |
-
kernels = []
|
307 |
-
sr = min(self.new_sr, self.old_sr)
|
308 |
-
sr *= self.rolloff
|
309 |
-
|
310 |
-
self._width = math.ceil(self.zeros * self.old_sr / sr)
|
311 |
-
idx = torch.arange(-self._width, self._width + self.old_sr).float()
|
312 |
-
|
313 |
-
for i in range(self.new_sr):
|
314 |
-
t = ((-i/self.new_sr + idx/self.old_sr) * sr).clamp_(-self.zeros, self.zeros)
|
315 |
-
t *= math.pi
|
316 |
-
|
317 |
-
kernel = sinc(t) * (torch.cos(t/self.zeros/2)**2)
|
318 |
-
kernel.div_(kernel.sum())
|
319 |
-
kernels.append(kernel)
|
320 |
-
|
321 |
-
self.register_buffer("kernel", torch.stack(kernels).view(self.new_sr, 1, -1))
|
322 |
-
|
323 |
-
def forward(self, x, output_length = None, full = False):
|
324 |
-
if self.old_sr == self.new_sr: return x
|
325 |
-
shape = x.shape
|
326 |
-
length = x.shape[-1]
|
327 |
-
|
328 |
-
x = x.reshape(-1, length)
|
329 |
-
y = F.conv1d(F.pad(x[:, None], (self._width, self._width + self.old_sr), mode='replicate'), self.kernel, stride=self.old_sr).transpose(1, 2).reshape(list(shape[:-1]) + [-1])
|
330 |
-
|
331 |
-
float_output_length = torch.as_tensor(self.new_sr * length / self.old_sr)
|
332 |
-
max_output_length = torch.ceil(float_output_length).long()
|
333 |
-
default_output_length = torch.floor(float_output_length).long()
|
334 |
-
|
335 |
-
if output_length is None: applied_output_length = max_output_length if full else default_output_length
|
336 |
-
elif output_length < 0 or output_length > max_output_length: raise ValueError("output_length < 0 or output_length > max_output_length")
|
337 |
-
else:
|
338 |
-
applied_output_length = torch.tensor(output_length)
|
339 |
-
if full: raise ValueError("full=True")
|
340 |
-
|
341 |
-
return y[..., :applied_output_length]
|
342 |
-
|
343 |
-
def __repr__(self):
|
344 |
-
return simple_repr(self)
|
345 |
-
|
346 |
-
def sinc(x):
|
347 |
-
return torch.where(x == 0, torch.tensor(1., device=x.device, dtype=x.dtype), torch.sin(x) / x)
|
348 |
-
|
349 |
-
def simple_repr(obj, attrs = None, overrides = {}):
|
350 |
-
params = inspect.signature(obj.__class__).parameters
|
351 |
-
attrs_repr = []
|
352 |
-
|
353 |
-
if attrs is None: attrs = list(params.keys())
|
354 |
-
for attr in attrs:
|
355 |
-
display = False
|
356 |
-
|
357 |
-
if attr in overrides: value = overrides[attr]
|
358 |
-
elif hasattr(obj, attr): value = getattr(obj, attr)
|
359 |
-
else: continue
|
360 |
-
|
361 |
-
if attr in params:
|
362 |
-
param = params[attr]
|
363 |
-
if param.default is inspect._empty or value != param.default: display = True
|
364 |
-
else: display = True
|
365 |
-
|
366 |
-
if display: attrs_repr.append(f"{attr}={value}")
|
367 |
-
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
|
368 |
-
|
369 |
-
def resample_frac(x, old_sr, new_sr, zeros = 24, rolloff = 0.945, output_length = None, full = False):
|
370 |
-
return ResampleFrac(old_sr, new_sr, zeros, rolloff).to(x)(x, output_length, full)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/hdemucs.py
DELETED
@@ -1,760 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from torch import nn
|
5 |
-
from copy import deepcopy
|
6 |
-
|
7 |
-
from torch.nn import functional as F
|
8 |
-
|
9 |
-
from .states import capture_init
|
10 |
-
from .demucs import DConv, rescale_module
|
11 |
-
|
12 |
-
|
13 |
-
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
14 |
-
*other, length = x.shape
|
15 |
-
x = x.reshape(-1, length)
|
16 |
-
device_type = x.device.type
|
17 |
-
is_other_gpu = not device_type in ["cuda", "cpu"]
|
18 |
-
if is_other_gpu: x = x.cpu()
|
19 |
-
z = torch.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=torch.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
|
20 |
-
_, freqs, frame = z.shape
|
21 |
-
return z.view(*other, freqs, frame)
|
22 |
-
|
23 |
-
def ispectro(z, hop_length=None, length=None, pad=0):
|
24 |
-
*other, freqs, frames = z.shape
|
25 |
-
n_fft = 2 * freqs - 2
|
26 |
-
z = z.view(-1, freqs, frames)
|
27 |
-
win_length = n_fft // (1 + pad)
|
28 |
-
device_type = z.device.type
|
29 |
-
is_other_gpu = not device_type in ["cuda", "cpu"]
|
30 |
-
if is_other_gpu: z = z.cpu()
|
31 |
-
x = torch.istft(z, n_fft, hop_length, window=torch.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
|
32 |
-
_, length = x.shape
|
33 |
-
return x.view(*other, length)
|
34 |
-
|
35 |
-
def atan2(y, x):
|
36 |
-
pi = 2 * torch.asin(torch.tensor(1.0))
|
37 |
-
x += ((x == 0) & (y == 0)) * 1.0
|
38 |
-
out = torch.atan(y / x)
|
39 |
-
out += ((y >= 0) & (x < 0)) * pi
|
40 |
-
out -= ((y < 0) & (x < 0)) * pi
|
41 |
-
out *= 1 - ((y > 0) & (x == 0)) * 1.0
|
42 |
-
out += ((y > 0) & (x == 0)) * (pi / 2)
|
43 |
-
out *= 1 - ((y < 0) & (x == 0)) * 1.0
|
44 |
-
out += ((y < 0) & (x == 0)) * (-pi / 2)
|
45 |
-
return out
|
46 |
-
|
47 |
-
def _norm(x):
|
48 |
-
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
|
49 |
-
|
50 |
-
def _mul_add(a, b, out = None):
|
51 |
-
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
52 |
-
if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
53 |
-
|
54 |
-
if out is a:
|
55 |
-
real_a = a[..., 0]
|
56 |
-
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
|
57 |
-
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
|
58 |
-
else:
|
59 |
-
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
|
60 |
-
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
|
61 |
-
|
62 |
-
return out
|
63 |
-
|
64 |
-
def _mul(a, b, out = None):
|
65 |
-
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
66 |
-
if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
67 |
-
|
68 |
-
if out is a:
|
69 |
-
real_a = a[..., 0]
|
70 |
-
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
|
71 |
-
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
|
72 |
-
else:
|
73 |
-
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
|
74 |
-
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
|
75 |
-
|
76 |
-
return out
|
77 |
-
|
78 |
-
def _inv(z, out = None):
|
79 |
-
ez = _norm(z)
|
80 |
-
if out is None or out.shape != z.shape: out = torch.zeros_like(z)
|
81 |
-
|
82 |
-
out[..., 0] = z[..., 0] / ez
|
83 |
-
out[..., 1] = -z[..., 1] / ez
|
84 |
-
|
85 |
-
return out
|
86 |
-
|
87 |
-
def _conj(z, out = None):
|
88 |
-
if out is None or out.shape != z.shape: out = torch.zeros_like(z)
|
89 |
-
|
90 |
-
out[..., 0] = z[..., 0]
|
91 |
-
out[..., 1] = -z[..., 1]
|
92 |
-
|
93 |
-
return out
|
94 |
-
|
95 |
-
def _invert(M, out = None):
|
96 |
-
nb_channels = M.shape[-2]
|
97 |
-
if out is None or out.shape != M.shape: out = torch.empty_like(M)
|
98 |
-
|
99 |
-
if nb_channels == 1: out = _inv(M, out)
|
100 |
-
elif nb_channels == 2:
|
101 |
-
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
|
102 |
-
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
|
103 |
-
invDet = _inv(det)
|
104 |
-
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
|
105 |
-
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
|
106 |
-
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
|
107 |
-
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
|
108 |
-
else: raise Exception("Torch == 2 Channels")
|
109 |
-
return out
|
110 |
-
|
111 |
-
def expectation_maximization(y, x, iterations = 2, eps = 1e-10, batch_size = 200):
|
112 |
-
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
|
113 |
-
nb_sources = y.shape[-1]
|
114 |
-
regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
|
115 |
-
regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
|
116 |
-
R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
|
117 |
-
weight = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
|
118 |
-
v = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
|
119 |
-
|
120 |
-
for _ in range(iterations):
|
121 |
-
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
|
122 |
-
for j in range(nb_sources):
|
123 |
-
R[j] = torch.tensor(0.0, device=x.device)
|
124 |
-
|
125 |
-
weight = torch.tensor(eps, device=x.device)
|
126 |
-
pos = 0
|
127 |
-
batch_size = batch_size if batch_size else nb_frames
|
128 |
-
|
129 |
-
while pos < nb_frames:
|
130 |
-
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
131 |
-
pos = int(t[-1]) + 1
|
132 |
-
|
133 |
-
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
|
134 |
-
weight = weight + torch.sum(v[t, ..., j], dim=0)
|
135 |
-
|
136 |
-
R[j] = R[j] / weight[..., None, None, None]
|
137 |
-
weight = torch.zeros_like(weight)
|
138 |
-
|
139 |
-
if y.requires_grad: y = y.clone()
|
140 |
-
|
141 |
-
pos = 0
|
142 |
-
|
143 |
-
while pos < nb_frames:
|
144 |
-
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
145 |
-
pos = int(t[-1]) + 1
|
146 |
-
|
147 |
-
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
148 |
-
|
149 |
-
Cxx = regularization
|
150 |
-
|
151 |
-
for j in range(nb_sources):
|
152 |
-
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
|
153 |
-
|
154 |
-
inv_Cxx = _invert(Cxx)
|
155 |
-
|
156 |
-
for j in range(nb_sources):
|
157 |
-
gain = torch.zeros_like(inv_Cxx)
|
158 |
-
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
|
159 |
-
|
160 |
-
for index in indices:
|
161 |
-
gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
|
162 |
-
|
163 |
-
gain = gain * v[t, ..., None, None, None, j]
|
164 |
-
|
165 |
-
for i in range(nb_channels):
|
166 |
-
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
|
167 |
-
|
168 |
-
return y, v, R
|
169 |
-
|
170 |
-
def wiener(targets_spectrograms, mix_stft, iterations = 1, softmask = False, residual = False, scale_factor = 10.0, eps = 1e-10):
|
171 |
-
if softmask: y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
|
172 |
-
else:
|
173 |
-
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
|
174 |
-
nb_sources = targets_spectrograms.shape[-1]
|
175 |
-
y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
|
176 |
-
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
|
177 |
-
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
|
178 |
-
|
179 |
-
if residual: y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
|
180 |
-
if iterations == 0: return y
|
181 |
-
|
182 |
-
max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
|
183 |
-
mix_stft = mix_stft / max_abs
|
184 |
-
y = y / max_abs
|
185 |
-
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
|
186 |
-
y = y * max_abs
|
187 |
-
|
188 |
-
return y
|
189 |
-
|
190 |
-
def _covariance(y_j):
|
191 |
-
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
|
192 |
-
|
193 |
-
Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
|
194 |
-
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
|
195 |
-
|
196 |
-
for index in indices:
|
197 |
-
Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
|
198 |
-
|
199 |
-
return Cj
|
200 |
-
|
201 |
-
def pad1d(x, paddings, mode = "constant", value = 0.0):
|
202 |
-
x0 = x
|
203 |
-
length = x.shape[-1]
|
204 |
-
padding_left, padding_right = paddings
|
205 |
-
|
206 |
-
if mode == "reflect":
|
207 |
-
max_pad = max(padding_left, padding_right)
|
208 |
-
|
209 |
-
if length <= max_pad:
|
210 |
-
extra_pad = max_pad - length + 1
|
211 |
-
extra_pad_right = min(padding_right, extra_pad)
|
212 |
-
extra_pad_left = extra_pad - extra_pad_right
|
213 |
-
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
|
214 |
-
x = F.pad(x, (extra_pad_left, extra_pad_right))
|
215 |
-
|
216 |
-
out = F.pad(x, paddings, mode, value)
|
217 |
-
|
218 |
-
assert out.shape[-1] == length + padding_left + padding_right
|
219 |
-
assert (out[..., padding_left : padding_left + length] == x0).all()
|
220 |
-
return out
|
221 |
-
|
222 |
-
class ScaledEmbedding(nn.Module):
|
223 |
-
def __init__(self, num_embeddings, embedding_dim, scale = 10.0, smooth=False):
|
224 |
-
super().__init__()
|
225 |
-
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
226 |
-
|
227 |
-
if smooth:
|
228 |
-
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
229 |
-
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
230 |
-
self.embedding.weight.data[:] = weight
|
231 |
-
|
232 |
-
self.embedding.weight.data /= scale
|
233 |
-
self.scale = scale
|
234 |
-
|
235 |
-
@property
|
236 |
-
def weight(self):
|
237 |
-
return self.embedding.weight * self.scale
|
238 |
-
|
239 |
-
def forward(self, x):
|
240 |
-
return self.embedding(x) * self.scale
|
241 |
-
|
242 |
-
class HEncLayer(nn.Module):
|
243 |
-
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
|
244 |
-
super().__init__()
|
245 |
-
norm_fn = lambda d: nn.Identity()
|
246 |
-
if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
|
247 |
-
pad = kernel_size // 4 if pad else 0
|
248 |
-
|
249 |
-
klass = nn.Conv1d
|
250 |
-
self.freq = freq
|
251 |
-
self.kernel_size = kernel_size
|
252 |
-
self.stride = stride
|
253 |
-
self.empty = empty
|
254 |
-
self.norm = norm
|
255 |
-
self.pad = pad
|
256 |
-
|
257 |
-
if freq:
|
258 |
-
kernel_size = [kernel_size, 1]
|
259 |
-
stride = [stride, 1]
|
260 |
-
pad = [pad, 0]
|
261 |
-
klass = nn.Conv2d
|
262 |
-
|
263 |
-
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
264 |
-
if self.empty: return
|
265 |
-
|
266 |
-
self.norm1 = norm_fn(chout)
|
267 |
-
self.rewrite = None
|
268 |
-
|
269 |
-
if rewrite:
|
270 |
-
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
271 |
-
self.norm2 = norm_fn(2 * chout)
|
272 |
-
|
273 |
-
self.dconv = None
|
274 |
-
if dconv: self.dconv = DConv(chout, **dconv_kw)
|
275 |
-
|
276 |
-
def forward(self, x, inject=None):
|
277 |
-
if not self.freq and x.dim() == 4:
|
278 |
-
B, C, Fr, T = x.shape
|
279 |
-
x = x.view(B, -1, T)
|
280 |
-
|
281 |
-
if not self.freq:
|
282 |
-
le = x.shape[-1]
|
283 |
-
if not le % self.stride == 0: x = F.pad(x, (0, self.stride - (le % self.stride)))
|
284 |
-
|
285 |
-
y = self.conv(x)
|
286 |
-
if self.empty: return y
|
287 |
-
|
288 |
-
if inject is not None:
|
289 |
-
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
290 |
-
|
291 |
-
if inject.dim() == 3 and y.dim() == 4: inject = inject[:, :, None]
|
292 |
-
y = y + inject
|
293 |
-
|
294 |
-
y = F.gelu(self.norm1(y))
|
295 |
-
|
296 |
-
if self.dconv:
|
297 |
-
if self.freq:
|
298 |
-
B, C, Fr, T = y.shape
|
299 |
-
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
300 |
-
|
301 |
-
y = self.dconv(y)
|
302 |
-
if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
303 |
-
|
304 |
-
if self.rewrite:
|
305 |
-
z = self.norm2(self.rewrite(y))
|
306 |
-
z = F.glu(z, dim=1)
|
307 |
-
else: z = y
|
308 |
-
|
309 |
-
return z
|
310 |
-
|
311 |
-
class MultiWrap(nn.Module):
|
312 |
-
def __init__(self, layer, split_ratios):
|
313 |
-
super().__init__()
|
314 |
-
self.split_ratios = split_ratios
|
315 |
-
self.layers = nn.ModuleList()
|
316 |
-
self.conv = isinstance(layer, HEncLayer)
|
317 |
-
assert not layer.norm
|
318 |
-
assert layer.freq
|
319 |
-
assert layer.pad
|
320 |
-
|
321 |
-
if not self.conv: assert not layer.context_freq
|
322 |
-
|
323 |
-
for _ in range(len(split_ratios) + 1):
|
324 |
-
lay = deepcopy(layer)
|
325 |
-
|
326 |
-
if self.conv: lay.conv.padding = (0, 0)
|
327 |
-
else: lay.pad = False
|
328 |
-
|
329 |
-
for m in lay.modules():
|
330 |
-
if hasattr(m, "reset_parameters"): m.reset_parameters()
|
331 |
-
|
332 |
-
self.layers.append(lay)
|
333 |
-
|
334 |
-
def forward(self, x, skip=None, length=None):
|
335 |
-
B, C, Fr, T = x.shape
|
336 |
-
ratios = list(self.split_ratios) + [1]
|
337 |
-
start = 0
|
338 |
-
outs = []
|
339 |
-
|
340 |
-
for ratio, layer in zip(ratios, self.layers):
|
341 |
-
if self.conv:
|
342 |
-
pad = layer.kernel_size // 4
|
343 |
-
|
344 |
-
if ratio == 1:
|
345 |
-
limit = Fr
|
346 |
-
frames = -1
|
347 |
-
else:
|
348 |
-
limit = int(round(Fr * ratio))
|
349 |
-
le = limit - start
|
350 |
-
|
351 |
-
if start == 0: le += pad
|
352 |
-
|
353 |
-
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
354 |
-
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
355 |
-
|
356 |
-
if start == 0: limit -= pad
|
357 |
-
|
358 |
-
assert limit - start > 0, (limit, start)
|
359 |
-
assert limit <= Fr, (limit, Fr)
|
360 |
-
|
361 |
-
y = x[:, :, start:limit, :]
|
362 |
-
|
363 |
-
if start == 0: y = F.pad(y, (0, 0, pad, 0))
|
364 |
-
if ratio == 1: y = F.pad(y, (0, 0, 0, pad))
|
365 |
-
|
366 |
-
outs.append(layer(y))
|
367 |
-
start = limit - layer.kernel_size + layer.stride
|
368 |
-
else:
|
369 |
-
limit = Fr if ratio == 1 else int(round(Fr * ratio))
|
370 |
-
|
371 |
-
last = layer.last
|
372 |
-
layer.last = True
|
373 |
-
|
374 |
-
y = x[:, :, start:limit]
|
375 |
-
s = skip[:, :, start:limit]
|
376 |
-
out, _ = layer(y, s, None)
|
377 |
-
|
378 |
-
if outs:
|
379 |
-
outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
|
380 |
-
out = out[:, :, layer.stride :]
|
381 |
-
|
382 |
-
if ratio == 1: out = out[:, :, : -layer.stride // 2, :]
|
383 |
-
if start == 0: out = out[:, :, layer.stride // 2 :, :]
|
384 |
-
|
385 |
-
outs.append(out)
|
386 |
-
layer.last = last
|
387 |
-
start = limit
|
388 |
-
|
389 |
-
out = torch.cat(outs, dim=2)
|
390 |
-
if not self.conv and not last: out = F.gelu(out)
|
391 |
-
|
392 |
-
if self.conv: return out
|
393 |
-
else: return out, None
|
394 |
-
|
395 |
-
class HDecLayer(nn.Module):
|
396 |
-
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True):
|
397 |
-
super().__init__()
|
398 |
-
norm_fn = lambda d: nn.Identity()
|
399 |
-
|
400 |
-
if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
|
401 |
-
pad = kernel_size // 4 if pad else 0
|
402 |
-
|
403 |
-
self.pad = pad
|
404 |
-
self.last = last
|
405 |
-
self.freq = freq
|
406 |
-
self.chin = chin
|
407 |
-
self.empty = empty
|
408 |
-
self.stride = stride
|
409 |
-
self.kernel_size = kernel_size
|
410 |
-
self.norm = norm
|
411 |
-
self.context_freq = context_freq
|
412 |
-
klass = nn.Conv1d
|
413 |
-
klass_tr = nn.ConvTranspose1d
|
414 |
-
|
415 |
-
if freq:
|
416 |
-
kernel_size = [kernel_size, 1]
|
417 |
-
stride = [stride, 1]
|
418 |
-
klass = nn.Conv2d
|
419 |
-
klass_tr = nn.ConvTranspose2d
|
420 |
-
|
421 |
-
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
422 |
-
self.norm2 = norm_fn(chout)
|
423 |
-
|
424 |
-
if self.empty: return
|
425 |
-
self.rewrite = None
|
426 |
-
|
427 |
-
if rewrite:
|
428 |
-
if context_freq: self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
429 |
-
else: self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
|
430 |
-
|
431 |
-
self.norm1 = norm_fn(2 * chin)
|
432 |
-
|
433 |
-
self.dconv = None
|
434 |
-
if dconv: self.dconv = DConv(chin, **dconv_kw)
|
435 |
-
|
436 |
-
def forward(self, x, skip, length):
|
437 |
-
if self.freq and x.dim() == 3:
|
438 |
-
B, C, T = x.shape
|
439 |
-
x = x.view(B, self.chin, -1, T)
|
440 |
-
|
441 |
-
if not self.empty:
|
442 |
-
x = x + skip
|
443 |
-
|
444 |
-
y = F.glu(self.norm1(self.rewrite(x)), dim=1) if self.rewrite else x
|
445 |
-
|
446 |
-
if self.dconv:
|
447 |
-
if self.freq:
|
448 |
-
B, C, Fr, T = y.shape
|
449 |
-
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
450 |
-
|
451 |
-
y = self.dconv(y)
|
452 |
-
|
453 |
-
if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
454 |
-
else:
|
455 |
-
y = x
|
456 |
-
assert skip is None
|
457 |
-
|
458 |
-
z = self.norm2(self.conv_tr(y))
|
459 |
-
|
460 |
-
if self.freq:
|
461 |
-
if self.pad: z = z[..., self.pad : -self.pad, :]
|
462 |
-
else:
|
463 |
-
z = z[..., self.pad : self.pad + length]
|
464 |
-
assert z.shape[-1] == length, (z.shape[-1], length)
|
465 |
-
|
466 |
-
if not self.last: z = F.gelu(z)
|
467 |
-
return z, y
|
468 |
-
|
469 |
-
class HDemucs(nn.Module):
|
470 |
-
@capture_init
|
471 |
-
def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=6, rewrite=True, hybrid=True, hybrid_old=False, multi_freqs=None, multi_freqs_depth=2, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, rescale=0.1, samplerate=44100, segment=4 * 10):
|
472 |
-
super().__init__()
|
473 |
-
self.cac = cac
|
474 |
-
self.wiener_residual = wiener_residual
|
475 |
-
self.audio_channels = audio_channels
|
476 |
-
self.sources = sources
|
477 |
-
self.kernel_size = kernel_size
|
478 |
-
self.context = context
|
479 |
-
self.stride = stride
|
480 |
-
self.depth = depth
|
481 |
-
self.channels = channels
|
482 |
-
self.samplerate = samplerate
|
483 |
-
self.segment = segment
|
484 |
-
self.nfft = nfft
|
485 |
-
self.hop_length = nfft // 4
|
486 |
-
self.wiener_iters = wiener_iters
|
487 |
-
self.end_iters = end_iters
|
488 |
-
self.freq_emb = None
|
489 |
-
self.hybrid = hybrid
|
490 |
-
self.hybrid_old = hybrid_old
|
491 |
-
if hybrid_old: assert hybrid
|
492 |
-
if hybrid: assert wiener_iters == end_iters
|
493 |
-
self.encoder = nn.ModuleList()
|
494 |
-
self.decoder = nn.ModuleList()
|
495 |
-
|
496 |
-
if hybrid:
|
497 |
-
self.tencoder = nn.ModuleList()
|
498 |
-
self.tdecoder = nn.ModuleList()
|
499 |
-
|
500 |
-
chin = audio_channels
|
501 |
-
chin_z = chin
|
502 |
-
|
503 |
-
if self.cac: chin_z *= 2
|
504 |
-
|
505 |
-
chout = channels_time or channels
|
506 |
-
chout_z = channels
|
507 |
-
freqs = nfft // 2
|
508 |
-
|
509 |
-
for index in range(depth):
|
510 |
-
lstm = index >= dconv_lstm
|
511 |
-
attn = index >= dconv_attn
|
512 |
-
norm = index >= norm_starts
|
513 |
-
freq = freqs > 1
|
514 |
-
stri = stride
|
515 |
-
ker = kernel_size
|
516 |
-
|
517 |
-
if not freq:
|
518 |
-
assert freqs == 1
|
519 |
-
|
520 |
-
ker = time_stride * 2
|
521 |
-
stri = time_stride
|
522 |
-
|
523 |
-
pad = True
|
524 |
-
last_freq = False
|
525 |
-
|
526 |
-
if freq and freqs <= kernel_size:
|
527 |
-
ker = freqs
|
528 |
-
pad = False
|
529 |
-
last_freq = True
|
530 |
-
|
531 |
-
kw = {
|
532 |
-
"kernel_size": ker,
|
533 |
-
"stride": stri,
|
534 |
-
"freq": freq,
|
535 |
-
"pad": pad,
|
536 |
-
"norm": norm,
|
537 |
-
"rewrite": rewrite,
|
538 |
-
"norm_groups": norm_groups,
|
539 |
-
"dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
540 |
-
}
|
541 |
-
|
542 |
-
kwt = dict(kw)
|
543 |
-
kwt["freq"] = 0
|
544 |
-
kwt["kernel_size"] = kernel_size
|
545 |
-
kwt["stride"] = stride
|
546 |
-
kwt["pad"] = True
|
547 |
-
kw_dec = dict(kw)
|
548 |
-
|
549 |
-
multi = False
|
550 |
-
|
551 |
-
if multi_freqs and index < multi_freqs_depth:
|
552 |
-
multi = True
|
553 |
-
kw_dec["context_freq"] = False
|
554 |
-
|
555 |
-
if last_freq:
|
556 |
-
chout_z = max(chout, chout_z)
|
557 |
-
chout = chout_z
|
558 |
-
|
559 |
-
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
560 |
-
if hybrid and freq:
|
561 |
-
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
562 |
-
self.tencoder.append(tenc)
|
563 |
-
|
564 |
-
if multi: enc = MultiWrap(enc, multi_freqs)
|
565 |
-
|
566 |
-
self.encoder.append(enc)
|
567 |
-
if index == 0:
|
568 |
-
chin = self.audio_channels * len(self.sources)
|
569 |
-
chin_z = chin
|
570 |
-
|
571 |
-
if self.cac: chin_z *= 2
|
572 |
-
|
573 |
-
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
574 |
-
if multi: dec = MultiWrap(dec, multi_freqs)
|
575 |
-
|
576 |
-
if hybrid and freq:
|
577 |
-
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
578 |
-
self.tdecoder.insert(0, tdec)
|
579 |
-
|
580 |
-
self.decoder.insert(0, dec)
|
581 |
-
chin = chout
|
582 |
-
chin_z = chout_z
|
583 |
-
chout = int(growth * chout)
|
584 |
-
chout_z = int(growth * chout_z)
|
585 |
-
|
586 |
-
if freq:
|
587 |
-
if freqs <= kernel_size: freqs = 1
|
588 |
-
else: freqs //= stride
|
589 |
-
|
590 |
-
if index == 0 and freq_emb:
|
591 |
-
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
592 |
-
self.freq_emb_scale = freq_emb
|
593 |
-
|
594 |
-
if rescale: rescale_module(self, reference=rescale)
|
595 |
-
|
596 |
-
def _spec(self, x):
|
597 |
-
hl = self.hop_length
|
598 |
-
nfft = self.nfft
|
599 |
-
|
600 |
-
if self.hybrid:
|
601 |
-
assert hl == nfft // 4
|
602 |
-
le = int(math.ceil(x.shape[-1] / hl))
|
603 |
-
pad = hl // 2 * 3
|
604 |
-
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") if not self.hybrid_old else pad1d(x, (pad, pad + le * hl - x.shape[-1]))
|
605 |
-
|
606 |
-
z = spectro(x, nfft, hl)[..., :-1, :]
|
607 |
-
if self.hybrid:
|
608 |
-
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
609 |
-
z = z[..., 2 : 2 + le]
|
610 |
-
|
611 |
-
return z
|
612 |
-
|
613 |
-
def _ispec(self, z, length=None, scale=0):
|
614 |
-
hl = self.hop_length // (4**scale)
|
615 |
-
z = F.pad(z, (0, 0, 0, 1))
|
616 |
-
|
617 |
-
if self.hybrid:
|
618 |
-
z = F.pad(z, (2, 2))
|
619 |
-
pad = hl // 2 * 3
|
620 |
-
le = hl * int(math.ceil(length / hl)) + 2 * pad if not self.hybrid_old else hl * int(math.ceil(length / hl))
|
621 |
-
x = ispectro(z, hl, length=le)
|
622 |
-
x = x[..., pad : pad + length] if not self.hybrid_old else x[..., :length]
|
623 |
-
else: x = ispectro(z, hl, length)
|
624 |
-
|
625 |
-
return x
|
626 |
-
|
627 |
-
def _magnitude(self, z):
|
628 |
-
if self.cac:
|
629 |
-
B, C, Fr, T = z.shape
|
630 |
-
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
631 |
-
m = m.reshape(B, C * 2, Fr, T)
|
632 |
-
else: m = z.abs()
|
633 |
-
|
634 |
-
return m
|
635 |
-
|
636 |
-
def _mask(self, z, m):
|
637 |
-
niters = self.wiener_iters
|
638 |
-
if self.cac:
|
639 |
-
B, S, C, Fr, T = m.shape
|
640 |
-
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
641 |
-
out = torch.view_as_complex(out.contiguous())
|
642 |
-
return out
|
643 |
-
|
644 |
-
if self.training: niters = self.end_iters
|
645 |
-
|
646 |
-
if niters < 0:
|
647 |
-
z = z[:, None]
|
648 |
-
return z / (1e-8 + z.abs()) * m
|
649 |
-
else: return self._wiener(m, z, niters)
|
650 |
-
|
651 |
-
def _wiener(self, mag_out, mix_stft, niters):
|
652 |
-
init = mix_stft.dtype
|
653 |
-
wiener_win_len = 300
|
654 |
-
residual = self.wiener_residual
|
655 |
-
B, S, C, Fq, T = mag_out.shape
|
656 |
-
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
657 |
-
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
658 |
-
outs = []
|
659 |
-
|
660 |
-
for sample in range(B):
|
661 |
-
pos = 0
|
662 |
-
out = []
|
663 |
-
|
664 |
-
for pos in range(0, T, wiener_win_len):
|
665 |
-
frame = slice(pos, pos + wiener_win_len)
|
666 |
-
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
667 |
-
out.append(z_out.transpose(-1, -2))
|
668 |
-
|
669 |
-
outs.append(torch.cat(out, dim=0))
|
670 |
-
|
671 |
-
out = torch.view_as_complex(torch.stack(outs, 0))
|
672 |
-
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
673 |
-
|
674 |
-
if residual: out = out[:, :-1]
|
675 |
-
assert list(out.shape) == [B, S, C, Fq, T]
|
676 |
-
return out.to(init)
|
677 |
-
|
678 |
-
def forward(self, mix):
|
679 |
-
x = mix
|
680 |
-
length = x.shape[-1]
|
681 |
-
z = self._spec(mix)
|
682 |
-
mag = self._magnitude(z).to(mix.device)
|
683 |
-
x = mag
|
684 |
-
B, C, Fq, T = x.shape
|
685 |
-
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
686 |
-
std = x.std(dim=(1, 2, 3), keepdim=True)
|
687 |
-
x = (x - mean) / (1e-5 + std)
|
688 |
-
|
689 |
-
if self.hybrid:
|
690 |
-
xt = mix
|
691 |
-
meant = xt.mean(dim=(1, 2), keepdim=True)
|
692 |
-
stdt = xt.std(dim=(1, 2), keepdim=True)
|
693 |
-
xt = (xt - meant) / (1e-5 + stdt)
|
694 |
-
|
695 |
-
saved, saved_t, lengths, lengths_t = [], [], [], []
|
696 |
-
|
697 |
-
for idx, encode in enumerate(self.encoder):
|
698 |
-
lengths.append(x.shape[-1])
|
699 |
-
inject = None
|
700 |
-
|
701 |
-
if self.hybrid and idx < len(self.tencoder):
|
702 |
-
lengths_t.append(xt.shape[-1])
|
703 |
-
tenc = self.tencoder[idx]
|
704 |
-
xt = tenc(xt)
|
705 |
-
|
706 |
-
if not tenc.empty: saved_t.append(xt)
|
707 |
-
else: inject = xt
|
708 |
-
|
709 |
-
x = encode(x, inject)
|
710 |
-
|
711 |
-
if idx == 0 and self.freq_emb is not None:
|
712 |
-
frs = torch.arange(x.shape[-2], device=x.device)
|
713 |
-
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
714 |
-
x = x + self.freq_emb_scale * emb
|
715 |
-
|
716 |
-
saved.append(x)
|
717 |
-
|
718 |
-
x = torch.zeros_like(x)
|
719 |
-
if self.hybrid: xt = torch.zeros_like(x)
|
720 |
-
|
721 |
-
for idx, decode in enumerate(self.decoder):
|
722 |
-
skip = saved.pop(-1)
|
723 |
-
x, pre = decode(x, skip, lengths.pop(-1))
|
724 |
-
|
725 |
-
if self.hybrid: offset = self.depth - len(self.tdecoder)
|
726 |
-
|
727 |
-
if self.hybrid and idx >= offset:
|
728 |
-
tdec = self.tdecoder[idx - offset]
|
729 |
-
length_t = lengths_t.pop(-1)
|
730 |
-
|
731 |
-
if tdec.empty:
|
732 |
-
assert pre.shape[2] == 1, pre.shape
|
733 |
-
|
734 |
-
pre = pre[:, :, 0]
|
735 |
-
xt, _ = tdec(pre, None, length_t)
|
736 |
-
else:
|
737 |
-
skip = saved_t.pop(-1)
|
738 |
-
xt, _ = tdec(xt, skip, length_t)
|
739 |
-
|
740 |
-
assert len(saved) == 0
|
741 |
-
assert len(lengths_t) == 0
|
742 |
-
assert len(saved_t) == 0
|
743 |
-
|
744 |
-
S = len(self.sources)
|
745 |
-
x = x.view(B, S, -1, Fq, T)
|
746 |
-
x = x * std[:, None] + mean[:, None]
|
747 |
-
device_type = x.device.type
|
748 |
-
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
749 |
-
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
750 |
-
if x_is_other_gpu: x = x.cpu()
|
751 |
-
zout = self._mask(z, x)
|
752 |
-
x = self._ispec(zout, length)
|
753 |
-
if x_is_other_gpu: x = x.to(device_load)
|
754 |
-
|
755 |
-
if self.hybrid:
|
756 |
-
xt = xt.view(B, S, -1, length)
|
757 |
-
xt = xt * stdt[:, None] + meant[:, None]
|
758 |
-
x = xt + x
|
759 |
-
|
760 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/htdemucs.py
DELETED
@@ -1,600 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
import random
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
from torch import nn
|
10 |
-
from einops import rearrange
|
11 |
-
from fractions import Fraction
|
12 |
-
from torch.nn import functional as F
|
13 |
-
|
14 |
-
sys.path.append(os.getcwd())
|
15 |
-
|
16 |
-
from .states import capture_init
|
17 |
-
from .demucs import rescale_module
|
18 |
-
from main.configs.config import Config
|
19 |
-
from .hdemucs import pad1d, spectro, ispectro, wiener, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
|
20 |
-
|
21 |
-
translations = Config().translations
|
22 |
-
|
23 |
-
def create_sin_embedding(length, dim, shift = 0, device="cpu", max_period=10000):
|
24 |
-
assert dim % 2 == 0
|
25 |
-
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
|
26 |
-
half_dim = dim // 2
|
27 |
-
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
28 |
-
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
29 |
-
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
30 |
-
|
31 |
-
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
|
32 |
-
if d_model % 4 != 0: raise ValueError(translations["dims"].format(dims=d_model))
|
33 |
-
pe = torch.zeros(d_model, height, width)
|
34 |
-
d_model = int(d_model / 2)
|
35 |
-
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model))
|
36 |
-
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
37 |
-
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
38 |
-
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
39 |
-
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
40 |
-
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
41 |
-
pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
42 |
-
|
43 |
-
return pe[None, :].to(device)
|
44 |
-
|
45 |
-
def create_sin_embedding_cape(length, dim, batch_size, mean_normalize, augment, max_global_shift = 0.0, max_local_shift = 0.0, max_scale = 1.0, device = "cpu", max_period = 10000.0):
|
46 |
-
assert dim % 2 == 0
|
47 |
-
pos = 1.0 * torch.arange(length).view(-1, 1, 1)
|
48 |
-
pos = pos.repeat(1, batch_size, 1)
|
49 |
-
if mean_normalize: pos -= torch.nanmean(pos, dim=0, keepdim=True)
|
50 |
-
|
51 |
-
if augment:
|
52 |
-
delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1])
|
53 |
-
delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1])
|
54 |
-
log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1])
|
55 |
-
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
|
56 |
-
|
57 |
-
pos = pos.to(device)
|
58 |
-
half_dim = dim // 2
|
59 |
-
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
60 |
-
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
61 |
-
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float()
|
62 |
-
|
63 |
-
class MyGroupNorm(nn.GroupNorm):
|
64 |
-
def __init__(self, *args, **kwargs):
|
65 |
-
super().__init__(*args, **kwargs)
|
66 |
-
|
67 |
-
def forward(self, x):
|
68 |
-
x = x.transpose(1, 2)
|
69 |
-
return super().forward(x).transpose(1, 2)
|
70 |
-
|
71 |
-
class LayerScale(nn.Module):
|
72 |
-
def __init__(self, channels, init = 0, channel_last=False):
|
73 |
-
super().__init__()
|
74 |
-
self.channel_last = channel_last
|
75 |
-
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
76 |
-
self.scale.data[:] = init
|
77 |
-
|
78 |
-
def forward(self, x):
|
79 |
-
if self.channel_last: return self.scale * x
|
80 |
-
else: return self.scale[:, None] * x
|
81 |
-
|
82 |
-
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
83 |
-
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, group_norm=0, norm_first=False, norm_out=False, layer_norm_eps=1e-5, layer_scale=False, init_values=1e-4, device=None, dtype=None, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, auto_sparsity=False, sparsity=0.95, batch_first=False):
|
84 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
85 |
-
super().__init__(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, device=device, dtype=dtype)
|
86 |
-
self.auto_sparsity = auto_sparsity
|
87 |
-
|
88 |
-
if group_norm:
|
89 |
-
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
90 |
-
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
91 |
-
|
92 |
-
self.norm_out = None
|
93 |
-
if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
94 |
-
|
95 |
-
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
96 |
-
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
97 |
-
|
98 |
-
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
99 |
-
x = src
|
100 |
-
T, B, C = x.shape
|
101 |
-
|
102 |
-
if self.norm_first:
|
103 |
-
x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
|
104 |
-
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
|
105 |
-
if self.norm_out: x = self.norm_out(x)
|
106 |
-
else:
|
107 |
-
x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)))
|
108 |
-
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
109 |
-
|
110 |
-
return x
|
111 |
-
|
112 |
-
class CrossTransformerEncoder(nn.Module):
|
113 |
-
def __init__(self, dim, emb = "sin", hidden_scale = 4.0, num_heads = 8, num_layers = 6, cross_first = False, dropout = 0.0, max_positions = 1000, norm_in = True, norm_in_group = False, group_norm = False, norm_first = False, norm_out = False, max_period = 10000.0, weight_decay = 0.0, lr = None, layer_scale = False, gelu = True, sin_random_shift = 0, weight_pos_embed = 1.0, cape_mean_normalize = True, cape_augment = True, cape_glob_loc_scale = [5000.0, 1.0, 1.4], sparse_self_attn = False, sparse_cross_attn = False, mask_type = "diag", mask_random_seed = 42, sparse_attn_window = 500, global_window = 50, auto_sparsity = False, sparsity = 0.95):
|
114 |
-
super().__init__()
|
115 |
-
assert dim % num_heads == 0
|
116 |
-
hidden_dim = int(dim * hidden_scale)
|
117 |
-
self.num_layers = num_layers
|
118 |
-
self.classic_parity = 1 if cross_first else 0
|
119 |
-
self.emb = emb
|
120 |
-
self.max_period = max_period
|
121 |
-
self.weight_decay = weight_decay
|
122 |
-
self.weight_pos_embed = weight_pos_embed
|
123 |
-
self.sin_random_shift = sin_random_shift
|
124 |
-
|
125 |
-
if emb == "cape":
|
126 |
-
self.cape_mean_normalize = cape_mean_normalize
|
127 |
-
self.cape_augment = cape_augment
|
128 |
-
self.cape_glob_loc_scale = cape_glob_loc_scale
|
129 |
-
|
130 |
-
if emb == "scaled": self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
|
131 |
-
|
132 |
-
self.lr = lr
|
133 |
-
activation = F.gelu if gelu else F.relu
|
134 |
-
|
135 |
-
if norm_in:
|
136 |
-
self.norm_in = nn.LayerNorm(dim)
|
137 |
-
self.norm_in_t = nn.LayerNorm(dim)
|
138 |
-
elif norm_in_group:
|
139 |
-
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
|
140 |
-
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
|
141 |
-
else:
|
142 |
-
self.norm_in = nn.Identity()
|
143 |
-
self.norm_in_t = nn.Identity()
|
144 |
-
|
145 |
-
self.layers = nn.ModuleList()
|
146 |
-
self.layers_t = nn.ModuleList()
|
147 |
-
|
148 |
-
kwargs_common = {
|
149 |
-
"d_model": dim,
|
150 |
-
"nhead": num_heads,
|
151 |
-
"dim_feedforward": hidden_dim,
|
152 |
-
"dropout": dropout,
|
153 |
-
"activation": activation,
|
154 |
-
"group_norm": group_norm,
|
155 |
-
"norm_first": norm_first,
|
156 |
-
"norm_out": norm_out,
|
157 |
-
"layer_scale": layer_scale,
|
158 |
-
"mask_type": mask_type,
|
159 |
-
"mask_random_seed": mask_random_seed,
|
160 |
-
"sparse_attn_window": sparse_attn_window,
|
161 |
-
"global_window": global_window,
|
162 |
-
"sparsity": sparsity,
|
163 |
-
"auto_sparsity": auto_sparsity,
|
164 |
-
"batch_first": True,
|
165 |
-
}
|
166 |
-
|
167 |
-
kwargs_classic_encoder = dict(kwargs_common)
|
168 |
-
kwargs_classic_encoder.update({"sparse": sparse_self_attn})
|
169 |
-
kwargs_cross_encoder = dict(kwargs_common)
|
170 |
-
kwargs_cross_encoder.update({"sparse": sparse_cross_attn})
|
171 |
-
|
172 |
-
for idx in range(num_layers):
|
173 |
-
if idx % 2 == self.classic_parity:
|
174 |
-
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
175 |
-
self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
176 |
-
else:
|
177 |
-
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
178 |
-
self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
179 |
-
|
180 |
-
def forward(self, x, xt):
|
181 |
-
B, C, Fr, T1 = x.shape
|
182 |
-
|
183 |
-
pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period)
|
184 |
-
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
|
185 |
-
|
186 |
-
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
|
187 |
-
x = self.norm_in(x)
|
188 |
-
x = x + self.weight_pos_embed * pos_emb_2d
|
189 |
-
|
190 |
-
B, C, T2 = xt.shape
|
191 |
-
xt = rearrange(xt, "b c t2 -> b t2 c")
|
192 |
-
|
193 |
-
pos_emb = self._get_pos_embedding(T2, B, C, x.device)
|
194 |
-
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
|
195 |
-
|
196 |
-
xt = self.norm_in_t(xt)
|
197 |
-
xt = xt + self.weight_pos_embed * pos_emb
|
198 |
-
|
199 |
-
for idx in range(self.num_layers):
|
200 |
-
if idx % 2 == self.classic_parity:
|
201 |
-
x = self.layers[idx](x)
|
202 |
-
xt = self.layers_t[idx](xt)
|
203 |
-
else:
|
204 |
-
old_x = x
|
205 |
-
x = self.layers[idx](x, xt)
|
206 |
-
xt = self.layers_t[idx](xt, old_x)
|
207 |
-
|
208 |
-
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
|
209 |
-
xt = rearrange(xt, "b t2 c -> b c t2")
|
210 |
-
return x, xt
|
211 |
-
|
212 |
-
def _get_pos_embedding(self, T, B, C, device):
|
213 |
-
if self.emb == "sin":
|
214 |
-
shift = random.randrange(self.sin_random_shift + 1)
|
215 |
-
pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period)
|
216 |
-
elif self.emb == "cape":
|
217 |
-
if self.training: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=self.cape_augment, max_global_shift=self.cape_glob_loc_scale[0], max_local_shift=self.cape_glob_loc_scale[1], max_scale=self.cape_glob_loc_scale[2])
|
218 |
-
else: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False)
|
219 |
-
elif self.emb == "scaled":
|
220 |
-
pos = torch.arange(T, device=device)
|
221 |
-
pos_emb = self.position_embeddings(pos)[:, None]
|
222 |
-
|
223 |
-
return pos_emb
|
224 |
-
|
225 |
-
def make_optim_group(self):
|
226 |
-
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
|
227 |
-
if self.lr is not None: group["lr"] = self.lr
|
228 |
-
return group
|
229 |
-
|
230 |
-
class CrossTransformerEncoderLayer(nn.Module):
|
231 |
-
def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1, activation=F.relu, layer_norm_eps = 1e-5, layer_scale = False, init_values = 1e-4, norm_first = False, group_norm = False, norm_out = False, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, sparsity=0.95, auto_sparsity=None, device=None, dtype=None, batch_first=False):
|
232 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
233 |
-
super().__init__()
|
234 |
-
self.auto_sparsity = auto_sparsity
|
235 |
-
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
236 |
-
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
237 |
-
self.dropout = nn.Dropout(dropout)
|
238 |
-
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
239 |
-
self.norm_first = norm_first
|
240 |
-
|
241 |
-
if group_norm:
|
242 |
-
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
243 |
-
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
244 |
-
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
245 |
-
else:
|
246 |
-
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
247 |
-
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
248 |
-
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
249 |
-
|
250 |
-
self.norm_out = None
|
251 |
-
if self.norm_first & norm_out:
|
252 |
-
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
253 |
-
|
254 |
-
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
255 |
-
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
256 |
-
self.dropout1 = nn.Dropout(dropout)
|
257 |
-
self.dropout2 = nn.Dropout(dropout)
|
258 |
-
|
259 |
-
if isinstance(activation, str): self.activation = self._get_activation_fn(activation)
|
260 |
-
else: self.activation = activation
|
261 |
-
|
262 |
-
def forward(self, q, k, mask=None):
|
263 |
-
if self.norm_first:
|
264 |
-
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
|
265 |
-
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
|
266 |
-
|
267 |
-
if self.norm_out: x = self.norm_out(x)
|
268 |
-
else:
|
269 |
-
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
|
270 |
-
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
271 |
-
|
272 |
-
return x
|
273 |
-
|
274 |
-
def _ca_block(self, q, k, attn_mask=None):
|
275 |
-
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
|
276 |
-
return self.dropout1(x)
|
277 |
-
|
278 |
-
def _ff_block(self, x):
|
279 |
-
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
280 |
-
return self.dropout2(x)
|
281 |
-
|
282 |
-
def _get_activation_fn(self, activation):
|
283 |
-
if activation == "relu": return F.relu
|
284 |
-
elif activation == "gelu": return F.gelu
|
285 |
-
raise RuntimeError(translations["activation"].format(activation=activation))
|
286 |
-
|
287 |
-
class HTDemucs(nn.Module):
|
288 |
-
@capture_init
|
289 |
-
def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=4, rewrite=True, multi_freqs=None, multi_freqs_depth=3, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=8, dconv_init=1e-3, bottom_channels=0, t_layers=5, t_emb="sin", t_hidden_scale=4.0, t_heads=8, t_dropout=0.0, t_max_positions=10000, t_norm_in=True, t_norm_in_group=False, t_group_norm=False, t_norm_first=True, t_norm_out=True, t_max_period=10000.0, t_weight_decay=0.0, t_lr=None, t_layer_scale=True, t_gelu=True, t_weight_pos_embed=1.0, t_sin_random_shift=0, t_cape_mean_normalize=True, t_cape_augment=True, t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], t_sparse_self_attn=False, t_sparse_cross_attn=False, t_mask_type="diag", t_mask_random_seed=42, t_sparse_attn_window=500, t_global_window=100, t_sparsity=0.95, t_auto_sparsity=False, t_cross_first=False, rescale=0.1, samplerate=44100, segment=4 * 10, use_train_segment=True):
|
290 |
-
super().__init__()
|
291 |
-
self.cac = cac
|
292 |
-
self.wiener_residual = wiener_residual
|
293 |
-
self.audio_channels = audio_channels
|
294 |
-
self.sources = sources
|
295 |
-
self.kernel_size = kernel_size
|
296 |
-
self.context = context
|
297 |
-
self.stride = stride
|
298 |
-
self.depth = depth
|
299 |
-
self.bottom_channels = bottom_channels
|
300 |
-
self.channels = channels
|
301 |
-
self.samplerate = samplerate
|
302 |
-
self.segment = segment
|
303 |
-
self.use_train_segment = use_train_segment
|
304 |
-
self.nfft = nfft
|
305 |
-
self.hop_length = nfft // 4
|
306 |
-
self.wiener_iters = wiener_iters
|
307 |
-
self.end_iters = end_iters
|
308 |
-
self.freq_emb = None
|
309 |
-
assert wiener_iters == end_iters
|
310 |
-
self.encoder = nn.ModuleList()
|
311 |
-
self.decoder = nn.ModuleList()
|
312 |
-
self.tencoder = nn.ModuleList()
|
313 |
-
self.tdecoder = nn.ModuleList()
|
314 |
-
chin = audio_channels
|
315 |
-
chin_z = chin
|
316 |
-
if self.cac: chin_z *= 2
|
317 |
-
chout = channels_time or channels
|
318 |
-
chout_z = channels
|
319 |
-
freqs = nfft // 2
|
320 |
-
|
321 |
-
for index in range(depth):
|
322 |
-
norm = index >= norm_starts
|
323 |
-
freq = freqs > 1
|
324 |
-
stri = stride
|
325 |
-
ker = kernel_size
|
326 |
-
|
327 |
-
if not freq:
|
328 |
-
assert freqs == 1
|
329 |
-
ker = time_stride * 2
|
330 |
-
stri = time_stride
|
331 |
-
|
332 |
-
pad = True
|
333 |
-
last_freq = False
|
334 |
-
|
335 |
-
if freq and freqs <= kernel_size:
|
336 |
-
ker = freqs
|
337 |
-
pad = False
|
338 |
-
last_freq = True
|
339 |
-
|
340 |
-
kw = {
|
341 |
-
"kernel_size": ker,
|
342 |
-
"stride": stri,
|
343 |
-
"freq": freq,
|
344 |
-
"pad": pad,
|
345 |
-
"norm": norm,
|
346 |
-
"rewrite": rewrite,
|
347 |
-
"norm_groups": norm_groups,
|
348 |
-
"dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
349 |
-
}
|
350 |
-
|
351 |
-
kwt = dict(kw)
|
352 |
-
kwt["freq"] = 0
|
353 |
-
kwt["kernel_size"] = kernel_size
|
354 |
-
kwt["stride"] = stride
|
355 |
-
kwt["pad"] = True
|
356 |
-
kw_dec = dict(kw)
|
357 |
-
multi = False
|
358 |
-
|
359 |
-
if multi_freqs and index < multi_freqs_depth:
|
360 |
-
multi = True
|
361 |
-
kw_dec["context_freq"] = False
|
362 |
-
|
363 |
-
if last_freq:
|
364 |
-
chout_z = max(chout, chout_z)
|
365 |
-
chout = chout_z
|
366 |
-
|
367 |
-
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
368 |
-
if freq:
|
369 |
-
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
370 |
-
self.tencoder.append(tenc)
|
371 |
-
|
372 |
-
if multi: enc = MultiWrap(enc, multi_freqs)
|
373 |
-
|
374 |
-
self.encoder.append(enc)
|
375 |
-
if index == 0:
|
376 |
-
chin = self.audio_channels * len(self.sources)
|
377 |
-
chin_z = chin
|
378 |
-
if self.cac: chin_z *= 2
|
379 |
-
|
380 |
-
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
381 |
-
if multi: dec = MultiWrap(dec, multi_freqs)
|
382 |
-
|
383 |
-
if freq:
|
384 |
-
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
385 |
-
self.tdecoder.insert(0, tdec)
|
386 |
-
|
387 |
-
self.decoder.insert(0, dec)
|
388 |
-
chin = chout
|
389 |
-
chin_z = chout_z
|
390 |
-
chout = int(growth * chout)
|
391 |
-
chout_z = int(growth * chout_z)
|
392 |
-
|
393 |
-
if freq:
|
394 |
-
if freqs <= kernel_size: freqs = 1
|
395 |
-
else: freqs //= stride
|
396 |
-
|
397 |
-
if index == 0 and freq_emb:
|
398 |
-
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
399 |
-
self.freq_emb_scale = freq_emb
|
400 |
-
|
401 |
-
if rescale: rescale_module(self, reference=rescale)
|
402 |
-
transformer_channels = channels * growth ** (depth - 1)
|
403 |
-
|
404 |
-
if bottom_channels:
|
405 |
-
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
406 |
-
self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
407 |
-
self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
408 |
-
self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
409 |
-
transformer_channels = bottom_channels
|
410 |
-
|
411 |
-
if t_layers > 0: self.crosstransformer = CrossTransformerEncoder(dim=transformer_channels, emb=t_emb, hidden_scale=t_hidden_scale, num_heads=t_heads, num_layers=t_layers, cross_first=t_cross_first, dropout=t_dropout, max_positions=t_max_positions, norm_in=t_norm_in, norm_in_group=t_norm_in_group, group_norm=t_group_norm, norm_first=t_norm_first, norm_out=t_norm_out, max_period=t_max_period, weight_decay=t_weight_decay, lr=t_lr, layer_scale=t_layer_scale, gelu=t_gelu, sin_random_shift=t_sin_random_shift, weight_pos_embed=t_weight_pos_embed, cape_mean_normalize=t_cape_mean_normalize, cape_augment=t_cape_augment, cape_glob_loc_scale=t_cape_glob_loc_scale, sparse_self_attn=t_sparse_self_attn, sparse_cross_attn=t_sparse_cross_attn, mask_type=t_mask_type, mask_random_seed=t_mask_random_seed, sparse_attn_window=t_sparse_attn_window, global_window=t_global_window, sparsity=t_sparsity, auto_sparsity=t_auto_sparsity)
|
412 |
-
else: self.crosstransformer = None
|
413 |
-
|
414 |
-
def _spec(self, x):
|
415 |
-
hl = self.hop_length
|
416 |
-
nfft = self.nfft
|
417 |
-
assert hl == nfft // 4
|
418 |
-
le = int(math.ceil(x.shape[-1] / hl))
|
419 |
-
pad = hl // 2 * 3
|
420 |
-
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
421 |
-
z = spectro(x, nfft, hl)[..., :-1, :]
|
422 |
-
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
423 |
-
z = z[..., 2 : 2 + le]
|
424 |
-
return z
|
425 |
-
|
426 |
-
def _ispec(self, z, length=None, scale=0):
|
427 |
-
hl = self.hop_length // (4**scale)
|
428 |
-
z = F.pad(z, (0, 0, 0, 1))
|
429 |
-
z = F.pad(z, (2, 2))
|
430 |
-
pad = hl // 2 * 3
|
431 |
-
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
432 |
-
x = ispectro(z, hl, length=le)
|
433 |
-
x = x[..., pad : pad + length]
|
434 |
-
return x
|
435 |
-
|
436 |
-
def _magnitude(self, z):
|
437 |
-
if self.cac:
|
438 |
-
B, C, Fr, T = z.shape
|
439 |
-
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
440 |
-
m = m.reshape(B, C * 2, Fr, T)
|
441 |
-
else: m = z.abs()
|
442 |
-
return m
|
443 |
-
|
444 |
-
def _mask(self, z, m):
|
445 |
-
niters = self.wiener_iters
|
446 |
-
if self.cac:
|
447 |
-
B, S, C, Fr, T = m.shape
|
448 |
-
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
449 |
-
out = torch.view_as_complex(out.contiguous())
|
450 |
-
return out
|
451 |
-
|
452 |
-
if self.training: niters = self.end_iters
|
453 |
-
|
454 |
-
if niters < 0:
|
455 |
-
z = z[:, None]
|
456 |
-
return z / (1e-8 + z.abs()) * m
|
457 |
-
else: return self._wiener(m, z, niters)
|
458 |
-
|
459 |
-
def _wiener(self, mag_out, mix_stft, niters):
|
460 |
-
init = mix_stft.dtype
|
461 |
-
wiener_win_len = 300
|
462 |
-
residual = self.wiener_residual
|
463 |
-
B, S, C, Fq, T = mag_out.shape
|
464 |
-
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
465 |
-
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
466 |
-
|
467 |
-
outs = []
|
468 |
-
|
469 |
-
for sample in range(B):
|
470 |
-
pos = 0
|
471 |
-
out = []
|
472 |
-
|
473 |
-
for pos in range(0, T, wiener_win_len):
|
474 |
-
frame = slice(pos, pos + wiener_win_len)
|
475 |
-
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
476 |
-
out.append(z_out.transpose(-1, -2))
|
477 |
-
|
478 |
-
outs.append(torch.cat(out, dim=0))
|
479 |
-
|
480 |
-
out = torch.view_as_complex(torch.stack(outs, 0))
|
481 |
-
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
482 |
-
|
483 |
-
if residual: out = out[:, :-1]
|
484 |
-
assert list(out.shape) == [B, S, C, Fq, T]
|
485 |
-
return out.to(init)
|
486 |
-
|
487 |
-
def valid_length(self, length):
|
488 |
-
if not self.use_train_segment: return length
|
489 |
-
|
490 |
-
training_length = int(self.segment * self.samplerate)
|
491 |
-
if training_length < length: raise ValueError(translations["length_or_training_length"].format(length=length, training_length=training_length))
|
492 |
-
|
493 |
-
return training_length
|
494 |
-
|
495 |
-
def forward(self, mix):
|
496 |
-
length = mix.shape[-1]
|
497 |
-
length_pre_pad = None
|
498 |
-
|
499 |
-
if self.use_train_segment:
|
500 |
-
if self.training: self.segment = Fraction(mix.shape[-1], self.samplerate)
|
501 |
-
else:
|
502 |
-
training_length = int(self.segment * self.samplerate)
|
503 |
-
|
504 |
-
if mix.shape[-1] < training_length:
|
505 |
-
length_pre_pad = mix.shape[-1]
|
506 |
-
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
507 |
-
|
508 |
-
z = self._spec(mix)
|
509 |
-
mag = self._magnitude(z).to(mix.device)
|
510 |
-
x = mag
|
511 |
-
B, C, Fq, T = x.shape
|
512 |
-
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
513 |
-
std = x.std(dim=(1, 2, 3), keepdim=True)
|
514 |
-
x = (x - mean) / (1e-5 + std)
|
515 |
-
xt = mix
|
516 |
-
meant = xt.mean(dim=(1, 2), keepdim=True)
|
517 |
-
stdt = xt.std(dim=(1, 2), keepdim=True)
|
518 |
-
xt = (xt - meant) / (1e-5 + stdt)
|
519 |
-
|
520 |
-
saved, saved_t, lengths, lengths_t = [], [], [], []
|
521 |
-
|
522 |
-
for idx, encode in enumerate(self.encoder):
|
523 |
-
lengths.append(x.shape[-1])
|
524 |
-
inject = None
|
525 |
-
|
526 |
-
if idx < len(self.tencoder):
|
527 |
-
lengths_t.append(xt.shape[-1])
|
528 |
-
tenc = self.tencoder[idx]
|
529 |
-
xt = tenc(xt)
|
530 |
-
|
531 |
-
if not tenc.empty: saved_t.append(xt)
|
532 |
-
else: inject = xt
|
533 |
-
|
534 |
-
x = encode(x, inject)
|
535 |
-
if idx == 0 and self.freq_emb is not None:
|
536 |
-
frs = torch.arange(x.shape[-2], device=x.device)
|
537 |
-
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
538 |
-
x = x + self.freq_emb_scale * emb
|
539 |
-
|
540 |
-
saved.append(x)
|
541 |
-
|
542 |
-
if self.crosstransformer:
|
543 |
-
if self.bottom_channels:
|
544 |
-
b, c, f, t = x.shape
|
545 |
-
x = rearrange(x, "b c f t-> b c (f t)")
|
546 |
-
x = self.channel_upsampler(x)
|
547 |
-
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
548 |
-
xt = self.channel_upsampler_t(xt)
|
549 |
-
|
550 |
-
x, xt = self.crosstransformer(x, xt)
|
551 |
-
|
552 |
-
if self.bottom_channels:
|
553 |
-
x = rearrange(x, "b c f t-> b c (f t)")
|
554 |
-
x = self.channel_downsampler(x)
|
555 |
-
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
556 |
-
xt = self.channel_downsampler_t(xt)
|
557 |
-
|
558 |
-
for idx, decode in enumerate(self.decoder):
|
559 |
-
skip = saved.pop(-1)
|
560 |
-
x, pre = decode(x, skip, lengths.pop(-1))
|
561 |
-
offset = self.depth - len(self.tdecoder)
|
562 |
-
|
563 |
-
if idx >= offset:
|
564 |
-
tdec = self.tdecoder[idx - offset]
|
565 |
-
length_t = lengths_t.pop(-1)
|
566 |
-
|
567 |
-
if tdec.empty:
|
568 |
-
assert pre.shape[2] == 1, pre.shape
|
569 |
-
pre = pre[:, :, 0]
|
570 |
-
xt, _ = tdec(pre, None, length_t)
|
571 |
-
else:
|
572 |
-
skip = saved_t.pop(-1)
|
573 |
-
xt, _ = tdec(xt, skip, length_t)
|
574 |
-
|
575 |
-
assert len(saved) == 0
|
576 |
-
assert len(lengths_t) == 0
|
577 |
-
assert len(saved_t) == 0
|
578 |
-
|
579 |
-
S = len(self.sources)
|
580 |
-
x = x.view(B, S, -1, Fq, T)
|
581 |
-
x = x * std[:, None] + mean[:, None]
|
582 |
-
device_type = x.device.type
|
583 |
-
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
584 |
-
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
585 |
-
if x_is_other_gpu: x = x.cpu()
|
586 |
-
zout = self._mask(z, x)
|
587 |
-
|
588 |
-
if self.use_train_segment: x = self._ispec(zout, length) if self.training else self._ispec(zout, training_length)
|
589 |
-
else: x = self._ispec(zout, length)
|
590 |
-
|
591 |
-
if x_is_other_gpu: x = x.to(device_load)
|
592 |
-
|
593 |
-
if self.use_train_segment: xt = xt.view(B, S, -1, length) if self.training else xt.view(B, S, -1, training_length)
|
594 |
-
else: xt = xt.view(B, S, -1, length)
|
595 |
-
|
596 |
-
xt = xt * stdt[:, None] + meant[:, None]
|
597 |
-
x = xt + x
|
598 |
-
|
599 |
-
if length_pre_pad: x = x[..., :length_pre_pad]
|
600 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/states.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
import inspect
|
5 |
-
import warnings
|
6 |
-
import functools
|
7 |
-
|
8 |
-
from pathlib import Path
|
9 |
-
|
10 |
-
sys.path.append(os.getcwd())
|
11 |
-
|
12 |
-
from main.configs.config import Config
|
13 |
-
translations = Config().translations
|
14 |
-
|
15 |
-
def load_model(path_or_package, strict=False):
|
16 |
-
if isinstance(path_or_package, dict): package = path_or_package
|
17 |
-
elif isinstance(path_or_package, (str, Path)):
|
18 |
-
with warnings.catch_warnings():
|
19 |
-
warnings.simplefilter("ignore")
|
20 |
-
package = torch.load(path_or_package, map_location="cpu")
|
21 |
-
else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
|
22 |
-
klass = package["klass"]
|
23 |
-
args = package["args"]
|
24 |
-
kwargs = package["kwargs"]
|
25 |
-
if strict: model = klass(*args, **kwargs)
|
26 |
-
else:
|
27 |
-
sig = inspect.signature(klass)
|
28 |
-
for key in list(kwargs):
|
29 |
-
if key not in sig.parameters:
|
30 |
-
warnings.warn(translations["del_parameter"] + key)
|
31 |
-
del kwargs[key]
|
32 |
-
model = klass(*args, **kwargs)
|
33 |
-
state = package["state"]
|
34 |
-
set_state(model, state)
|
35 |
-
return model
|
36 |
-
|
37 |
-
def restore_quantized_state(model, state):
|
38 |
-
assert "meta" in state
|
39 |
-
quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"])
|
40 |
-
quantizer.restore_quantized_state(state)
|
41 |
-
quantizer.detach()
|
42 |
-
|
43 |
-
def set_state(model, state, quantizer=None):
|
44 |
-
if state.get("__quantized"):
|
45 |
-
if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
|
46 |
-
else: restore_quantized_state(model, state)
|
47 |
-
else: model.load_state_dict(state)
|
48 |
-
return state
|
49 |
-
|
50 |
-
def capture_init(init):
|
51 |
-
@functools.wraps(init)
|
52 |
-
def __init__(self, *args, **kwargs):
|
53 |
-
self._init_args_kwargs = (args, kwargs)
|
54 |
-
init(self, *args, **kwargs)
|
55 |
-
return __init__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/utils.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
def center_trim(tensor, reference):
|
4 |
-
ref_size = reference.size(-1) if isinstance(reference, torch.Tensor) else reference
|
5 |
-
delta = tensor.size(-1) - ref_size
|
6 |
-
if delta < 0: raise ValueError(f"tensor > parameter: {delta}.")
|
7 |
-
if delta: tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
|
8 |
-
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/spec_utils.py
DELETED
@@ -1,900 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import six
|
3 |
-
import sys
|
4 |
-
import librosa
|
5 |
-
import tempfile
|
6 |
-
import platform
|
7 |
-
import subprocess
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import soundfile as sf
|
11 |
-
|
12 |
-
from scipy.signal import correlate, hilbert
|
13 |
-
|
14 |
-
sys.path.append(os.getcwd())
|
15 |
-
|
16 |
-
from main.configs.config import Config
|
17 |
-
translations = Config().translations
|
18 |
-
|
19 |
-
OPERATING_SYSTEM = platform.system()
|
20 |
-
SYSTEM_ARCH = platform.platform()
|
21 |
-
SYSTEM_PROC = platform.processor()
|
22 |
-
ARM = "arm"
|
23 |
-
AUTO_PHASE = "Automatic"
|
24 |
-
POSITIVE_PHASE = "Positive Phase"
|
25 |
-
NEGATIVE_PHASE = "Negative Phase"
|
26 |
-
NONE_P = ("None",)
|
27 |
-
LOW_P = ("Shifts: Low",)
|
28 |
-
MED_P = ("Shifts: Medium",)
|
29 |
-
HIGH_P = ("Shifts: High",)
|
30 |
-
VHIGH_P = "Shifts: Very High"
|
31 |
-
MAXIMUM_P = "Shifts: Maximum"
|
32 |
-
BASE_PATH_RUB = sys._MEIPASS if getattr(sys, 'frozen', False) else os.path.dirname(os.path.abspath(__file__))
|
33 |
-
DEVNULL = open(os.devnull, 'w') if six.PY2 else subprocess.DEVNULL
|
34 |
-
MAX_SPEC = "Max Spec"
|
35 |
-
MIN_SPEC = "Min Spec"
|
36 |
-
LIN_ENSE = "Linear Ensemble"
|
37 |
-
MAX_WAV = MAX_SPEC
|
38 |
-
MIN_WAV = MIN_SPEC
|
39 |
-
AVERAGE = "Average"
|
40 |
-
|
41 |
-
progress_value, last_update_time = 0, 0
|
42 |
-
wav_resolution = "sinc_fastest"
|
43 |
-
wav_resolution_float_resampling = wav_resolution
|
44 |
-
|
45 |
-
def crop_center(h1, h2):
|
46 |
-
h1_shape = h1.size()
|
47 |
-
h2_shape = h2.size()
|
48 |
-
|
49 |
-
if h1_shape[3] == h2_shape[3]: return h1
|
50 |
-
elif h1_shape[3] < h2_shape[3]: raise ValueError("h1_shape[3] > h2_shape[3]")
|
51 |
-
|
52 |
-
s_time = (h1_shape[3] - h2_shape[3]) // 2
|
53 |
-
|
54 |
-
h1 = h1[:, :, :, s_time:s_time + h2_shape[3]]
|
55 |
-
return h1
|
56 |
-
|
57 |
-
def preprocess(X_spec):
|
58 |
-
return np.abs(X_spec), np.angle(X_spec)
|
59 |
-
|
60 |
-
def make_padding(width, cropsize, offset):
|
61 |
-
roi_size = cropsize - offset * 2
|
62 |
-
|
63 |
-
if roi_size == 0: roi_size = cropsize
|
64 |
-
return offset, roi_size - (width % roi_size) + offset, roi_size
|
65 |
-
|
66 |
-
def normalize(wave, max_peak=1.0):
|
67 |
-
maxv = np.abs(wave).max()
|
68 |
-
|
69 |
-
if maxv > max_peak: wave *= max_peak / maxv
|
70 |
-
return wave
|
71 |
-
|
72 |
-
def auto_transpose(audio_array):
|
73 |
-
if audio_array.shape[1] == 2: return audio_array.T
|
74 |
-
return audio_array
|
75 |
-
|
76 |
-
def write_array_to_mem(audio_data, subtype):
|
77 |
-
if isinstance(audio_data, np.ndarray):
|
78 |
-
import io
|
79 |
-
|
80 |
-
audio_buffer = io.BytesIO()
|
81 |
-
sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
|
82 |
-
|
83 |
-
audio_buffer.seek(0)
|
84 |
-
return audio_buffer
|
85 |
-
else: return audio_data
|
86 |
-
|
87 |
-
def spectrogram_to_image(spec, mode="magnitude"):
|
88 |
-
if mode == "magnitude": y = np.log10((np.abs(spec) if np.iscomplexobj(spec) else spec)**2 + 1e-8)
|
89 |
-
elif mode == "phase": y = np.angle(spec) if np.iscomplexobj(spec) else spec
|
90 |
-
|
91 |
-
y -= y.min()
|
92 |
-
y *= 255 / y.max()
|
93 |
-
img = np.uint8(y)
|
94 |
-
|
95 |
-
if y.ndim == 3:
|
96 |
-
img = img.transpose(1, 2, 0)
|
97 |
-
img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
|
98 |
-
|
99 |
-
return img
|
100 |
-
|
101 |
-
def reduce_vocal_aggressively(X, y, softmask):
|
102 |
-
y_mag_tmp = np.abs(y)
|
103 |
-
v_mag_tmp = np.abs(X - y)
|
104 |
-
|
105 |
-
return np.clip(y_mag_tmp - v_mag_tmp * (v_mag_tmp > y_mag_tmp) * softmask, 0, np.inf) * np.exp(1.0j * np.angle(y))
|
106 |
-
|
107 |
-
def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
|
108 |
-
mask = y_mask
|
109 |
-
|
110 |
-
try:
|
111 |
-
if min_range < fade_size * 2: raise ValueError("min_range >= fade_size * 2")
|
112 |
-
|
113 |
-
idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
|
114 |
-
start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
|
115 |
-
end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
|
116 |
-
artifact_idx = np.where(end_idx - start_idx > min_range)[0]
|
117 |
-
weight = np.zeros_like(y_mask)
|
118 |
-
|
119 |
-
if len(artifact_idx) > 0:
|
120 |
-
start_idx = start_idx[artifact_idx]
|
121 |
-
end_idx = end_idx[artifact_idx]
|
122 |
-
old_e = None
|
123 |
-
|
124 |
-
for s, e in zip(start_idx, end_idx):
|
125 |
-
if old_e is not None and s - old_e < fade_size: s = old_e - fade_size * 2
|
126 |
-
|
127 |
-
if s != 0: weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
|
128 |
-
else: s -= fade_size
|
129 |
-
|
130 |
-
if e != y_mask.shape[2]: weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
|
131 |
-
else: e += fade_size
|
132 |
-
|
133 |
-
weight[:, :, s + fade_size : e - fade_size] = 1
|
134 |
-
old_e = e
|
135 |
-
|
136 |
-
v_mask = 1 - y_mask
|
137 |
-
y_mask += weight * v_mask
|
138 |
-
mask = y_mask
|
139 |
-
except Exception as e:
|
140 |
-
import traceback
|
141 |
-
print(translations["not_success"], f'{type(e).__name__}: "{e}"\n{traceback.format_exc()}"')
|
142 |
-
|
143 |
-
return mask
|
144 |
-
|
145 |
-
def align_wave_head_and_tail(a, b):
|
146 |
-
l = min([a[0].size, b[0].size])
|
147 |
-
return a[:l, :l], b[:l, :l]
|
148 |
-
|
149 |
-
def convert_channels(spec, mp, band):
|
150 |
-
cc = mp.param["band"][band].get("convert_channels")
|
151 |
-
|
152 |
-
if "mid_side_c" == cc:
|
153 |
-
spec_left = np.add(spec[0], spec[1] * 0.25)
|
154 |
-
spec_right = np.subtract(spec[1], spec[0] * 0.25)
|
155 |
-
elif "mid_side" == cc:
|
156 |
-
spec_left = np.add(spec[0], spec[1]) / 2
|
157 |
-
spec_right = np.subtract(spec[0], spec[1])
|
158 |
-
elif "stereo_n" == cc:
|
159 |
-
spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
|
160 |
-
spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
|
161 |
-
else: return spec
|
162 |
-
|
163 |
-
return np.asfortranarray([spec_left, spec_right])
|
164 |
-
|
165 |
-
def combine_spectrograms(specs, mp, is_v51_model=False):
|
166 |
-
l = min([specs[i].shape[2] for i in specs])
|
167 |
-
spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
|
168 |
-
offset = 0
|
169 |
-
bands_n = len(mp.param["band"])
|
170 |
-
|
171 |
-
for d in range(1, bands_n + 1):
|
172 |
-
h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
|
173 |
-
spec_c[:, offset : offset + h, :l] = specs[d][:, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l]
|
174 |
-
offset += h
|
175 |
-
|
176 |
-
if offset > mp.param["bins"]: raise ValueError("offset > mp.param['bins']")
|
177 |
-
|
178 |
-
if mp.param["pre_filter_start"] > 0:
|
179 |
-
if is_v51_model: spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
180 |
-
else:
|
181 |
-
if bands_n == 1: spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
182 |
-
else:
|
183 |
-
import math
|
184 |
-
gp = 1
|
185 |
-
|
186 |
-
for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
|
187 |
-
g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
|
188 |
-
gp = g
|
189 |
-
spec_c[:, b, :] *= g
|
190 |
-
|
191 |
-
return np.asfortranarray(spec_c)
|
192 |
-
|
193 |
-
def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
|
194 |
-
if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
|
195 |
-
|
196 |
-
if not is_v51_model:
|
197 |
-
if mp.param["reverse"]:
|
198 |
-
wave_left = np.flip(np.asfortranarray(wave[0]))
|
199 |
-
wave_right = np.flip(np.asfortranarray(wave[1]))
|
200 |
-
elif mp.param["mid_side"]:
|
201 |
-
wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
|
202 |
-
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
|
203 |
-
elif mp.param["mid_side_b2"]:
|
204 |
-
wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
|
205 |
-
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
|
206 |
-
else:
|
207 |
-
wave_left = np.asfortranarray(wave[0])
|
208 |
-
wave_right = np.asfortranarray(wave[1])
|
209 |
-
else:
|
210 |
-
wave_left = np.asfortranarray(wave[0])
|
211 |
-
wave_right = np.asfortranarray(wave[1])
|
212 |
-
|
213 |
-
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
|
214 |
-
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
|
215 |
-
|
216 |
-
spec = np.asfortranarray([spec_left, spec_right])
|
217 |
-
|
218 |
-
if is_v51_model: spec = convert_channels(spec, mp, band)
|
219 |
-
return spec
|
220 |
-
|
221 |
-
def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
|
222 |
-
spec_left = np.asfortranarray(spec[0])
|
223 |
-
spec_right = np.asfortranarray(spec[1])
|
224 |
-
|
225 |
-
wave_left = librosa.istft(spec_left, hop_length=hop_length)
|
226 |
-
wave_right = librosa.istft(spec_right, hop_length=hop_length)
|
227 |
-
|
228 |
-
if is_v51_model:
|
229 |
-
cc = mp.param["band"][band].get("convert_channels")
|
230 |
-
|
231 |
-
if "mid_side_c" == cc: return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)])
|
232 |
-
elif "mid_side" == cc: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
233 |
-
elif "stereo_n" == cc: return np.asfortranarray([np.subtract(wave_left, wave_right * 0.25), np.subtract(wave_right, wave_left * 0.25)])
|
234 |
-
else:
|
235 |
-
if mp.param["reverse"]: return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
|
236 |
-
elif mp.param["mid_side"]: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
237 |
-
elif mp.param["mid_side_b2"]: return np.asfortranarray([np.add(wave_right / 1.25, 0.4 * wave_left), np.subtract(wave_left / 1.25, 0.4 * wave_right)])
|
238 |
-
|
239 |
-
return np.asfortranarray([wave_left, wave_right])
|
240 |
-
|
241 |
-
def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False):
|
242 |
-
bands_n = len(mp.param["band"])
|
243 |
-
offset = 0
|
244 |
-
|
245 |
-
for d in range(1, bands_n + 1):
|
246 |
-
bp = mp.param["band"][d]
|
247 |
-
spec_s = np.zeros(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
|
248 |
-
h = bp["crop_stop"] - bp["crop_start"]
|
249 |
-
spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
|
250 |
-
offset += h
|
251 |
-
|
252 |
-
if d == bands_n:
|
253 |
-
if extra_bins_h:
|
254 |
-
max_bin = bp["n_fft"] // 2
|
255 |
-
spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
|
256 |
-
|
257 |
-
if bp["hpf_start"] > 0:
|
258 |
-
if is_v51_model: spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
259 |
-
else: spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
260 |
-
|
261 |
-
wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model) if bands_n == 1 else np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
|
262 |
-
else:
|
263 |
-
sr = mp.param["band"][d + 1]["sr"]
|
264 |
-
if d == 1:
|
265 |
-
if is_v51_model: spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
266 |
-
else: spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
267 |
-
|
268 |
-
try:
|
269 |
-
wave = librosa.resample(spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
270 |
-
except ValueError as e:
|
271 |
-
print(f"{translations['resample_error']}: {e}")
|
272 |
-
print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
|
273 |
-
else:
|
274 |
-
if is_v51_model:
|
275 |
-
spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
276 |
-
spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
277 |
-
else:
|
278 |
-
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
279 |
-
spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
280 |
-
|
281 |
-
try:
|
282 |
-
wave = librosa.resample(np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
283 |
-
except ValueError as e:
|
284 |
-
print(f"{translations['resample_error']}: {e}")
|
285 |
-
print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
|
286 |
-
|
287 |
-
return wave
|
288 |
-
|
289 |
-
def get_lp_filter_mask(n_bins, bin_start, bin_stop):
|
290 |
-
return np.concatenate([np.ones((bin_start - 1, 1)), np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], np.zeros((n_bins - bin_stop, 1))], axis=0)
|
291 |
-
|
292 |
-
def get_hp_filter_mask(n_bins, bin_start, bin_stop):
|
293 |
-
return np.concatenate([np.zeros((bin_stop + 1, 1)), np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], np.ones((n_bins - bin_start - 2, 1))], axis=0)
|
294 |
-
|
295 |
-
def fft_lp_filter(spec, bin_start, bin_stop):
|
296 |
-
g = 1.0
|
297 |
-
|
298 |
-
for b in range(bin_start, bin_stop):
|
299 |
-
g -= 1 / (bin_stop - bin_start)
|
300 |
-
spec[:, b, :] = g * spec[:, b, :]
|
301 |
-
|
302 |
-
spec[:, bin_stop:, :] *= 0
|
303 |
-
return spec
|
304 |
-
|
305 |
-
def fft_hp_filter(spec, bin_start, bin_stop):
|
306 |
-
g = 1.0
|
307 |
-
|
308 |
-
for b in range(bin_start, bin_stop, -1):
|
309 |
-
g -= 1 / (bin_start - bin_stop)
|
310 |
-
spec[:, b, :] = g * spec[:, b, :]
|
311 |
-
|
312 |
-
spec[:, 0 : bin_stop + 1, :] *= 0
|
313 |
-
return spec
|
314 |
-
|
315 |
-
def spectrogram_to_wave_old(spec, hop_length=1024):
|
316 |
-
if spec.ndim == 2: wave = librosa.istft(spec, hop_length=hop_length)
|
317 |
-
elif spec.ndim == 3: wave = np.asfortranarray([librosa.istft(np.asfortranarray(spec[0]), hop_length=hop_length), librosa.istft(np.asfortranarray(spec[1]), hop_length=hop_length)])
|
318 |
-
|
319 |
-
return wave
|
320 |
-
|
321 |
-
def wave_to_spectrogram_old(wave, hop_length, n_fft):
|
322 |
-
return np.asfortranarray([librosa.stft(np.asfortranarray(wave[0]), n_fft=n_fft, hop_length=hop_length), librosa.stft(np.asfortranarray(wave[1]), n_fft=n_fft, hop_length=hop_length)])
|
323 |
-
|
324 |
-
def mirroring(a, spec_m, input_high_end, mp):
|
325 |
-
if "mirroring" == a:
|
326 |
-
mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1) * np.exp(1.0j * np.angle(input_high_end))
|
327 |
-
|
328 |
-
return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
|
329 |
-
|
330 |
-
if "mirroring2" == a:
|
331 |
-
mi = np.multiply(np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1), input_high_end * 1.7)
|
332 |
-
|
333 |
-
return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
|
334 |
-
|
335 |
-
def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
|
336 |
-
aggr = aggressiveness["value"] * 2
|
337 |
-
|
338 |
-
if aggr != 0:
|
339 |
-
if is_non_accom_stem:
|
340 |
-
aggr = 1 - aggr
|
341 |
-
|
342 |
-
if np.any(aggr > 10) or np.any(aggr < -10): print(f"{translations['warnings']}: {aggr}")
|
343 |
-
|
344 |
-
aggr = [aggr, aggr]
|
345 |
-
|
346 |
-
if aggressiveness["aggr_correction"] is not None:
|
347 |
-
aggr[0] += aggressiveness["aggr_correction"]["left"]
|
348 |
-
aggr[1] += aggressiveness["aggr_correction"]["right"]
|
349 |
-
|
350 |
-
for ch in range(2):
|
351 |
-
mask[ch, : aggressiveness["split_bin"]] = np.power(mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3)
|
352 |
-
mask[ch, aggressiveness["split_bin"] :] = np.power(mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch])
|
353 |
-
|
354 |
-
return mask
|
355 |
-
|
356 |
-
def stft(wave, nfft, hl):
|
357 |
-
return np.asfortranarray([librosa.stft(np.asfortranarray(wave[0]), n_fft=nfft, hop_length=hl), librosa.stft(np.asfortranarray(wave[1]), n_fft=nfft, hop_length=hl)])
|
358 |
-
|
359 |
-
def istft(spec, hl):
|
360 |
-
return np.asfortranarray([librosa.istft(np.asfortranarray(spec[0]), hop_length=hl), librosa.istft(np.asfortranarray(spec[1]), hop_length=hl)])
|
361 |
-
|
362 |
-
def spec_effects(wave, algorithm="Default", value=None):
|
363 |
-
if np.isnan(wave).any() or np.isinf(wave).any(): print(f"{translations['warnings_2']}: {wave.shape}")
|
364 |
-
spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
|
365 |
-
|
366 |
-
if algorithm == "Min_Mag": wave = istft(np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0]), 1024)
|
367 |
-
elif algorithm == "Max_Mag": wave = istft(np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0]), 1024)
|
368 |
-
elif algorithm == "Default": wave = (wave[1] * value) + (wave[0] * (1 - value))
|
369 |
-
elif algorithm == "Invert_p":
|
370 |
-
X_mag, y_mag = np.abs(spec[0]), np.abs(spec[1])
|
371 |
-
wave = istft(spec[1] - np.where(X_mag >= y_mag, X_mag, y_mag) * np.exp(1.0j * np.angle(spec[0])), 1024)
|
372 |
-
|
373 |
-
return wave
|
374 |
-
|
375 |
-
def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
|
376 |
-
wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
|
377 |
-
if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
|
378 |
-
|
379 |
-
return wave
|
380 |
-
|
381 |
-
def wave_to_spectrogram_no_mp(wave):
|
382 |
-
spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
|
383 |
-
|
384 |
-
if spec.ndim == 1: spec = np.asfortranarray([spec, spec])
|
385 |
-
return spec
|
386 |
-
|
387 |
-
def invert_audio(specs, invert_p=True):
|
388 |
-
ln = min([specs[0].shape[2], specs[1].shape[2]])
|
389 |
-
specs[0] = specs[0][:, :, :ln]
|
390 |
-
specs[1] = specs[1][:, :, :ln]
|
391 |
-
|
392 |
-
if invert_p:
|
393 |
-
X_mag, y_mag = np.abs(specs[0]), np.abs(specs[1])
|
394 |
-
v_spec = specs[1] - np.where(X_mag >= y_mag, X_mag, y_mag) * np.exp(1.0j * np.angle(specs[0]))
|
395 |
-
else:
|
396 |
-
specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
|
397 |
-
v_spec = specs[0] - specs[1]
|
398 |
-
|
399 |
-
return v_spec
|
400 |
-
|
401 |
-
def invert_stem(mixture, stem):
|
402 |
-
return -spectrogram_to_wave_no_mp(invert_audio([wave_to_spectrogram_no_mp(mixture), wave_to_spectrogram_no_mp(stem)])).T
|
403 |
-
|
404 |
-
def ensembling(a, inputs, is_wavs=False):
|
405 |
-
for i in range(1, len(inputs)):
|
406 |
-
if i == 1: input = inputs[0]
|
407 |
-
|
408 |
-
if is_wavs:
|
409 |
-
ln = min([input.shape[1], inputs[i].shape[1]])
|
410 |
-
input = input[:, :ln]
|
411 |
-
inputs[i] = inputs[i][:, :ln]
|
412 |
-
else:
|
413 |
-
ln = min([input.shape[2], inputs[i].shape[2]])
|
414 |
-
input = input[:, :, :ln]
|
415 |
-
inputs[i] = inputs[i][:, :, :ln]
|
416 |
-
|
417 |
-
if MIN_SPEC == a: input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
|
418 |
-
if MAX_SPEC == a: input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
|
419 |
-
|
420 |
-
return input
|
421 |
-
|
422 |
-
def ensemble_for_align(waves):
|
423 |
-
specs = []
|
424 |
-
|
425 |
-
for wav in waves:
|
426 |
-
spec = wave_to_spectrogram_no_mp(wav.T)
|
427 |
-
specs.append(spec)
|
428 |
-
|
429 |
-
wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
|
430 |
-
wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
|
431 |
-
|
432 |
-
return wav_aligned
|
433 |
-
|
434 |
-
def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False):
|
435 |
-
wavs_ = []
|
436 |
-
|
437 |
-
if algorithm == AVERAGE:
|
438 |
-
output = average_audio(audio_input)
|
439 |
-
samplerate = 44100
|
440 |
-
else:
|
441 |
-
specs = []
|
442 |
-
|
443 |
-
for i in range(len(audio_input)):
|
444 |
-
wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
|
445 |
-
wavs_.append(wave)
|
446 |
-
specs.append( wave if is_wave else wave_to_spectrogram_no_mp(wave))
|
447 |
-
|
448 |
-
wave_shapes = [w.shape[1] for w in wavs_]
|
449 |
-
target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
|
450 |
-
|
451 |
-
output = ensembling(algorithm, specs, is_wavs=True) if is_wave else spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
|
452 |
-
output = to_shape(output, target_shape.shape)
|
453 |
-
|
454 |
-
sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
|
455 |
-
|
456 |
-
def to_shape(x, target_shape):
|
457 |
-
padding_list = []
|
458 |
-
|
459 |
-
for x_dim, target_dim in zip(x.shape, target_shape):
|
460 |
-
padding_list.append((0, target_dim - x_dim))
|
461 |
-
|
462 |
-
return np.pad(x, tuple(padding_list), mode="constant")
|
463 |
-
|
464 |
-
def to_shape_minimize(x, target_shape):
|
465 |
-
padding_list = []
|
466 |
-
|
467 |
-
for x_dim, target_dim in zip(x.shape, target_shape):
|
468 |
-
padding_list.append((0, target_dim - x_dim))
|
469 |
-
|
470 |
-
return np.pad(x, tuple(padding_list), mode="constant")
|
471 |
-
|
472 |
-
def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
|
473 |
-
if len(audio.shape) == 2:
|
474 |
-
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
475 |
-
audio = audio[channel]
|
476 |
-
|
477 |
-
for i in range(0, len(audio), frame_length):
|
478 |
-
if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold: return (i / sr) * 1000
|
479 |
-
|
480 |
-
return (len(audio) / sr) * 1000
|
481 |
-
|
482 |
-
def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024):
|
483 |
-
def find_silence_end(audio):
|
484 |
-
if len(audio.shape) == 2:
|
485 |
-
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
486 |
-
audio_mono = audio[channel]
|
487 |
-
else: audio_mono = audio
|
488 |
-
|
489 |
-
for i in range(0, len(audio_mono), frame_length):
|
490 |
-
if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold: return i
|
491 |
-
|
492 |
-
return len(audio_mono)
|
493 |
-
|
494 |
-
ref_silence_end = find_silence_end(reference_audio)
|
495 |
-
target_silence_end = find_silence_end(target_audio)
|
496 |
-
silence_difference = ref_silence_end - target_silence_end
|
497 |
-
|
498 |
-
try:
|
499 |
-
silence_difference_p = ((ref_silence_end / 44100) * 1000) - ((target_silence_end / 44100) * 1000)
|
500 |
-
except Exception as e:
|
501 |
-
pass
|
502 |
-
|
503 |
-
if silence_difference > 0: return np.hstack((np.zeros((target_audio.shape[0], silence_difference))if len(target_audio.shape) == 2 else np.zeros(silence_difference), target_audio))
|
504 |
-
elif silence_difference < 0: return target_audio[:, -silence_difference:]if len(target_audio.shape) == 2 else target_audio[-silence_difference:]
|
505 |
-
else: return target_audio
|
506 |
-
|
507 |
-
def match_array_shapes(array_1, array_2, is_swap=False):
|
508 |
-
|
509 |
-
if is_swap: array_1, array_2 = array_1.T, array_2.T
|
510 |
-
|
511 |
-
if array_1.shape[1] > array_2.shape[1]: array_1 = array_1[:, : array_2.shape[1]]
|
512 |
-
elif array_1.shape[1] < array_2.shape[1]:
|
513 |
-
padding = array_2.shape[1] - array_1.shape[1]
|
514 |
-
array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
|
515 |
-
|
516 |
-
if is_swap: array_1, array_2 = array_1.T, array_2.T
|
517 |
-
|
518 |
-
return array_1
|
519 |
-
|
520 |
-
def match_mono_array_shapes(array_1, array_2):
|
521 |
-
if len(array_1) > len(array_2): array_1 = array_1[: len(array_2)]
|
522 |
-
elif len(array_1) < len(array_2):
|
523 |
-
padding = len(array_2) - len(array_1)
|
524 |
-
array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
|
525 |
-
|
526 |
-
return array_1
|
527 |
-
|
528 |
-
def change_pitch_semitones(y, sr, semitone_shift):
|
529 |
-
factor = 2 ** (semitone_shift / 12)
|
530 |
-
y_pitch_tuned = []
|
531 |
-
|
532 |
-
for y_channel in y:
|
533 |
-
y_pitch_tuned.append(librosa.resample(y_channel, orig_sr=sr, target_sr=sr * factor, res_type=wav_resolution_float_resampling))
|
534 |
-
|
535 |
-
y_pitch_tuned = np.array(y_pitch_tuned)
|
536 |
-
new_sr = sr * factor
|
537 |
-
|
538 |
-
return y_pitch_tuned, new_sr
|
539 |
-
|
540 |
-
def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True):
|
541 |
-
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
|
542 |
-
if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
|
543 |
-
|
544 |
-
if not is_time_correction: wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0]
|
545 |
-
else:
|
546 |
-
if is_pitch: wav_1, wav_2 = pitch_shift(wav[0], sr, rate, rbargs=None), pitch_shift(wav[1], sr, rate, rbargs=None)
|
547 |
-
else: wav_1, wav_2 = time_stretch(wav[0], sr, rate, rbargs=None), time_stretch(wav[1], sr, rate, rbargs=None)
|
548 |
-
|
549 |
-
if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
|
550 |
-
if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
|
551 |
-
|
552 |
-
wav_mix = np.asfortranarray([wav_1, wav_2])
|
553 |
-
|
554 |
-
sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
|
555 |
-
save_format(export_path)
|
556 |
-
|
557 |
-
|
558 |
-
def average_audio(audio):
|
559 |
-
waves, wave_shapes, final_waves = [], [], []
|
560 |
-
|
561 |
-
for i in range(len(audio)):
|
562 |
-
wave = librosa.load(audio[i], sr=44100, mono=False)
|
563 |
-
waves.append(wave[0])
|
564 |
-
wave_shapes.append(wave[0].shape[1])
|
565 |
-
|
566 |
-
wave_shapes_index = wave_shapes.index(max(wave_shapes))
|
567 |
-
target_shape = waves[wave_shapes_index]
|
568 |
-
|
569 |
-
waves.pop(wave_shapes_index)
|
570 |
-
final_waves.append(target_shape)
|
571 |
-
|
572 |
-
for n_array in waves:
|
573 |
-
wav_target = to_shape(n_array, target_shape.shape)
|
574 |
-
final_waves.append(wav_target)
|
575 |
-
|
576 |
-
waves = sum(final_waves)
|
577 |
-
return waves / len(audio)
|
578 |
-
|
579 |
-
def average_dual_sources(wav_1, wav_2, value):
|
580 |
-
if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
|
581 |
-
if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
|
582 |
-
|
583 |
-
return (wav_1 * value) + (wav_2 * (1 - value))
|
584 |
-
|
585 |
-
def reshape_sources(wav_1, wav_2):
|
586 |
-
if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
|
587 |
-
|
588 |
-
if wav_1.shape < wav_2.shape:
|
589 |
-
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
590 |
-
wav_2 = wav_2[:, :ln]
|
591 |
-
|
592 |
-
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
593 |
-
wav_1 = wav_1[:, :ln]
|
594 |
-
wav_2 = wav_2[:, :ln]
|
595 |
-
|
596 |
-
return wav_2
|
597 |
-
|
598 |
-
def reshape_sources_ref(wav_1_shape, wav_2):
|
599 |
-
if wav_1_shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1_shape)
|
600 |
-
return wav_2
|
601 |
-
|
602 |
-
def combine_arrarys(audio_sources, is_swap=False):
|
603 |
-
source = np.zeros_like(max(audio_sources, key=np.size))
|
604 |
-
|
605 |
-
for v in audio_sources:
|
606 |
-
v = match_array_shapes(v, source, is_swap=is_swap)
|
607 |
-
source += v
|
608 |
-
|
609 |
-
return source
|
610 |
-
|
611 |
-
def combine_audio(paths, audio_file_base=None, wav_type_set="FLOAT", save_format=None):
|
612 |
-
source = combine_arrarys([load_audio(i) for i in paths])
|
613 |
-
save_path = f"{audio_file_base}_combined.wav"
|
614 |
-
sf.write(save_path, source.T, 44100, subtype=wav_type_set)
|
615 |
-
save_format(save_path)
|
616 |
-
|
617 |
-
def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
|
618 |
-
return combine_arrarys([inst_source * (1 - reduction_rate), voc_source], is_swap=True)
|
619 |
-
|
620 |
-
def organize_inputs(inputs):
|
621 |
-
input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
|
622 |
-
|
623 |
-
for i in inputs:
|
624 |
-
if i.endswith("_(Vocals).wav"): input_list["reference"] = i
|
625 |
-
elif "_RVC_" in i: input_list["target"] = i
|
626 |
-
elif i.endswith("reverbed_stem.wav"): input_list["reverb"] = i
|
627 |
-
elif i.endswith("_(Instrumental).wav"): input_list["inst"] = i
|
628 |
-
|
629 |
-
return input_list
|
630 |
-
|
631 |
-
def check_if_phase_inverted(wav1, wav2, is_mono=False):
|
632 |
-
if not is_mono:
|
633 |
-
wav1 = np.mean(wav1, axis=0)
|
634 |
-
wav2 = np.mean(wav2, axis=0)
|
635 |
-
|
636 |
-
return np.corrcoef(wav1[:1000], wav2[:1000])[0, 1] < 0
|
637 |
-
|
638 |
-
def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_save_aligned, command_Text, save_format, align_window, align_intro_val, db_analysis, set_progress_bar, phase_option, phase_shifts, is_match_silence, is_spec_match):
|
639 |
-
global progress_value
|
640 |
-
progress_value = 0
|
641 |
-
is_mono = False
|
642 |
-
|
643 |
-
def get_diff(a, b):
|
644 |
-
return np.correlate(a, b, "full").argmax() - (b.shape[0] - 1)
|
645 |
-
|
646 |
-
def progress_bar(length):
|
647 |
-
global progress_value
|
648 |
-
progress_value += 1
|
649 |
-
|
650 |
-
if (0.90 / length * progress_value) >= 0.9: length = progress_value + 1
|
651 |
-
set_progress_bar(0.1, (0.9 / length * progress_value))
|
652 |
-
|
653 |
-
wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
|
654 |
-
wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
|
655 |
-
|
656 |
-
if wav1.ndim == 1 and wav2.ndim == 1: is_mono = True
|
657 |
-
elif wav1.ndim == 1: wav1 = np.asfortranarray([wav1, wav1])
|
658 |
-
elif wav2.ndim == 1: wav2 = np.asfortranarray([wav2, wav2])
|
659 |
-
|
660 |
-
if phase_option == AUTO_PHASE:
|
661 |
-
if check_if_phase_inverted(wav1, wav2, is_mono=is_mono): wav2 = -wav2
|
662 |
-
elif phase_option == POSITIVE_PHASE: wav2 = +wav2
|
663 |
-
elif phase_option == NEGATIVE_PHASE: wav2 = -wav2
|
664 |
-
|
665 |
-
if is_match_silence: wav2 = adjust_leading_silence(wav2, wav1)
|
666 |
-
|
667 |
-
wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
|
668 |
-
wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
|
669 |
-
|
670 |
-
if not is_mono:
|
671 |
-
wav1 = wav1.transpose()
|
672 |
-
wav2 = wav2.transpose()
|
673 |
-
|
674 |
-
wav2_org = wav2.copy()
|
675 |
-
|
676 |
-
command_Text(translations["process_file"])
|
677 |
-
seconds_length = min(wav1_length, wav2_length)
|
678 |
-
wav2_aligned_sources = []
|
679 |
-
|
680 |
-
for sec_len in align_intro_val:
|
681 |
-
sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
|
682 |
-
index = sr1 * sec_seg
|
683 |
-
|
684 |
-
if is_mono:
|
685 |
-
samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
|
686 |
-
diff = get_diff(samp1, samp2)
|
687 |
-
else:
|
688 |
-
index = sr1 * sec_seg
|
689 |
-
samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
|
690 |
-
samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1]
|
691 |
-
diff, _ = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
|
692 |
-
|
693 |
-
if diff > 0: wav2_aligned = np.append(np.zeros(diff) if is_mono else np.zeros((diff, 2)), wav2_org, axis=0)
|
694 |
-
elif diff < 0: wav2_aligned = wav2_org[-diff:]
|
695 |
-
else: wav2_aligned = wav2_org
|
696 |
-
|
697 |
-
if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources): wav2_aligned_sources.append(wav2_aligned)
|
698 |
-
|
699 |
-
unique_sources = len(wav2_aligned_sources)
|
700 |
-
sub_mapper_big_mapper = {}
|
701 |
-
|
702 |
-
for s in wav2_aligned_sources:
|
703 |
-
wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True)
|
704 |
-
|
705 |
-
if align_window:
|
706 |
-
wav_sub = time_correction(wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts)
|
707 |
-
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{np.abs(wav_sub).mean(): wav_sub}}
|
708 |
-
else:
|
709 |
-
wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
|
710 |
-
|
711 |
-
for db_adjustment in db_analysis[1]:
|
712 |
-
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{np.abs(wav_sub).mean(): wav1 - (wav2_aligned * (10 ** (db_adjustment / 20)))}}
|
713 |
-
|
714 |
-
wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values())) if is_spec_match and len(list(sub_mapper_big_mapper.values())) >= 2 else ensemble_wav(list(sub_mapper_big_mapper.values()))
|
715 |
-
wav_sub = np.clip(wav_sub, -1, +1)
|
716 |
-
|
717 |
-
command_Text(translations["save_instruments"])
|
718 |
-
|
719 |
-
if is_save_aligned or is_spec_match:
|
720 |
-
wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True)
|
721 |
-
wav2_aligned = wav1 - wav_sub
|
722 |
-
|
723 |
-
if is_spec_match:
|
724 |
-
if wav1.ndim == 1 and wav2.ndim == 1:
|
725 |
-
wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
|
726 |
-
wav1 = np.asfortranarray([wav1, wav1]).T
|
727 |
-
|
728 |
-
wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
|
729 |
-
wav_sub = wav1 - wav2_aligned
|
730 |
-
|
731 |
-
if is_save_aligned:
|
732 |
-
sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
|
733 |
-
save_format(file2_aligned)
|
734 |
-
|
735 |
-
sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
|
736 |
-
save_format(file_subtracted)
|
737 |
-
|
738 |
-
def phase_shift_hilbert(signal, degree):
|
739 |
-
analytic_signal = hilbert(signal)
|
740 |
-
return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag
|
741 |
-
|
742 |
-
def get_phase_shifted_tracks(track, phase_shift):
|
743 |
-
if phase_shift == 180: return [track, -track]
|
744 |
-
|
745 |
-
step = phase_shift
|
746 |
-
end = 180 - (180 % step) if 180 % step == 0 else 181
|
747 |
-
phase_range = range(step, end, step)
|
748 |
-
flipped_list = [track, -track]
|
749 |
-
|
750 |
-
for i in phase_range:
|
751 |
-
flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)])
|
752 |
-
|
753 |
-
return flipped_list
|
754 |
-
|
755 |
-
def time_correction(mix, instrumental, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P):
|
756 |
-
def align_tracks(track1, track2):
|
757 |
-
shifted_tracks = {}
|
758 |
-
track2 = track2 * np.power(10, db_analysis[0] / 20)
|
759 |
-
track2_flipped = [track2] if phase_shifts == 190 else get_phase_shifted_tracks(track2, phase_shifts)
|
760 |
-
|
761 |
-
for db_adjustment in db_analysis[1]:
|
762 |
-
for t in track2_flipped:
|
763 |
-
track2_adjusted = t * (10 ** (db_adjustment / 20))
|
764 |
-
track2_shifted = np.roll(track2_adjusted, shift=np.argmax(np.abs(correlate(track1, track2_adjusted))) - (len(track1) - 1))
|
765 |
-
shifted_tracks[np.abs(track1 - track2_shifted).mean()] = track2_shifted
|
766 |
-
|
767 |
-
return shifted_tracks[min(shifted_tracks.keys())]
|
768 |
-
|
769 |
-
assert mix.shape == instrumental.shape, translations["assert"].format(mixshape=mix.shape, instrumentalshape=instrumental.shape)
|
770 |
-
seconds_length = seconds_length // 2
|
771 |
-
|
772 |
-
sub_mapper = {}
|
773 |
-
progress_update_interval, total_iterations = 120, 0
|
774 |
-
|
775 |
-
if len(align_window) > 2: progress_update_interval = 320
|
776 |
-
|
777 |
-
for secs in align_window:
|
778 |
-
step = secs / 2
|
779 |
-
window_size = int(sr * secs)
|
780 |
-
step_size = int(sr * step)
|
781 |
-
|
782 |
-
if len(mix.shape) == 1: total_iterations += ((len(range(0, len(mix) - window_size, step_size)) // progress_update_interval) * unique_sources)
|
783 |
-
else: total_iterations += ((len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2 // progress_update_interval) * unique_sources)
|
784 |
-
|
785 |
-
for secs in align_window:
|
786 |
-
sub = np.zeros_like(mix)
|
787 |
-
divider = np.zeros_like(mix)
|
788 |
-
window_size = int(sr * secs)
|
789 |
-
step_size = int(sr * secs / 2)
|
790 |
-
window = np.hanning(window_size)
|
791 |
-
|
792 |
-
if len(mix.shape) == 1:
|
793 |
-
counter = 0
|
794 |
-
|
795 |
-
for i in range(0, len(mix) - window_size, step_size):
|
796 |
-
counter += 1
|
797 |
-
if counter % progress_update_interval == 0: progress_bar(total_iterations)
|
798 |
-
|
799 |
-
window_mix = mix[i : i + window_size] * window
|
800 |
-
window_instrumental = instrumental[i : i + window_size] * window
|
801 |
-
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
802 |
-
sub[i : i + window_size] += window_mix - window_instrumental_aligned
|
803 |
-
divider[i : i + window_size] += window
|
804 |
-
else:
|
805 |
-
counter = 0
|
806 |
-
|
807 |
-
for ch in range(mix.shape[1]):
|
808 |
-
for i in range(0, len(mix[:, ch]) - window_size, step_size):
|
809 |
-
counter += 1
|
810 |
-
|
811 |
-
if counter % progress_update_interval == 0: progress_bar(total_iterations)
|
812 |
-
|
813 |
-
window_mix = mix[i : i + window_size, ch] * window
|
814 |
-
window_instrumental = instrumental[i : i + window_size, ch] * window
|
815 |
-
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
816 |
-
sub[i : i + window_size, ch] += window_mix - window_instrumental_aligned
|
817 |
-
divider[i : i + window_size, ch] += window
|
818 |
-
|
819 |
-
return ensemble_wav(list({**sub_mapper, **{np.abs(sub).mean(): np.where(divider > 1e-6, sub / divider, sub)}}.values()), split_size=12)
|
820 |
-
|
821 |
-
def ensemble_wav(waveforms, split_size=240):
|
822 |
-
waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)}
|
823 |
-
final_waveform = []
|
824 |
-
for third_idx in range(split_size):
|
825 |
-
final_waveform.append(waveform_thirds[np.argmin([np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))])][third_idx])
|
826 |
-
|
827 |
-
return np.concatenate(final_waveform)
|
828 |
-
|
829 |
-
def ensemble_wav_min(waveforms):
|
830 |
-
for i in range(1, len(waveforms)):
|
831 |
-
if i == 1: wave = waveforms[0]
|
832 |
-
ln = min(len(wave), len(waveforms[i]))
|
833 |
-
wave = wave[:ln]
|
834 |
-
waveforms[i] = waveforms[i][:ln]
|
835 |
-
wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
|
836 |
-
|
837 |
-
return wave
|
838 |
-
|
839 |
-
def align_audio_test(wav1, wav2, sr1=44100):
|
840 |
-
def get_diff(a, b):
|
841 |
-
return np.correlate(a, b, "full").argmax() - (b.shape[0] - 1)
|
842 |
-
|
843 |
-
wav1 = wav1.transpose()
|
844 |
-
wav2 = wav2.transpose()
|
845 |
-
wav2_org = wav2.copy()
|
846 |
-
index = sr1
|
847 |
-
diff = get_diff(wav1[index : index + sr1, 0], wav2[index : index + sr1, 0])
|
848 |
-
|
849 |
-
if diff > 0: wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
|
850 |
-
elif diff < 0: wav2_aligned = wav2_org[-diff:]
|
851 |
-
else: wav2_aligned = wav2_org
|
852 |
-
return wav2_aligned
|
853 |
-
|
854 |
-
def load_audio(audio_file):
|
855 |
-
wav, _ = librosa.load(audio_file, sr=44100, mono=False)
|
856 |
-
if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
|
857 |
-
return wav
|
858 |
-
|
859 |
-
def __rubberband(y, sr, **kwargs):
|
860 |
-
assert sr > 0
|
861 |
-
fd, infile = tempfile.mkstemp(suffix='.wav')
|
862 |
-
os.close(fd)
|
863 |
-
fd, outfile = tempfile.mkstemp(suffix='.wav')
|
864 |
-
os.close(fd)
|
865 |
-
|
866 |
-
sf.write(infile, y, sr)
|
867 |
-
|
868 |
-
try:
|
869 |
-
arguments = [os.path.join(BASE_PATH_RUB, 'rubberband'), '-q']
|
870 |
-
for key, value in six.iteritems(kwargs):
|
871 |
-
arguments.append(str(key))
|
872 |
-
arguments.append(str(value))
|
873 |
-
|
874 |
-
arguments.extend([infile, outfile])
|
875 |
-
subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
|
876 |
-
|
877 |
-
y_out, _ = sf.read(outfile, always_2d=True)
|
878 |
-
if y.ndim == 1: y_out = np.squeeze(y_out)
|
879 |
-
except OSError as exc:
|
880 |
-
six.raise_from(RuntimeError(translations["rubberband"]), exc)
|
881 |
-
finally:
|
882 |
-
os.unlink(infile)
|
883 |
-
os.unlink(outfile)
|
884 |
-
|
885 |
-
return y_out
|
886 |
-
|
887 |
-
def time_stretch(y, sr, rate, rbargs=None):
|
888 |
-
if rate <= 0: raise ValueError(translations["rate"])
|
889 |
-
if rate == 1.0: return y
|
890 |
-
if rbargs is None: rbargs = dict()
|
891 |
-
|
892 |
-
rbargs.setdefault('--tempo', rate)
|
893 |
-
return __rubberband(y, sr, **rbargs)
|
894 |
-
|
895 |
-
def pitch_shift(y, sr, n_steps, rbargs=None):
|
896 |
-
if n_steps == 0: return y
|
897 |
-
if rbargs is None: rbargs = dict()
|
898 |
-
|
899 |
-
rbargs.setdefault('--pitch', n_steps)
|
900 |
-
return __rubberband(y, sr, **rbargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/edge_tts.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
import ssl
|
3 |
-
import json
|
4 |
-
import time
|
5 |
-
import uuid
|
6 |
-
import codecs
|
7 |
-
import certifi
|
8 |
-
import aiohttp
|
9 |
-
|
10 |
-
from io import TextIOWrapper
|
11 |
-
from dataclasses import dataclass
|
12 |
-
from contextlib import nullcontext
|
13 |
-
from xml.sax.saxutils import escape
|
14 |
-
|
15 |
-
|
16 |
-
@dataclass
|
17 |
-
class TTSConfig:
|
18 |
-
def __init__(self, voice, rate, volume, pitch):
|
19 |
-
self.voice = voice
|
20 |
-
self.rate = rate
|
21 |
-
self.volume = volume
|
22 |
-
self.pitch = pitch
|
23 |
-
|
24 |
-
@staticmethod
|
25 |
-
def validate_string_param(param_name, param_value, pattern):
|
26 |
-
if re.match(pattern, param_value) is None: raise ValueError(f"{param_name} '{param_value}'.")
|
27 |
-
return param_value
|
28 |
-
|
29 |
-
def __post_init__(self):
|
30 |
-
match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", self.voice)
|
31 |
-
if match is not None:
|
32 |
-
region = match.group(2)
|
33 |
-
name = match.group(3)
|
34 |
-
|
35 |
-
if name.find("-") != -1:
|
36 |
-
region = region + "-" + name[: name.find("-")]
|
37 |
-
name = name[name.find("-") + 1 :]
|
38 |
-
|
39 |
-
self.voice = ("Microsoft Server Speech Text to Speech Voice" + f" ({match.group(1)}-{region}, {name})")
|
40 |
-
|
41 |
-
self.validate_string_param("voice", self.voice, r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$")
|
42 |
-
self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
|
43 |
-
self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
|
44 |
-
self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")
|
45 |
-
|
46 |
-
def get_headers_and_data(data, header_length):
|
47 |
-
headers = {}
|
48 |
-
|
49 |
-
for line in data[:header_length].split(b"\r\n"):
|
50 |
-
key, value = line.split(b":", 1)
|
51 |
-
headers[key] = value
|
52 |
-
|
53 |
-
return headers, data[header_length + 2 :]
|
54 |
-
|
55 |
-
def date_to_string():
|
56 |
-
return time.strftime("%a %b %d %Y %H:%M:%S GMT+0000 (Coordinated Universal Time)", time.gmtime())
|
57 |
-
|
58 |
-
def mkssml(tc, escaped_text):
|
59 |
-
if isinstance(escaped_text, bytes): escaped_text = escaped_text.decode("utf-8")
|
60 |
-
return (f"<speak version='1.0' xmlns='{codecs.decode('uggc://jjj.j3.bet/2001/10/flagurfvf', 'rot13')}' xml:lang='en-US'>" f"<voice name='{tc.voice}'>" f"<prosody pitch='{tc.pitch}' rate='{tc.rate}' volume='{tc.volume}'>" f"{escaped_text}" "</prosody>" "</voice>" "</speak>")
|
61 |
-
|
62 |
-
def connect_id():
|
63 |
-
return str(uuid.uuid4()).replace("-", "")
|
64 |
-
|
65 |
-
def ssml_headers_plus_data(request_id, timestamp, ssml):
|
66 |
-
return (f"X-RequestId:{request_id}\r\n" "Content-Type:application/ssml+xml\r\n" f"X-Timestamp:{timestamp}Z\r\n" "Path:ssml\r\n\r\n" f"{ssml}")
|
67 |
-
|
68 |
-
def remove_incompatible_characters(string):
|
69 |
-
if isinstance(string, bytes): string = string.decode("utf-8")
|
70 |
-
chars = list(string)
|
71 |
-
|
72 |
-
for idx, char in enumerate(chars):
|
73 |
-
code = ord(char)
|
74 |
-
if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31): chars[idx] = " "
|
75 |
-
|
76 |
-
return "".join(chars)
|
77 |
-
|
78 |
-
def split_text_by_byte_length(text, byte_length):
|
79 |
-
if isinstance(text, str): text = text.encode("utf-8")
|
80 |
-
if byte_length <= 0: raise ValueError("byte_length > 0")
|
81 |
-
|
82 |
-
while len(text) > byte_length:
|
83 |
-
split_at = text.rfind(b" ", 0, byte_length)
|
84 |
-
split_at = split_at if split_at != -1 else byte_length
|
85 |
-
|
86 |
-
while b"&" in text[:split_at]:
|
87 |
-
ampersand_index = text.rindex(b"&", 0, split_at)
|
88 |
-
if text.find(b";", ampersand_index, split_at) != -1: break
|
89 |
-
|
90 |
-
split_at = ampersand_index - 1
|
91 |
-
if split_at == 0: break
|
92 |
-
|
93 |
-
new_text = text[:split_at].strip()
|
94 |
-
|
95 |
-
if new_text: yield new_text
|
96 |
-
if split_at == 0: split_at = 1
|
97 |
-
|
98 |
-
text = text[split_at:]
|
99 |
-
|
100 |
-
new_text = text.strip()
|
101 |
-
if new_text: yield new_text
|
102 |
-
|
103 |
-
class Communicate:
|
104 |
-
def __init__(self, text, voice, *, rate="+0%", volume="+0%", pitch="+0Hz", proxy=None, connect_timeout=10, receive_timeout=60):
|
105 |
-
self.tts_config = TTSConfig(voice, rate, volume, pitch)
|
106 |
-
self.texts = split_text_by_byte_length(escape(remove_incompatible_characters(text)), 2**16 - (len(ssml_headers_plus_data(connect_id(), date_to_string(), mkssml(self.tts_config, ""))) + 50))
|
107 |
-
self.proxy = proxy
|
108 |
-
self.session_timeout = aiohttp.ClientTimeout(total=None, connect=None, sock_connect=connect_timeout, sock_read=receive_timeout)
|
109 |
-
self.state = {"partial_text": None, "offset_compensation": 0, "last_duration_offset": 0, "stream_was_called": False}
|
110 |
-
|
111 |
-
def __parse_metadata(self, data):
|
112 |
-
for meta_obj in json.loads(data)["Metadata"]:
|
113 |
-
meta_type = meta_obj["Type"]
|
114 |
-
if meta_type == "WordBoundary": return {"type": meta_type, "offset": (meta_obj["Data"]["Offset"] + self.state["offset_compensation"]), "duration": meta_obj["Data"]["Duration"], "text": meta_obj["Data"]["text"]["Text"]}
|
115 |
-
if meta_type in ("SessionEnd",): continue
|
116 |
-
|
117 |
-
async def __stream(self):
|
118 |
-
async def send_command_request():
|
119 |
-
await websocket.send_str(f"X-Timestamp:{date_to_string()}\r\n" "Content-Type:application/json; charset=utf-8\r\n" "Path:speech.config\r\n\r\n" '{"context":{"synthesis":{"audio":{"metadataoptions":{' '"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true},' '"outputFormat":"audio-24khz-48kbitrate-mono-mp3"' "}}}}\r\n")
|
120 |
-
|
121 |
-
async def send_ssml_request():
|
122 |
-
await websocket.send_str(ssml_headers_plus_data(connect_id(), date_to_string(), mkssml(self.tts_config, self.state["partial_text"])))
|
123 |
-
|
124 |
-
audio_was_received = False
|
125 |
-
ssl_ctx = ssl.create_default_context(cafile=certifi.where())
|
126 |
-
|
127 |
-
async with aiohttp.ClientSession(trust_env=True, timeout=self.session_timeout) as session, session.ws_connect(f"wss://speech.platform.bing.com/consumer/speech/synthesize/readaloud/edge/v1?TrustedClientToken=6A5AA1D4EAFF4E9FB37E23D68491D6F4&ConnectionId={connect_id()}", compress=15, proxy=self.proxy, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" " (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36" " Edg/130.0.0.0", "Accept-Encoding": "gzip, deflate, br", "Accept-Language": "en-US,en;q=0.9", "Pragma": "no-cache", "Cache-Control": "no-cache", "Origin": "chrome-extension://jdiccldimpdaibmpdkjnbmckianbfold"}, ssl=ssl_ctx) as websocket:
|
128 |
-
await send_command_request()
|
129 |
-
await send_ssml_request()
|
130 |
-
|
131 |
-
async for received in websocket:
|
132 |
-
if received.type == aiohttp.WSMsgType.TEXT:
|
133 |
-
encoded_data: bytes = received.data.encode("utf-8")
|
134 |
-
parameters, data = get_headers_and_data(encoded_data, encoded_data.find(b"\r\n\r\n"))
|
135 |
-
path = parameters.get(b"Path", None)
|
136 |
-
|
137 |
-
if path == b"audio.metadata":
|
138 |
-
parsed_metadata = self.__parse_metadata(data)
|
139 |
-
yield parsed_metadata
|
140 |
-
self.state["last_duration_offset"] = (parsed_metadata["offset"] + parsed_metadata["duration"])
|
141 |
-
elif path == b"turn.end":
|
142 |
-
self.state["offset_compensation"] = self.state["last_duration_offset"]
|
143 |
-
self.state["offset_compensation"] += 8_750_000
|
144 |
-
break
|
145 |
-
elif received.type == aiohttp.WSMsgType.BINARY:
|
146 |
-
if len(received.data) < 2: raise Exception("received.data < 2")
|
147 |
-
|
148 |
-
header_length = int.from_bytes(received.data[:2], "big")
|
149 |
-
if header_length > len(received.data): raise Exception("header_length > received.data")
|
150 |
-
|
151 |
-
parameters, data = get_headers_and_data(received.data, header_length)
|
152 |
-
if parameters.get(b"Path") != b"audio": raise Exception("Path != audio")
|
153 |
-
|
154 |
-
content_type = parameters.get(b"Content-Type", None)
|
155 |
-
if content_type not in [b"audio/mpeg", None]: raise Exception("content_type != audio/mpeg")
|
156 |
-
|
157 |
-
if content_type is None and len(data) == 0: continue
|
158 |
-
|
159 |
-
if len(data) == 0: raise Exception("data = 0")
|
160 |
-
audio_was_received = True
|
161 |
-
yield {"type": "audio", "data": data}
|
162 |
-
|
163 |
-
if not audio_was_received: raise Exception("!audio_was_received")
|
164 |
-
|
165 |
-
async def stream(self):
|
166 |
-
if self.state["stream_was_called"]: raise RuntimeError("stream_was_called")
|
167 |
-
self.state["stream_was_called"] = True
|
168 |
-
|
169 |
-
for self.state["partial_text"] in self.texts:
|
170 |
-
async for message in self.__stream():
|
171 |
-
yield message
|
172 |
-
|
173 |
-
async def save(self, audio_fname, metadata_fname = None):
|
174 |
-
metadata = (open(metadata_fname, "w", encoding="utf-8") if metadata_fname is not None else nullcontext())
|
175 |
-
with metadata, open(audio_fname, "wb") as audio:
|
176 |
-
async for message in self.stream():
|
177 |
-
if message["type"] == "audio": audio.write(message["data"])
|
178 |
-
elif (isinstance(metadata, TextIOWrapper) and message["type"] == "WordBoundary"):
|
179 |
-
json.dump(message, metadata)
|
180 |
-
metadata.write("\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/gdown.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import json
|
5 |
-
import tqdm
|
6 |
-
import codecs
|
7 |
-
import tempfile
|
8 |
-
import requests
|
9 |
-
|
10 |
-
from urllib.parse import urlparse, parse_qs, unquote
|
11 |
-
|
12 |
-
sys.path.append(os.getcwd())
|
13 |
-
|
14 |
-
from main.configs.config import Config
|
15 |
-
translations = Config().translations
|
16 |
-
|
17 |
-
def parse_url(url):
|
18 |
-
parsed = urlparse(url)
|
19 |
-
is_download_link = parsed.path.endswith("/uc")
|
20 |
-
if not parsed.hostname in ("drive.google.com", "docs.google.com"): return None, is_download_link
|
21 |
-
file_id = parse_qs(parsed.query).get("id", [None])[0]
|
22 |
-
|
23 |
-
if file_id is None:
|
24 |
-
for pattern in (r"^/file/d/(.*?)/(edit|view)$", r"^/file/u/[0-9]+/d/(.*?)/(edit|view)$", r"^/document/d/(.*?)/(edit|htmlview|view)$", r"^/document/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$"):
|
25 |
-
match = re.match(pattern, parsed.path)
|
26 |
-
if match:
|
27 |
-
file_id = match.group(1)
|
28 |
-
break
|
29 |
-
return file_id, is_download_link
|
30 |
-
|
31 |
-
def get_url_from_gdrive_confirmation(contents):
|
32 |
-
for pattern in (r'href="(\/uc\?export=download[^"]+)', r'href="/open\?id=([^"]+)"', r'"downloadUrl":"([^"]+)'):
|
33 |
-
match = re.search(pattern, contents)
|
34 |
-
if match:
|
35 |
-
url = match.group(1)
|
36 |
-
if pattern == r'href="/open\?id=([^"]+)"': url = (codecs.decode("uggcf://qevir.hfrepbagrag.tbbtyr.pbz/qbjaybnq?vq=", "rot13") + url + "&confirm=t&uuid=" + re.search(r'<input\s+type="hidden"\s+name="uuid"\s+value="([^"]+)"', contents).group(1))
|
37 |
-
elif pattern == r'"downloadUrl":"([^"]+)': url = url.replace("\\u003d", "=").replace("\\u0026", "&")
|
38 |
-
else: url = codecs.decode("uggcf://qbpf.tbbtyr.pbz", "rot13") + url.replace("&", "&")
|
39 |
-
return url
|
40 |
-
|
41 |
-
match = re.search(r'<p class="uc-error-subcaption">(.*)</p>', contents)
|
42 |
-
if match: raise Exception(match.group(1))
|
43 |
-
raise Exception(translations["gdown_error"])
|
44 |
-
|
45 |
-
def _get_session(use_cookies, return_cookies_file=False):
|
46 |
-
sess = requests.session()
|
47 |
-
sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
|
48 |
-
cookies_file = os.path.join(os.path.expanduser("~"), ".cache/gdown/cookies.json")
|
49 |
-
|
50 |
-
if os.path.exists(cookies_file) and use_cookies:
|
51 |
-
with open(cookies_file) as f:
|
52 |
-
for k, v in json.load(f):
|
53 |
-
sess.cookies[k] = v
|
54 |
-
return (sess, cookies_file) if return_cookies_file else sess
|
55 |
-
|
56 |
-
def gdown_download(url=None, id=None, output=None):
|
57 |
-
if not (id is None) ^ (url is None): raise ValueError(translations["gdown_value_error"])
|
58 |
-
if id is not None: url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{id}"
|
59 |
-
|
60 |
-
url_origin = url
|
61 |
-
sess, cookies_file = _get_session(use_cookies=True, return_cookies_file=True)
|
62 |
-
gdrive_file_id, is_gdrive_download_link = parse_url(url)
|
63 |
-
|
64 |
-
if gdrive_file_id:
|
65 |
-
url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{gdrive_file_id}"
|
66 |
-
url_origin = url
|
67 |
-
is_gdrive_download_link = True
|
68 |
-
|
69 |
-
while 1:
|
70 |
-
res = sess.get(url, stream=True, verify=True)
|
71 |
-
if url == url_origin and res.status_code == 500:
|
72 |
-
url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/bcra?vq=', 'rot13')}{gdrive_file_id}"
|
73 |
-
continue
|
74 |
-
|
75 |
-
os.makedirs(os.path.dirname(cookies_file), exist_ok=True)
|
76 |
-
with open(cookies_file, "w") as f:
|
77 |
-
json.dump([(k, v) for k, v in sess.cookies.items() if not k.startswith("download_warning_")], f, indent=2)
|
78 |
-
|
79 |
-
if "Content-Disposition" in res.headers: break
|
80 |
-
if not (gdrive_file_id and is_gdrive_download_link): break
|
81 |
-
|
82 |
-
try:
|
83 |
-
url = get_url_from_gdrive_confirmation(res.text)
|
84 |
-
except Exception as e:
|
85 |
-
raise Exception(e)
|
86 |
-
|
87 |
-
if gdrive_file_id and is_gdrive_download_link:
|
88 |
-
content_disposition = unquote(res.headers["Content-Disposition"])
|
89 |
-
filename_from_url = (re.search(r"filename\*=UTF-8''(.*)", content_disposition) or re.search(r'filename=["\']?(.*?)["\']?$', content_disposition)).group(1).replace(os.path.sep, "_")
|
90 |
-
else: filename_from_url = os.path.basename(url)
|
91 |
-
|
92 |
-
output = os.path.join(output or ".", filename_from_url)
|
93 |
-
tmp_file = tempfile.mktemp(suffix=tempfile.template, prefix=os.path.basename(output), dir=os.path.dirname(output))
|
94 |
-
f = open(tmp_file, "ab")
|
95 |
-
|
96 |
-
if tmp_file is not None and f.tell() != 0: res = sess.get(url, headers={"Range": f"bytes={f.tell()}-"}, stream=True, verify=True)
|
97 |
-
print(translations["to"], os.path.abspath(output), file=sys.stderr)
|
98 |
-
|
99 |
-
try:
|
100 |
-
with tqdm.tqdm(total=int(res.headers.get("Content-Length", 0)), ncols=100, unit="byte") as pbar:
|
101 |
-
for chunk in res.iter_content(chunk_size=512 * 1024):
|
102 |
-
f.write(chunk)
|
103 |
-
pbar.update(len(chunk))
|
104 |
-
|
105 |
-
pbar.close()
|
106 |
-
if tmp_file: f.close()
|
107 |
-
finally:
|
108 |
-
os.rename(tmp_file, output)
|
109 |
-
sess.close()
|
110 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/google_tts.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import codecs
|
3 |
-
import librosa
|
4 |
-
import requests
|
5 |
-
|
6 |
-
import soundfile as sf
|
7 |
-
|
8 |
-
def google_tts(text, lang="vi", speed=1, pitch=0, output_file="output.mp3"):
|
9 |
-
try:
|
10 |
-
response = requests.get(codecs.decode("uggcf://genafyngr.tbbtyr.pbz/genafyngr_ggf", "rot13"), params={"ie": "UTF-8", "q": text, "tl": lang, "ttsspeed": speed, "client": "tw-ob"}, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36"})
|
11 |
-
|
12 |
-
if response.status_code == 200:
|
13 |
-
with open(output_file, "wb") as f:
|
14 |
-
f.write(response.content)
|
15 |
-
|
16 |
-
format = os.path.splitext(os.path.basename(output_file))[-1].lower().replace('.', '')
|
17 |
-
|
18 |
-
if pitch != 0: pitch_shift(input_file=output_file, output_file=output_file, pitch=pitch, export_format=format)
|
19 |
-
if speed != 1: change_speed(input_file=output_file, output_file=output_file, speed=speed, export_format=format)
|
20 |
-
else: raise ValueError(f"{response.status_code}, {response.text}")
|
21 |
-
except Exception as e:
|
22 |
-
raise RuntimeError(e)
|
23 |
-
|
24 |
-
def pitch_shift(input_file, output_file, pitch, export_format):
|
25 |
-
y, sr = librosa.load(input_file, sr=None)
|
26 |
-
sf.write(file=output_file, data=librosa.effects.pitch_shift(y, sr=sr, n_steps=pitch), samplerate=sr, format=export_format)
|
27 |
-
|
28 |
-
def change_speed(input_file, output_file, speed, export_format):
|
29 |
-
y, sr = librosa.load(input_file, sr=None)
|
30 |
-
sf.write(file=output_file, data=librosa.effects.time_stretch(y, rate=speed), samplerate=sr, format=export_format)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/huggingface.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import requests
|
3 |
-
import tqdm
|
4 |
-
|
5 |
-
|
6 |
-
def HF_download_file(url, output_path=None):
|
7 |
-
url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
|
8 |
-
|
9 |
-
if output_path is None: output_path = os.path.basename(url)
|
10 |
-
else: output_path = os.path.join(output_path, os.path.basename(url)) if os.path.isdir(output_path) else output_path
|
11 |
-
|
12 |
-
response = requests.get(url, stream=True, timeout=300)
|
13 |
-
|
14 |
-
if response.status_code == 200:
|
15 |
-
progress_bar = tqdm.tqdm(total=int(response.headers.get("content-length", 0)), ncols=100, unit="byte")
|
16 |
-
|
17 |
-
with open(output_path, "wb") as f:
|
18 |
-
for chunk in response.iter_content(chunk_size=8192):
|
19 |
-
progress_bar.update(len(chunk))
|
20 |
-
f.write(chunk)
|
21 |
-
|
22 |
-
progress_bar.close()
|
23 |
-
return output_path
|
24 |
-
else: raise ValueError(response.status_code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/mediafire.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import requests
|
4 |
-
from bs4 import BeautifulSoup
|
5 |
-
|
6 |
-
|
7 |
-
def Mediafire_Download(url, output=None, filename=None):
|
8 |
-
if not filename: filename = url.split('/')[-2]
|
9 |
-
if not output: output = os.path.dirname(os.path.realpath(__file__))
|
10 |
-
output_file = os.path.join(output, filename)
|
11 |
-
|
12 |
-
sess = requests.session()
|
13 |
-
sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
|
14 |
-
|
15 |
-
try:
|
16 |
-
with requests.get(BeautifulSoup(sess.get(url).content, "html.parser").find(id="downloadButton").get("href"), stream=True) as r:
|
17 |
-
r.raise_for_status()
|
18 |
-
with open(output_file, "wb") as f:
|
19 |
-
total_length = int(r.headers.get('content-length'))
|
20 |
-
download_progress = 0
|
21 |
-
|
22 |
-
for chunk in r.iter_content(chunk_size=1024):
|
23 |
-
download_progress += len(chunk)
|
24 |
-
f.write(chunk)
|
25 |
-
sys.stdout.write(f"\r[{filename}]: {int(100 * download_progress/total_length)}% ({round(download_progress/1024/1024, 2)}mb/{round(total_length/1024/1024, 2)}mb)")
|
26 |
-
sys.stdout.flush()
|
27 |
-
sys.stdout.write("\n")
|
28 |
-
return output_file
|
29 |
-
except Exception as e:
|
30 |
-
raise RuntimeError(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/meganz.py
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import json
|
5 |
-
import tqdm
|
6 |
-
import time
|
7 |
-
import codecs
|
8 |
-
import random
|
9 |
-
import base64
|
10 |
-
import struct
|
11 |
-
import shutil
|
12 |
-
import requests
|
13 |
-
import tempfile
|
14 |
-
|
15 |
-
from Crypto.Cipher import AES
|
16 |
-
from Crypto.Util import Counter
|
17 |
-
|
18 |
-
sys.path.append(os.getcwd())
|
19 |
-
|
20 |
-
from main.configs.config import Config
|
21 |
-
translations = Config().translations
|
22 |
-
|
23 |
-
def makebyte(x):
|
24 |
-
return codecs.latin_1_encode(x)[0]
|
25 |
-
|
26 |
-
def a32_to_str(a):
|
27 |
-
return struct.pack('>%dI' % len(a), *a)
|
28 |
-
|
29 |
-
def get_chunks(size):
|
30 |
-
p, s = 0, 0x20000
|
31 |
-
|
32 |
-
while p + s < size:
|
33 |
-
yield (p, s)
|
34 |
-
p += s
|
35 |
-
|
36 |
-
if s < 0x100000: s += 0x20000
|
37 |
-
|
38 |
-
yield (p, size - p)
|
39 |
-
|
40 |
-
def decrypt_attr(attr, key):
|
41 |
-
attr = codecs.latin_1_decode(AES.new(a32_to_str(key), AES.MODE_CBC, makebyte('\0' * 16)).decrypt(attr))[0].rstrip('\0')
|
42 |
-
|
43 |
-
return json.loads(attr[4:]) if attr[:6] == 'MEGA{"' else False
|
44 |
-
|
45 |
-
def _api_request(data):
|
46 |
-
sequence_num = random.randint(0, 0xFFFFFFFF)
|
47 |
-
params = {'id': sequence_num}
|
48 |
-
sequence_num += 1
|
49 |
-
|
50 |
-
if not isinstance(data, list): data = [data]
|
51 |
-
|
52 |
-
for attempt in range(60):
|
53 |
-
try:
|
54 |
-
json_resp = json.loads(requests.post(f'https://g.api.mega.co.nz/cs', params=params, data=json.dumps(data), timeout=160).text)
|
55 |
-
|
56 |
-
try:
|
57 |
-
if isinstance(json_resp, list): int_resp = json_resp[0] if isinstance(json_resp[0], int) else None
|
58 |
-
elif isinstance(json_resp, int): int_resp = json_resp
|
59 |
-
except IndexError:
|
60 |
-
int_resp = None
|
61 |
-
|
62 |
-
if int_resp is not None:
|
63 |
-
if int_resp == 0: return int_resp
|
64 |
-
if int_resp == -3: raise RuntimeError('int_resp==-3')
|
65 |
-
raise Exception(int_resp)
|
66 |
-
|
67 |
-
return json_resp[0]
|
68 |
-
except (RuntimeError, requests.exceptions.RequestException):
|
69 |
-
if attempt == 60 - 1: raise
|
70 |
-
delay = 2 * (2 ** attempt)
|
71 |
-
time.sleep(delay)
|
72 |
-
|
73 |
-
def base64_url_decode(data):
|
74 |
-
data += '=='[(2 - len(data) * 3) % 4:]
|
75 |
-
|
76 |
-
for search, replace in (('-', '+'), ('_', '/'), (',', '')):
|
77 |
-
data = data.replace(search, replace)
|
78 |
-
|
79 |
-
return base64.b64decode(data)
|
80 |
-
|
81 |
-
def str_to_a32(b):
|
82 |
-
if isinstance(b, str): b = makebyte(b)
|
83 |
-
if len(b) % 4: b += b'\0' * (4 - len(b) % 4)
|
84 |
-
|
85 |
-
return struct.unpack('>%dI' % (len(b) / 4), b)
|
86 |
-
|
87 |
-
def mega_download_file(file_handle, file_key, dest_path=None, dest_filename=None, file=None):
|
88 |
-
if file is None:
|
89 |
-
file_key = str_to_a32(base64_url_decode(file_key))
|
90 |
-
file_data = _api_request({'a': 'g', 'g': 1, 'p': file_handle})
|
91 |
-
|
92 |
-
k = (file_key[0] ^ file_key[4], file_key[1] ^ file_key[5], file_key[2] ^ file_key[6], file_key[3] ^ file_key[7])
|
93 |
-
iv = file_key[4:6] + (0, 0)
|
94 |
-
meta_mac = file_key[6:8]
|
95 |
-
else:
|
96 |
-
file_data = _api_request({'a': 'g', 'g': 1, 'n': file['h']})
|
97 |
-
k = file['k']
|
98 |
-
iv = file['iv']
|
99 |
-
meta_mac = file['meta_mac']
|
100 |
-
|
101 |
-
if 'g' not in file_data: raise Exception(translations["file_not_access"])
|
102 |
-
file_size = file_data['s']
|
103 |
-
|
104 |
-
attribs = decrypt_attr(base64_url_decode(file_data['at']), k)
|
105 |
-
|
106 |
-
file_name = dest_filename if dest_filename is not None else attribs['n']
|
107 |
-
input_file = requests.get(file_data['g'], stream=True).raw
|
108 |
-
|
109 |
-
if dest_path is None: dest_path = ''
|
110 |
-
else: dest_path += '/'
|
111 |
-
|
112 |
-
temp_output_file = tempfile.NamedTemporaryFile(mode='w+b', prefix='megapy_', delete=False)
|
113 |
-
k_str = a32_to_str(k)
|
114 |
-
|
115 |
-
aes = AES.new(k_str, AES.MODE_CTR, counter=Counter.new(128, initial_value=((iv[0] << 32) + iv[1]) << 64))
|
116 |
-
mac_str = b'\0' * 16
|
117 |
-
mac_encryptor = AES.new(k_str, AES.MODE_CBC, mac_str)
|
118 |
-
|
119 |
-
iv_str = a32_to_str([iv[0], iv[1], iv[0], iv[1]])
|
120 |
-
pbar = tqdm.tqdm(total=file_size, ncols=100, unit="byte")
|
121 |
-
|
122 |
-
for _, chunk_size in get_chunks(file_size):
|
123 |
-
chunk = aes.decrypt(input_file.read(chunk_size))
|
124 |
-
temp_output_file.write(chunk)
|
125 |
-
|
126 |
-
pbar.update(len(chunk))
|
127 |
-
encryptor = AES.new(k_str, AES.MODE_CBC, iv_str)
|
128 |
-
|
129 |
-
for i in range(0, len(chunk)-16, 16):
|
130 |
-
block = chunk[i:i + 16]
|
131 |
-
encryptor.encrypt(block)
|
132 |
-
|
133 |
-
if file_size > 16: i += 16
|
134 |
-
else: i = 0
|
135 |
-
|
136 |
-
block = chunk[i:i + 16]
|
137 |
-
if len(block) % 16: block += b'\0' * (16 - (len(block) % 16))
|
138 |
-
|
139 |
-
mac_str = mac_encryptor.encrypt(encryptor.encrypt(block))
|
140 |
-
|
141 |
-
file_mac = str_to_a32(mac_str)
|
142 |
-
temp_output_file.close()
|
143 |
-
|
144 |
-
if (file_mac[0] ^ file_mac[1], file_mac[2] ^ file_mac[3]) != meta_mac: raise ValueError(translations["mac_not_match"])
|
145 |
-
|
146 |
-
file_path = os.path.join(dest_path, file_name)
|
147 |
-
if os.path.exists(file_path): os.remove(file_path)
|
148 |
-
|
149 |
-
shutil.move(temp_output_file.name, file_path)
|
150 |
-
|
151 |
-
def mega_download_url(url, dest_path=None, dest_filename=None):
|
152 |
-
if '/file/' in url:
|
153 |
-
url = url.replace(' ', '')
|
154 |
-
file_id = re.findall(r'\W\w\w\w\w\w\w\w\w\W', url)[0][1:-1]
|
155 |
-
|
156 |
-
path = f'{file_id}!{url[re.search(file_id, url).end() + 1:]}'.split('!')
|
157 |
-
elif '!' in url: path = re.findall(r'/#!(.*)', url)[0].split('!')
|
158 |
-
else: raise Exception(translations["missing_url"])
|
159 |
-
|
160 |
-
return mega_download_file(file_handle=path[0], file_key=path[1], dest_path=dest_path, dest_filename=dest_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/noisereduce.py
DELETED
@@ -1,200 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import tempfile
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
from joblib import Parallel, delayed
|
7 |
-
from torch.nn.functional import conv1d, conv2d
|
8 |
-
|
9 |
-
from main.configs.config import Config
|
10 |
-
translations = Config().translations
|
11 |
-
|
12 |
-
@torch.no_grad()
|
13 |
-
def amp_to_db(x, eps = torch.finfo(torch.float32).eps, top_db = 40):
|
14 |
-
x_db = 20 * torch.log10(x.abs() + eps)
|
15 |
-
return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1))
|
16 |
-
|
17 |
-
@torch.no_grad()
|
18 |
-
def temperature_sigmoid(x, x0, temp_coeff):
|
19 |
-
return torch.sigmoid((x - x0) / temp_coeff)
|
20 |
-
|
21 |
-
@torch.no_grad()
|
22 |
-
def linspace(start, stop, num = 50, endpoint = True, **kwargs):
|
23 |
-
return torch.linspace(start, stop, num, **kwargs) if endpoint else torch.linspace(start, stop, num + 1, **kwargs)[:-1]
|
24 |
-
|
25 |
-
def _smoothing_filter(n_grad_freq, n_grad_time):
|
26 |
-
smoothing_filter = np.outer(np.concatenate([np.linspace(0, 1, n_grad_freq + 1, endpoint=False), np.linspace(1, 0, n_grad_freq + 2)])[1:-1], np.concatenate([np.linspace(0, 1, n_grad_time + 1, endpoint=False), np.linspace(1, 0, n_grad_time + 2)])[1:-1])
|
27 |
-
return smoothing_filter / np.sum(smoothing_filter)
|
28 |
-
|
29 |
-
class SpectralGate:
|
30 |
-
def __init__(self, y, sr, prop_decrease, chunk_size, padding, n_fft, win_length, hop_length, time_constant_s, freq_mask_smooth_hz, time_mask_smooth_ms, tmp_folder, use_tqdm, n_jobs):
|
31 |
-
self.sr = sr
|
32 |
-
self.flat = False
|
33 |
-
y = np.array(y)
|
34 |
-
|
35 |
-
if len(y.shape) == 1:
|
36 |
-
self.y = np.expand_dims(y, 0)
|
37 |
-
self.flat = True
|
38 |
-
elif len(y.shape) > 2: raise ValueError(translations["waveform"])
|
39 |
-
else: self.y = y
|
40 |
-
|
41 |
-
self._dtype = y.dtype
|
42 |
-
self.n_channels, self.n_frames = self.y.shape
|
43 |
-
self._chunk_size = chunk_size
|
44 |
-
self.padding = padding
|
45 |
-
self.n_jobs = n_jobs
|
46 |
-
self.use_tqdm = use_tqdm
|
47 |
-
self._tmp_folder = tmp_folder
|
48 |
-
self._n_fft = n_fft
|
49 |
-
self._win_length = self._n_fft if win_length is None else win_length
|
50 |
-
self._hop_length = (self._win_length // 4) if hop_length is None else hop_length
|
51 |
-
self._time_constant_s = time_constant_s
|
52 |
-
self._prop_decrease = prop_decrease
|
53 |
-
|
54 |
-
if (freq_mask_smooth_hz is None) & (time_mask_smooth_ms is None): self.smooth_mask = False
|
55 |
-
else: self._generate_mask_smoothing_filter(freq_mask_smooth_hz, time_mask_smooth_ms)
|
56 |
-
|
57 |
-
def _generate_mask_smoothing_filter(self, freq_mask_smooth_hz, time_mask_smooth_ms):
|
58 |
-
if freq_mask_smooth_hz is None: n_grad_freq = 1
|
59 |
-
else:
|
60 |
-
n_grad_freq = int(freq_mask_smooth_hz / (self.sr / (self._n_fft / 2)))
|
61 |
-
if n_grad_freq < 1: raise ValueError(translations["freq_mask_smooth_hz"].format(hz=int((self.sr / (self._n_fft / 2)))))
|
62 |
-
|
63 |
-
if time_mask_smooth_ms is None: n_grad_time = 1
|
64 |
-
else:
|
65 |
-
n_grad_time = int(time_mask_smooth_ms / ((self._hop_length / self.sr) * 1000))
|
66 |
-
if n_grad_time < 1: raise ValueError(translations["time_mask_smooth_ms"].format(ms=int((self._hop_length / self.sr) * 1000)))
|
67 |
-
|
68 |
-
if (n_grad_time == 1) & (n_grad_freq == 1): self.smooth_mask = False
|
69 |
-
else:
|
70 |
-
self.smooth_mask = True
|
71 |
-
self._smoothing_filter = _smoothing_filter(n_grad_freq, n_grad_time)
|
72 |
-
|
73 |
-
def _read_chunk(self, i1, i2):
|
74 |
-
i1b = 0 if i1 < 0 else i1
|
75 |
-
i2b = self.n_frames if i2 > self.n_frames else i2
|
76 |
-
chunk = np.zeros((self.n_channels, i2 - i1))
|
77 |
-
chunk[:, i1b - i1: i2b - i1] = self.y[:, i1b:i2b]
|
78 |
-
return chunk
|
79 |
-
|
80 |
-
def filter_chunk(self, start_frame, end_frame):
|
81 |
-
i1 = start_frame - self.padding
|
82 |
-
return self._do_filter(self._read_chunk(i1, (end_frame + self.padding)))[:, start_frame - i1: end_frame - i1]
|
83 |
-
|
84 |
-
def _get_filtered_chunk(self, ind):
|
85 |
-
start0 = ind * self._chunk_size
|
86 |
-
end0 = (ind + 1) * self._chunk_size
|
87 |
-
return self.filter_chunk(start_frame=start0, end_frame=end0)
|
88 |
-
|
89 |
-
def _do_filter(self, chunk):
|
90 |
-
pass
|
91 |
-
|
92 |
-
def _iterate_chunk(self, filtered_chunk, pos, end0, start0, ich):
|
93 |
-
filtered_chunk[:, pos: pos + end0 - start0] = self._get_filtered_chunk(ich)[:, start0:end0]
|
94 |
-
pos += end0 - start0
|
95 |
-
|
96 |
-
def get_traces(self, start_frame=None, end_frame=None):
|
97 |
-
if start_frame is None: start_frame = 0
|
98 |
-
if end_frame is None: end_frame = self.n_frames
|
99 |
-
|
100 |
-
if self._chunk_size is not None:
|
101 |
-
if end_frame - start_frame > self._chunk_size:
|
102 |
-
ich1 = int(start_frame / self._chunk_size)
|
103 |
-
ich2 = int((end_frame - 1) / self._chunk_size)
|
104 |
-
|
105 |
-
with tempfile.NamedTemporaryFile(prefix=self._tmp_folder) as fp:
|
106 |
-
filtered_chunk = np.memmap(fp, dtype=self._dtype, shape=(self.n_channels, int(end_frame - start_frame)), mode="w+")
|
107 |
-
pos_list, start_list, end_list = [], [], []
|
108 |
-
pos = 0
|
109 |
-
|
110 |
-
for ich in range(ich1, ich2 + 1):
|
111 |
-
start0 = (start_frame - ich * self._chunk_size) if ich == ich1 else 0
|
112 |
-
end0 = end_frame - ich * self._chunk_size if ich == ich2 else self._chunk_size
|
113 |
-
pos_list.append(pos)
|
114 |
-
start_list.append(start0)
|
115 |
-
end_list.append(end0)
|
116 |
-
pos += end0 - start0
|
117 |
-
|
118 |
-
Parallel(n_jobs=self.n_jobs)(delayed(self._iterate_chunk)(filtered_chunk, pos, end0, start0, ich) for pos, start0, end0, ich in zip(tqdm(pos_list, disable=not (self.use_tqdm)), start_list, end_list, range(ich1, ich2 + 1)))
|
119 |
-
return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
|
120 |
-
|
121 |
-
filtered_chunk = self.filter_chunk(start_frame=0, end_frame=end_frame)
|
122 |
-
return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
|
123 |
-
|
124 |
-
class TG(torch.nn.Module):
|
125 |
-
@torch.no_grad()
|
126 |
-
def __init__(self, sr, nonstationary = False, n_std_thresh_stationary = 1.5, n_thresh_nonstationary = 1.3, temp_coeff_nonstationary = 0.1, n_movemean_nonstationary = 20, prop_decrease = 1.0, n_fft = 1024, win_length = None, hop_length = None, freq_mask_smooth_hz = 500, time_mask_smooth_ms = 50):
|
127 |
-
super().__init__()
|
128 |
-
self.sr = sr
|
129 |
-
self.nonstationary = nonstationary
|
130 |
-
assert 0.0 <= prop_decrease <= 1.0
|
131 |
-
self.prop_decrease = prop_decrease
|
132 |
-
self.n_fft = n_fft
|
133 |
-
self.win_length = self.n_fft if win_length is None else win_length
|
134 |
-
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
|
135 |
-
self.n_std_thresh_stationary = n_std_thresh_stationary
|
136 |
-
self.temp_coeff_nonstationary = temp_coeff_nonstationary
|
137 |
-
self.n_movemean_nonstationary = n_movemean_nonstationary
|
138 |
-
self.n_thresh_nonstationary = n_thresh_nonstationary
|
139 |
-
self.freq_mask_smooth_hz = freq_mask_smooth_hz
|
140 |
-
self.time_mask_smooth_ms = time_mask_smooth_ms
|
141 |
-
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
|
142 |
-
|
143 |
-
@torch.no_grad()
|
144 |
-
def _generate_mask_smoothing_filter(self):
|
145 |
-
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None
|
146 |
-
n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))))
|
147 |
-
if n_grad_freq < 1: raise ValueError(translations["freq_mask_smooth_hz"].format(hz=int((self.sr / (self._n_fft / 2)))))
|
148 |
-
|
149 |
-
n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)))
|
150 |
-
if n_grad_time < 1: raise ValueError(translations["time_mask_smooth_ms"].format(ms=int((self._hop_length / self.sr) * 1000)))
|
151 |
-
if n_grad_time == 1 and n_grad_freq == 1: return None
|
152 |
-
|
153 |
-
smoothing_filter = torch.outer(torch.cat([linspace(0, 1, n_grad_freq + 1, endpoint=False), linspace(1, 0, n_grad_freq + 2)])[1:-1], torch.cat([linspace(0, 1, n_grad_time + 1, endpoint=False), linspace(1, 0, n_grad_time + 2)])[1:-1]).unsqueeze(0).unsqueeze(0)
|
154 |
-
return smoothing_filter / smoothing_filter.sum()
|
155 |
-
|
156 |
-
@torch.no_grad()
|
157 |
-
def _stationary_mask(self, X_db, xn = None):
|
158 |
-
XN_db = amp_to_db(torch.stft(xn, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(xn.device))).to(dtype=X_db.dtype) if xn is not None else X_db
|
159 |
-
std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
|
160 |
-
return torch.gt(X_db, (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2))
|
161 |
-
|
162 |
-
@torch.no_grad()
|
163 |
-
def _nonstationary_mask(self, X_abs):
|
164 |
-
X_smoothed = (conv1d(X_abs.reshape(-1, 1, X_abs.shape[-1]), torch.ones(self.n_movemean_nonstationary, dtype=X_abs.dtype, device=X_abs.device).view(1, 1, -1), padding="same").view(X_abs.shape) / self.n_movemean_nonstationary)
|
165 |
-
return temperature_sigmoid(((X_abs - X_smoothed) / X_smoothed), self.n_thresh_nonstationary, self.temp_coeff_nonstationary)
|
166 |
-
|
167 |
-
def forward(self, x, xn = None):
|
168 |
-
assert x.ndim == 2
|
169 |
-
if x.shape[-1] < self.win_length * 2: raise Exception(f"{translations['x']} {self.win_length * 2}")
|
170 |
-
assert xn is None or xn.ndim == 1 or xn.ndim == 2
|
171 |
-
if xn is not None and xn.shape[-1] < self.win_length * 2: raise Exception(f"{translations['xn']} {self.win_length * 2}")
|
172 |
-
|
173 |
-
X = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(x.device))
|
174 |
-
sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X), xn)
|
175 |
-
|
176 |
-
sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
|
177 |
-
if self.smoothing_filter is not None: sig_mask = conv2d(sig_mask.unsqueeze(1), self.smoothing_filter.to(sig_mask.dtype), padding="same")
|
178 |
-
|
179 |
-
Y = X * sig_mask.squeeze(1)
|
180 |
-
return torch.istft(Y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=True, window=torch.hann_window(self.win_length).to(Y.device)).to(dtype=x.dtype)
|
181 |
-
|
182 |
-
class StreamedTorchGate(SpectralGate):
|
183 |
-
def __init__(self, y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, n_std_thresh_stationary=1.5, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, n_jobs=1, device="cpu"):
|
184 |
-
super().__init__(y=y, sr=sr, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, tmp_folder=tmp_folder, prop_decrease=prop_decrease, use_tqdm=use_tqdm, n_jobs=n_jobs)
|
185 |
-
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
186 |
-
|
187 |
-
if y_noise is not None:
|
188 |
-
if y_noise.shape[-1] > y.shape[-1] and clip_noise_stationary: y_noise = y_noise[: y.shape[-1]]
|
189 |
-
y_noise = torch.from_numpy(y_noise).to(device)
|
190 |
-
if len(y_noise.shape) == 1: y_noise = y_noise.unsqueeze(0)
|
191 |
-
|
192 |
-
self.y_noise = y_noise
|
193 |
-
self.tg = TG(sr=sr, nonstationary=not stationary, n_std_thresh_stationary=n_std_thresh_stationary, n_thresh_nonstationary=thresh_n_mult_nonstationary, temp_coeff_nonstationary=1 / sigmoid_slope_nonstationary, n_movemean_nonstationary=int(time_constant_s / self._hop_length * sr), prop_decrease=prop_decrease, n_fft=self._n_fft, win_length=self._win_length, hop_length=self._hop_length, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms).to(device)
|
194 |
-
|
195 |
-
def _do_filter(self, chunk):
|
196 |
-
if type(chunk) is np.ndarray: chunk = torch.from_numpy(chunk).to(self.device)
|
197 |
-
return self.tg(x=chunk, xn=self.y_noise).cpu().detach().numpy()
|
198 |
-
|
199 |
-
def reduce_noise(y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, device="cpu"):
|
200 |
-
return StreamedTorchGate(y=y, sr=sr, stationary=stationary, y_noise=y_noise, prop_decrease=prop_decrease, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, thresh_n_mult_nonstationary=thresh_n_mult_nonstationary, sigmoid_slope_nonstationary=sigmoid_slope_nonstationary, tmp_folder=tmp_folder, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, clip_noise_stationary=clip_noise_stationary, use_tqdm=use_tqdm, n_jobs=1, device=device).get_traces()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/pixeldrain.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import requests
|
3 |
-
|
4 |
-
def pixeldrain(url, output_dir):
|
5 |
-
try:
|
6 |
-
response = requests.get(f"https://pixeldrain.com/api/file/{url.split('pixeldrain.com/u/')[1]}")
|
7 |
-
|
8 |
-
if response.status_code == 200:
|
9 |
-
file_path = os.path.join(output_dir, (response.headers.get("Content-Disposition").split("filename=")[-1].strip('";')))
|
10 |
-
|
11 |
-
with open(file_path, "wb") as newfile:
|
12 |
-
newfile.write(response.content)
|
13 |
-
return file_path
|
14 |
-
else: return None
|
15 |
-
except Exception as e:
|
16 |
-
raise RuntimeError(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|