AnhP commited on
Commit
cdbcc8b
·
verified ·
1 Parent(s): 162797d

Delete main

Browse files
Files changed (50) hide show
  1. main/app/app.py +0 -0
  2. main/app/tensorboard.py +0 -30
  3. main/configs/config.json +0 -26
  4. main/configs/config.py +0 -70
  5. main/configs/v1/32000.json +0 -46
  6. main/configs/v1/40000.json +0 -46
  7. main/configs/v1/44100.json +0 -46
  8. main/configs/v1/48000.json +0 -46
  9. main/configs/v2/32000.json +0 -42
  10. main/configs/v2/40000.json +0 -42
  11. main/configs/v2/44100.json +0 -42
  12. main/configs/v2/48000.json +0 -42
  13. main/inference/audio_effects.py +0 -170
  14. main/inference/convert.py +0 -650
  15. main/inference/create_dataset.py +0 -240
  16. main/inference/create_index.py +0 -100
  17. main/inference/extract.py +0 -450
  18. main/inference/preprocess.py +0 -290
  19. main/inference/separator_music.py +0 -290
  20. main/inference/train.py +0 -1000
  21. main/library/algorithm/commons.py +0 -50
  22. main/library/algorithm/modules.py +0 -70
  23. main/library/algorithm/mrf_hifigan.py +0 -160
  24. main/library/algorithm/refinegan.py +0 -180
  25. main/library/algorithm/residuals.py +0 -140
  26. main/library/algorithm/separator.py +0 -330
  27. main/library/algorithm/synthesizers.py +0 -450
  28. main/library/architectures/demucs_separator.py +0 -160
  29. main/library/architectures/mdx_separator.py +0 -320
  30. main/library/predictors/CREPE.py +0 -210
  31. main/library/predictors/FCPE.py +0 -670
  32. main/library/predictors/RMVPE.py +0 -260
  33. main/library/predictors/WORLD.py +0 -90
  34. main/library/utils.py +0 -100
  35. main/library/uvr5_separator/common_separator.py +0 -250
  36. main/library/uvr5_separator/demucs/apply.py +0 -250
  37. main/library/uvr5_separator/demucs/demucs.py +0 -370
  38. main/library/uvr5_separator/demucs/hdemucs.py +0 -760
  39. main/library/uvr5_separator/demucs/htdemucs.py +0 -600
  40. main/library/uvr5_separator/demucs/states.py +0 -55
  41. main/library/uvr5_separator/demucs/utils.py +0 -8
  42. main/library/uvr5_separator/spec_utils.py +0 -900
  43. main/tools/edge_tts.py +0 -180
  44. main/tools/gdown.py +0 -110
  45. main/tools/google_tts.py +0 -30
  46. main/tools/huggingface.py +0 -24
  47. main/tools/mediafire.py +0 -30
  48. main/tools/meganz.py +0 -160
  49. main/tools/noisereduce.py +0 -200
  50. main/tools/pixeldrain.py +0 -16
main/app/app.py DELETED
The diff for this file is too large to render. See raw diff
 
main/app/tensorboard.py DELETED
@@ -1,30 +0,0 @@
1
- import os
2
- import sys
3
- import json
4
- import logging
5
- import webbrowser
6
-
7
- from tensorboard import program
8
-
9
- sys.path.append(os.getcwd())
10
-
11
- from main.configs.config import Config
12
- translations = Config().translations
13
-
14
- with open(os.path.join("main", "configs", "config.json"), "r") as f:
15
- configs = json.load(f)
16
-
17
- def launch_tensorboard():
18
- for l in ["root", "tensorboard"]:
19
- logging.getLogger(l).setLevel(logging.ERROR)
20
-
21
- tb = program.TensorBoard()
22
- tb.configure(argv=[None, "--logdir", "assets/logs", f"--port={configs["tensorboard_port"]}"])
23
- url = tb.launch()
24
-
25
- print(f"{translations['tensorboard_url']}: {url}")
26
- if "--open" in sys.argv: webbrowser.open(url)
27
-
28
- return f"{translations['tensorboard_url']}: {url}"
29
-
30
- if __name__ == "__main__": launch_tensorboard()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/config.json DELETED
@@ -1,26 +0,0 @@
1
- {
2
- "language": "vi-VN",
3
- "support_language": ["en-US", "vi-VN"],
4
- "theme": "NoCrypt/miku",
5
- "themes": ["NoCrypt/miku", "gstaff/xkcd", "JohnSmith9982/small_and_pretty", "ParityError/Interstellar", "earneleh/paris", "shivi/calm_seafoam", "Hev832/Applio", "YTheme/Minecraft", "gstaff/sketch", "SebastianBravo/simci_css", "allenai/gradio-theme", "Nymbo/Nymbo_Theme_5", "lone17/kotaemon", "Zarkel/IBM_Carbon_Theme", "SherlockRamos/Feliz", "freddyaboulton/dracula_revamped", "freddyaboulton/bad-theme-space", "gradio/dracula_revamped", "abidlabs/dracula_revamped", "gradio/dracula_test", "gradio/seafoam", "gradio/glass", "gradio/monochrome", "gradio/soft", "gradio/default", "gradio/base", "abidlabs/pakistan", "dawood/microsoft_windows", "ysharma/steampunk", "ysharma/huggingface", "abidlabs/Lime", "freddyaboulton/this-theme-does-not-exist-2", "aliabid94/new-theme", "aliabid94/test2", "aliabid94/test3", "aliabid94/test4", "abidlabs/banana", "freddyaboulton/test-blue", "gstaff/whiteboard", "ysharma/llamas", "abidlabs/font-test", "YenLai/Superhuman", "bethecloud/storj_theme", "sudeepshouche/minimalist", "knotdgaf/gradiotest", "ParityError/Anime", "Ajaxon6255/Emerald_Isle", "ParityError/LimeFace", "finlaymacklon/smooth_slate", "finlaymacklon/boxy_violet", "derekzen/stardust", "EveryPizza/Cartoony-Gradio-Theme", "Ifeanyi/Cyanister", "Tshackelton/IBMPlex-DenseReadable", "snehilsanyal/scikit-learn", "Himhimhim/xkcd", "nota-ai/theme", "rawrsor1/Everforest", "rottenlittlecreature/Moon_Goblin", "abidlabs/test-yellow", "abidlabs/test-yellow3", "idspicQstitho/dracula_revamped", "kfahn/AnimalPose", "HaleyCH/HaleyCH_Theme", "simulKitke/dracula_test", "braintacles/CrimsonNight", "wentaohe/whiteboardv2", "reilnuud/polite", "remilia/Ghostly", "Franklisi/darkmode", "coding-alt/soft", "xiaobaiyuan/theme_land", "step-3-profit/Midnight-Deep", "xiaobaiyuan/theme_demo", "Taithrah/Minimal", "Insuz/SimpleIndigo", "zkunn/Alipay_Gradio_theme", "Insuz/Mocha", "xiaobaiyuan/theme_brief", "Ama434/434-base-Barlow", "Ama434/def_barlow", "Ama434/neutral-barlow", "dawood/dracula_test", "nuttea/Softblue", "BlueDancer/Alien_Diffusion", "naughtondale/monochrome", "Dagfinn1962/standard", "default"],
6
- "mdx_model": ["Main_340", "Main_390", "Main_406", "Main_427", "Main_438", "Inst_full_292", "Inst_HQ_1", "Inst_HQ_2", "Inst_HQ_3", "Inst_HQ_4", "Inst_HQ_5", "Kim_Vocal_1", "Kim_Vocal_2", "Kim_Inst", "Inst_187_beta", "Inst_82_beta", "Inst_90_beta", "Voc_FT", "Crowd_HQ", "Inst_1", "Inst_2", "Inst_3", "MDXNET_1_9703", "MDXNET_2_9682", "MDXNET_3_9662", "Inst_Main", "MDXNET_Main", "MDXNET_9482"],
7
- "demucs_model": ["HT-Normal", "HT-Tuned", "HD_MMI", "HT_6S"],
8
- "edge_tts": ["af-ZA-AdriNeural", "af-ZA-WillemNeural", "sq-AL-AnilaNeural", "sq-AL-IlirNeural", "am-ET-AmehaNeural", "am-ET-MekdesNeural", "ar-DZ-AminaNeural", "ar-DZ-IsmaelNeural", "ar-BH-AliNeural", "ar-BH-LailaNeural", "ar-EG-SalmaNeural", "ar-EG-ShakirNeural", "ar-IQ-BasselNeural", "ar-IQ-RanaNeural", "ar-JO-SanaNeural", "ar-JO-TaimNeural", "ar-KW-FahedNeural", "ar-KW-NouraNeural", "ar-LB-LaylaNeural", "ar-LB-RamiNeural", "ar-LY-ImanNeural", "ar-LY-OmarNeural", "ar-MA-JamalNeural", "ar-MA-MounaNeural", "ar-OM-AbdullahNeural", "ar-OM-AyshaNeural", "ar-QA-AmalNeural", "ar-QA-MoazNeural", "ar-SA-HamedNeural", "ar-SA-ZariyahNeural", "ar-SY-AmanyNeural", "ar-SY-LaithNeural", "ar-TN-HediNeural", "ar-TN-ReemNeural", "ar-AE-FatimaNeural", "ar-AE-HamdanNeural", "ar-YE-MaryamNeural", "ar-YE-SalehNeural", "az-AZ-BabekNeural", "az-AZ-BanuNeural", "bn-BD-NabanitaNeural", "bn-BD-PradeepNeural", "bn-IN-BashkarNeural", "bn-IN-TanishaaNeural", "bs-BA-GoranNeural", "bs-BA-VesnaNeural", "bg-BG-BorislavNeural", "bg-BG-KalinaNeural", "my-MM-NilarNeural", "my-MM-ThihaNeural", "ca-ES-EnricNeural", "ca-ES-JoanaNeural", "zh-HK-HiuGaaiNeural", "zh-HK-HiuMaanNeural", "zh-HK-WanLungNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural", "zh-CN-YunjianNeural", "zh-CN-YunxiNeural", "zh-CN-YunxiaNeural", "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-TW-HsiaoChenNeural", "zh-TW-YunJheNeural", "zh-TW-HsiaoYuNeural", "zh-CN-shaanxi-XiaoniNeural", "hr-HR-GabrijelaNeural", "hr-HR-SreckoNeural", "cs-CZ-AntoninNeural", "cs-CZ-VlastaNeural", "da-DK-ChristelNeural", "da-DK-JeppeNeural", "nl-BE-ArnaudNeural", "nl-BE-DenaNeural", "nl-NL-ColetteNeural", "nl-NL-FennaNeural", "nl-NL-MaartenNeural", "en-AU-NatashaNeural", "en-AU-WilliamNeural", "en-CA-ClaraNeural", "en-CA-LiamNeural", "en-HK-SamNeural", "en-HK-YanNeural", "en-IN-NeerjaExpressiveNeural", "en-IN-NeerjaNeural", "en-IN-PrabhatNeural", "en-IE-ConnorNeural", "en-IE-EmilyNeural", "en-KE-AsiliaNeural", "en-KE-ChilembaNeural", "en-NZ-MitchellNeural", "en-NZ-MollyNeural", "en-NG-AbeoNeural", "en-NG-EzinneNeural", "en-PH-JamesNeural", "en-PH-RosaNeural", "en-SG-LunaNeural", "en-SG-WayneNeural", "en-ZA-LeahNeural", "en-ZA-LukeNeural", "en-TZ-ElimuNeural", "en-TZ-ImaniNeural", "en-GB-LibbyNeural", "en-GB-MaisieNeural", "en-GB-RyanNeural", "en-GB-SoniaNeural", "en-GB-ThomasNeural", "en-US-AvaMultilingualNeural", "en-US-AndrewMultilingualNeural", "en-US-EmmaMultilingualNeural", "en-US-BrianMultilingualNeural", "en-US-AvaNeural", "en-US-AndrewNeural", "en-US-EmmaNeural", "en-US-BrianNeural", "en-US-AnaNeural", "en-US-AriaNeural", "en-US-ChristopherNeural", "en-US-EricNeural", "en-US-GuyNeural", "en-US-JennyNeural", "en-US-MichelleNeural", "en-US-RogerNeural", "en-US-SteffanNeural", "et-EE-AnuNeural", "et-EE-KertNeural", "fil-PH-AngeloNeural", "fil-PH-BlessicaNeural", "fi-FI-HarriNeural", "fi-FI-NooraNeural", "fr-BE-CharlineNeural", "fr-BE-GerardNeural", "fr-CA-ThierryNeural", "fr-CA-AntoineNeural", "fr-CA-JeanNeural", "fr-CA-SylvieNeural", "fr-FR-VivienneMultilingualNeural", "fr-FR-RemyMultilingualNeural", "fr-FR-DeniseNeural", "fr-FR-EloiseNeural", "fr-FR-HenriNeural", "fr-CH-ArianeNeural", "fr-CH-FabriceNeural", "gl-ES-RoiNeural", "gl-ES-SabelaNeural", "ka-GE-EkaNeural", "ka-GE-GiorgiNeural", "de-AT-IngridNeural", "de-AT-JonasNeural", "de-DE-SeraphinaMultilingualNeural", "de-DE-FlorianMultilingualNeural", "de-DE-AmalaNeural", "de-DE-ConradNeural", "de-DE-KatjaNeural", "de-DE-KillianNeural", "de-CH-JanNeural", "de-CH-LeniNeural", "el-GR-AthinaNeural", "el-GR-NestorasNeural", "gu-IN-DhwaniNeural", "gu-IN-NiranjanNeural", "he-IL-AvriNeural", "he-IL-HilaNeural", "hi-IN-MadhurNeural", "hi-IN-SwaraNeural", "hu-HU-NoemiNeural", "hu-HU-TamasNeural", "is-IS-GudrunNeural", "is-IS-GunnarNeural", "id-ID-ArdiNeural", "id-ID-GadisNeural", "ga-IE-ColmNeural", "ga-IE-OrlaNeural", "it-IT-GiuseppeNeural", "it-IT-DiegoNeural", "it-IT-ElsaNeural", "it-IT-IsabellaNeural", "ja-JP-KeitaNeural", "ja-JP-NanamiNeural", "jv-ID-DimasNeural", "jv-ID-SitiNeural", "kn-IN-GaganNeural", "kn-IN-SapnaNeural", "kk-KZ-AigulNeural", "kk-KZ-DauletNeural", "km-KH-PisethNeural", "km-KH-SreymomNeural", "ko-KR-HyunsuNeural", "ko-KR-InJoonNeural", "ko-KR-SunHiNeural", "lo-LA-ChanthavongNeural", "lo-LA-KeomanyNeural", "lv-LV-EveritaNeural", "lv-LV-NilsNeural", "lt-LT-LeonasNeural", "lt-LT-OnaNeural", "mk-MK-AleksandarNeural", "mk-MK-MarijaNeural", "ms-MY-OsmanNeural", "ms-MY-YasminNeural", "ml-IN-MidhunNeural", "ml-IN-SobhanaNeural", "mt-MT-GraceNeural", "mt-MT-JosephNeural", "mr-IN-AarohiNeural", "mr-IN-ManoharNeural", "mn-MN-BataaNeural", "mn-MN-YesuiNeural", "ne-NP-HemkalaNeural", "ne-NP-SagarNeural", "nb-NO-FinnNeural", "nb-NO-PernilleNeural", "ps-AF-GulNawazNeural", "ps-AF-LatifaNeural", "fa-IR-DilaraNeural", "fa-IR-FaridNeural", "pl-PL-MarekNeural", "pl-PL-ZofiaNeural", "pt-BR-ThalitaNeural", "pt-BR-AntonioNeural", "pt-BR-FranciscaNeural", "pt-PT-DuarteNeural", "pt-PT-RaquelNeural", "ro-RO-AlinaNeural", "ro-RO-EmilNeural", "ru-RU-DmitryNeural", "ru-RU-SvetlanaNeural", "sr-RS-NicholasNeural", "sr-RS-SophieNeural", "si-LK-SameeraNeural", "si-LK-ThiliniNeural", "sk-SK-LukasNeural", "sk-SK-ViktoriaNeural", "sl-SI-PetraNeural", "sl-SI-RokNeural", "so-SO-MuuseNeural", "so-SO-UbaxNeural", "es-AR-ElenaNeural", "es-AR-TomasNeural", "es-BO-MarceloNeural", "es-BO-SofiaNeural", "es-CL-CatalinaNeural", "es-CL-LorenzoNeural", "es-ES-XimenaNeural", "es-CO-GonzaloNeural", "es-CO-SalomeNeural", "es-CR-JuanNeural", "es-CR-MariaNeural", "es-CU-BelkysNeural", "es-CU-ManuelNeural", "es-DO-EmilioNeural", "es-DO-RamonaNeural", "es-EC-AndreaNeural", "es-EC-LuisNeural", "es-SV-LorenaNeural", "es-SV-RodrigoNeural", "es-GQ-JavierNeural", "es-GQ-TeresaNeural", "es-GT-AndresNeural", "es-GT-MartaNeural", "es-HN-CarlosNeural", "es-HN-KarlaNeural", "es-MX-DaliaNeural", "es-MX-JorgeNeural", "es-NI-FedericoNeural", "es-NI-YolandaNeural", "es-PA-MargaritaNeural", "es-PA-RobertoNeural", "es-PY-MarioNeural", "es-PY-TaniaNeural", "es-PE-AlexNeural", "es-PE-CamilaNeural", "es-PR-KarinaNeural", "es-PR-VictorNeural", "es-ES-AlvaroNeural", "es-ES-ElviraNeural", "es-US-AlonsoNeural", "es-US-PalomaNeural", "es-UY-MateoNeural", "es-UY-ValentinaNeural", "es-VE-PaolaNeural", "es-VE-SebastianNeural", "su-ID-JajangNeural", "su-ID-TutiNeural", "sw-KE-RafikiNeural", "sw-KE-ZuriNeural", "sw-TZ-DaudiNeural", "sw-TZ-RehemaNeural", "sv-SE-MattiasNeural", "sv-SE-SofieNeural", "ta-IN-PallaviNeural", "ta-IN-ValluvarNeural", "ta-MY-KaniNeural", "ta-MY-SuryaNeural", "ta-SG-AnbuNeural", "ta-SG-VenbaNeural", "ta-LK-KumarNeural", "ta-LK-SaranyaNeural", "te-IN-MohanNeural", "te-IN-ShrutiNeural", "th-TH-NiwatNeural", "th-TH-PremwadeeNeural", "tr-TR-AhmetNeural", "tr-TR-EmelNeural", "uk-UA-OstapNeural", "uk-UA-PolinaNeural", "ur-IN-GulNeural", "ur-IN-SalmanNeural", "ur-PK-AsadNeural", "ur-PK-UzmaNeural", "uz-UZ-MadinaNeural", "uz-UZ-SardorNeural", "vi-VN-HoaiMyNeural", "vi-VN-NamMinhNeural", "cy-GB-AledNeural", "cy-GB-NiaNeural", "zu-ZA-ThandoNeural", "zu-ZA-ThembaNeural"],
9
- "google_tts_voice": ["af", "am", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fi", "fr", "fr-CA", "gl", "gu", "ha", "hi", "hr", "hu", "id", "is", "it", "iw", "ja", "jw", "km", "kn", "ko", "la", "lt", "lv", "ml", "mr", "ms", "my", "ne", "nl", "no", "pa", "pl", "pt", "pt-PT", "ro", "ru", "si", "sk", "sq", "sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "vi", "yue", "zh-CN", "zh-TW", "zh"],
10
- "separator_tab": true,
11
- "convert_tab": true,
12
- "tts_tab": true,
13
- "effects_tab": true,
14
- "create_dataset_tab": true,
15
- "training_tab": true,
16
- "fushion_tab": true,
17
- "read_tab": true,
18
- "downloads_tab": true,
19
- "settings_tab": true,
20
- "report_bug_tab": true,
21
- "app_port": 7860,
22
- "tensorboard_port": 6870,
23
- "num_of_restart": 5,
24
- "server_name": "0.0.0.0",
25
- "app_show_error": true
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/config.py DELETED
@@ -1,70 +0,0 @@
1
- import os
2
- import json
3
- import torch
4
-
5
- version_config_paths = [os.path.join(version, size) for version in ["v1", "v2"] for size in ["32000.json", "40000.json", "44100.json", "48000.json"]]
6
-
7
- def singleton(cls):
8
- instances = {}
9
- def get_instance(*args, **kwargs):
10
- if cls not in instances: instances[cls] = cls(*args, **kwargs)
11
- return instances[cls]
12
- return get_instance
13
-
14
- @singleton
15
- class Config:
16
- def __init__(self):
17
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
- self.gpu_name = (torch.cuda.get_device_name(int(self.device.split(":")[-1])) if self.device.startswith("cuda") else None)
19
- self.translations = self.multi_language()
20
- self.json_config = self.load_config_json()
21
- self.gpu_mem = None
22
- self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
23
-
24
- def multi_language(self):
25
- try:
26
- with open(os.path.join("main", "configs", "config.json"), "r") as f:
27
- configs = json.load(f)
28
-
29
- lang = configs.get("language", "vi-VN")
30
- if len([l for l in os.listdir(os.path.join("assets", "languages")) if l.endswith(".json")]) < 1: raise FileNotFoundError("Không tìm thấy bất cứ gói ngôn ngữ nào(No package languages found)")
31
-
32
- if not lang: lang = "vi-VN"
33
- if lang not in configs["support_language"]: raise ValueError("Ngôn ngữ không được hỗ trợ(Language not supported)")
34
-
35
- lang_path = os.path.join("assets", "languages", f"{lang}.json")
36
- if not os.path.exists(lang_path): lang_path = os.path.join("assets", "languages", "vi-VN.json")
37
-
38
- with open(lang_path, encoding="utf-8") as f:
39
- translations = json.load(f)
40
- except json.JSONDecodeError:
41
- print(self.translations["empty_json"].format(file=lang))
42
- pass
43
- return translations
44
-
45
- def load_config_json(self):
46
- configs = {}
47
- for config_file in version_config_paths:
48
- try:
49
- with open(os.path.join("main", "configs", config_file), "r") as f:
50
- configs[config_file] = json.load(f)
51
- except json.JSONDecodeError:
52
- print(self.translations["empty_json"].format(file=config_file))
53
- pass
54
- return configs
55
-
56
- def device_config(self):
57
- if self.device.startswith("cuda"): self.set_cuda_config()
58
- elif self.has_mps(): self.device = "mps"
59
- else: self.device = "cpu"
60
-
61
- if self.gpu_mem is not None and self.gpu_mem <= 4: return 1, 5, 30, 32
62
- return 1, 6, 38, 41
63
-
64
- def set_cuda_config(self):
65
- i_device = int(self.device.split(":")[-1])
66
- self.gpu_name = torch.cuda.get_device_name(i_device)
67
- self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
68
-
69
- def has_mps(self):
70
- return torch.backends.mps.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/32000.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [0.8, 0.99],
8
- "eps": 1e-09,
9
- "batch_size": 4,
10
- "lr_decay": 0.999875,
11
- "segment_size": 12800,
12
- "init_lr_ratio": 1,
13
- "warmup_epochs": 0,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 32000,
20
- "filter_length": 1024,
21
- "hop_length": 320,
22
- "win_length": 1024,
23
- "n_mel_channels": 80,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 256,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [3, 7, 11],
38
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
39
- "upsample_rates": [10, 4, 2, 2, 2],
40
- "upsample_initial_channel": 512,
41
- "upsample_kernel_sizes": [16, 16, 4, 4, 4],
42
- "use_spectral_norm": false,
43
- "gin_channels": 256,
44
- "spk_embed_dim": 109
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/40000.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [0.8, 0.99],
8
- "eps": 1e-09,
9
- "batch_size": 4,
10
- "lr_decay": 0.999875,
11
- "segment_size": 12800,
12
- "init_lr_ratio": 1,
13
- "warmup_epochs": 0,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 40000,
20
- "filter_length": 2048,
21
- "hop_length": 400,
22
- "win_length": 2048,
23
- "n_mel_channels": 125,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 256,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [3, 7, 11],
38
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
39
- "upsample_rates": [10, 10, 2, 2],
40
- "upsample_initial_channel": 512,
41
- "upsample_kernel_sizes": [16, 16, 4, 4],
42
- "use_spectral_norm": false,
43
- "gin_channels": 256,
44
- "spk_embed_dim": 109
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/44100.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [0.8, 0.99],
8
- "eps": 1e-09,
9
- "batch_size": 4,
10
- "lr_decay": 0.999875,
11
- "segment_size": 15876,
12
- "init_lr_ratio": 1,
13
- "warmup_epochs": 0,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 44100,
20
- "filter_length": 2048,
21
- "hop_length": 441,
22
- "win_length": 2048,
23
- "n_mel_channels": 160,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 256,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [3, 7, 11],
38
- "resblock_dilation_sizes": [[1, 3, 5], [ 1, 3, 5], [1, 3, 5]],
39
- "upsample_rates": [7, 7, 3, 3],
40
- "upsample_initial_channel": 512,
41
- "upsample_kernel_sizes": [14, 14, 6, 6],
42
- "use_spectral_norm": false,
43
- "gin_channels": 256,
44
- "spk_embed_dim": 109
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/48000.json DELETED
@@ -1,46 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [0.8, 0.99],
8
- "eps": 1e-09,
9
- "batch_size": 4,
10
- "lr_decay": 0.999875,
11
- "segment_size": 11520,
12
- "init_lr_ratio": 1,
13
- "warmup_epochs": 0,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 48000,
20
- "filter_length": 2048,
21
- "hop_length": 480,
22
- "win_length": 2048,
23
- "n_mel_channels": 128,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 256,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [3, 7, 11],
38
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
39
- "upsample_rates": [10, 6, 2, 2, 2],
40
- "upsample_initial_channel": 512,
41
- "upsample_kernel_sizes": [16, 16, 4, 4, 4],
42
- "use_spectral_norm": false,
43
- "gin_channels": 256,
44
- "spk_embed_dim": 109
45
- }
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/32000.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [0.8, 0.99],
7
- "eps": 1e-09,
8
- "lr_decay": 0.999875,
9
- "segment_size": 12800,
10
- "c_mel": 45,
11
- "c_kl": 1.0
12
- },
13
- "data": {
14
- "max_wav_value": 32768.0,
15
- "sample_rate": 32000,
16
- "filter_length": 1024,
17
- "hop_length": 320,
18
- "win_length": 1024,
19
- "n_mel_channels": 80,
20
- "mel_fmin": 0.0,
21
- "mel_fmax": null
22
- },
23
- "model": {
24
- "inter_channels": 192,
25
- "hidden_channels": 192,
26
- "filter_channels": 768,
27
- "text_enc_hidden_dim": 768,
28
- "n_heads": 2,
29
- "n_layers": 6,
30
- "kernel_size": 3,
31
- "p_dropout": 0,
32
- "resblock": "1",
33
- "resblock_kernel_sizes": [3, 7, 11],
34
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
35
- "upsample_rates": [10, 8, 2, 2],
36
- "upsample_initial_channel": 512,
37
- "upsample_kernel_sizes": [20, 16, 4, 4],
38
- "use_spectral_norm": false,
39
- "gin_channels": 256,
40
- "spk_embed_dim": 109
41
- }
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/40000.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [0.8, 0.99],
7
- "eps": 1e-09,
8
- "lr_decay": 0.999875,
9
- "segment_size": 12800,
10
- "c_mel": 45,
11
- "c_kl": 1.0
12
- },
13
- "data": {
14
- "max_wav_value": 32768.0,
15
- "sample_rate": 40000,
16
- "filter_length": 2048,
17
- "hop_length": 400,
18
- "win_length": 2048,
19
- "n_mel_channels": 125,
20
- "mel_fmin": 0.0,
21
- "mel_fmax": null
22
- },
23
- "model": {
24
- "inter_channels": 192,
25
- "hidden_channels": 192,
26
- "filter_channels": 768,
27
- "text_enc_hidden_dim": 768,
28
- "n_heads": 2,
29
- "n_layers": 6,
30
- "kernel_size": 3,
31
- "p_dropout": 0,
32
- "resblock": "1",
33
- "resblock_kernel_sizes": [3, 7, 11],
34
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
35
- "upsample_rates": [10, 10, 2, 2],
36
- "upsample_initial_channel": 512,
37
- "upsample_kernel_sizes": [16, 16, 4, 4],
38
- "use_spectral_norm": false,
39
- "gin_channels": 256,
40
- "spk_embed_dim": 109
41
- }
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/44100.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [0.8, 0.99],
7
- "eps": 1e-09,
8
- "lr_decay": 0.999875,
9
- "segment_size": 15876,
10
- "c_mel": 45,
11
- "c_kl": 1.0
12
- },
13
- "data": {
14
- "max_wav_value": 32768.0,
15
- "sample_rate": 44100,
16
- "filter_length": 2048,
17
- "hop_length": 441,
18
- "win_length": 2048,
19
- "n_mel_channels": 160,
20
- "mel_fmin": 0.0,
21
- "mel_fmax": null
22
- },
23
- "model": {
24
- "inter_channels": 192,
25
- "hidden_channels": 192,
26
- "filter_channels": 768,
27
- "text_enc_hidden_dim": 768,
28
- "n_heads": 2,
29
- "n_layers": 6,
30
- "kernel_size": 3,
31
- "p_dropout": 0,
32
- "resblock": "1",
33
- "resblock_kernel_sizes": [3, 7, 11],
34
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
35
- "upsample_rates": [7, 7, 3, 3],
36
- "upsample_initial_channel": 512,
37
- "upsample_kernel_sizes": [14, 14, 6, 6],
38
- "use_spectral_norm": false,
39
- "gin_channels": 256,
40
- "spk_embed_dim": 109
41
- }
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/48000.json DELETED
@@ -1,42 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [0.8, 0.99],
7
- "eps": 1e-09,
8
- "lr_decay": 0.999875,
9
- "segment_size": 17280,
10
- "c_mel": 45,
11
- "c_kl": 1.0
12
- },
13
- "data": {
14
- "max_wav_value": 32768.0,
15
- "sample_rate": 48000,
16
- "filter_length": 2048,
17
- "hop_length": 480,
18
- "win_length": 2048,
19
- "n_mel_channels": 128,
20
- "mel_fmin": 0.0,
21
- "mel_fmax": null
22
- },
23
- "model": {
24
- "inter_channels": 192,
25
- "hidden_channels": 192,
26
- "filter_channels": 768,
27
- "text_enc_hidden_dim": 768,
28
- "n_heads": 2,
29
- "n_layers": 6,
30
- "kernel_size": 3,
31
- "p_dropout": 0,
32
- "resblock": "1",
33
- "resblock_kernel_sizes": [3, 7, 11],
34
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
35
- "upsample_rates": [12, 10, 2, 2],
36
- "upsample_initial_channel": 512,
37
- "upsample_kernel_sizes": [24, 20, 4, 4],
38
- "use_spectral_norm": false,
39
- "gin_channels": 256,
40
- "spk_embed_dim": 109
41
- }
42
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/audio_effects.py DELETED
@@ -1,170 +0,0 @@
1
- import os
2
- import sys
3
- import librosa
4
- import argparse
5
-
6
- import numpy as np
7
- import soundfile as sf
8
-
9
- from distutils.util import strtobool
10
- from scipy.signal import butter, filtfilt
11
- from pedalboard import Pedalboard, Chorus, Distortion, Reverb, PitchShift, Delay, Limiter, Gain, Bitcrush, Clipping, Compressor, Phaser, HighpassFilter
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from main.configs.config import Config
16
- from main.library.utils import pydub_convert
17
-
18
- translations = Config().translations
19
-
20
- def parse_arguments():
21
- parser = argparse.ArgumentParser()
22
- parser.add_argument("--input_path", type=str, required=True)
23
- parser.add_argument("--output_path", type=str, default="./audios/apply_effects.wav")
24
- parser.add_argument("--export_format", type=str, default="wav")
25
- parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
26
- parser.add_argument("--resample_sr", type=int, default=0)
27
- parser.add_argument("--chorus", type=lambda x: bool(strtobool(x)), default=False)
28
- parser.add_argument("--chorus_depth", type=float, default=0.5)
29
- parser.add_argument("--chorus_rate", type=float, default=1.5)
30
- parser.add_argument("--chorus_mix", type=float, default=0.5)
31
- parser.add_argument("--chorus_delay", type=int, default=10)
32
- parser.add_argument("--chorus_feedback", type=float, default=0)
33
- parser.add_argument("--distortion", type=lambda x: bool(strtobool(x)), default=False)
34
- parser.add_argument("--drive_db", type=int, default=20)
35
- parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
36
- parser.add_argument("--reverb_room_size", type=float, default=0.5)
37
- parser.add_argument("--reverb_damping", type=float, default=0.5)
38
- parser.add_argument("--reverb_wet_level", type=float, default=0.33)
39
- parser.add_argument("--reverb_dry_level", type=float, default=0.67)
40
- parser.add_argument("--reverb_width", type=float, default=1)
41
- parser.add_argument("--reverb_freeze_mode", type=lambda x: bool(strtobool(x)), default=False)
42
- parser.add_argument("--pitchshift", type=lambda x: bool(strtobool(x)), default=False)
43
- parser.add_argument("--pitch_shift", type=int, default=0)
44
- parser.add_argument("--delay", type=lambda x: bool(strtobool(x)), default=False)
45
- parser.add_argument("--delay_seconds", type=float, default=0.5)
46
- parser.add_argument("--delay_feedback", type=float, default=0.5)
47
- parser.add_argument("--delay_mix", type=float, default=0.5)
48
- parser.add_argument("--compressor", type=lambda x: bool(strtobool(x)), default=False)
49
- parser.add_argument("--compressor_threshold", type=int, default=-20)
50
- parser.add_argument("--compressor_ratio", type=float, default=4)
51
- parser.add_argument("--compressor_attack_ms", type=float, default=10)
52
- parser.add_argument("--compressor_release_ms", type=int, default=200)
53
- parser.add_argument("--limiter", type=lambda x: bool(strtobool(x)), default=False)
54
- parser.add_argument("--limiter_threshold", type=int, default=0)
55
- parser.add_argument("--limiter_release", type=int, default=100)
56
- parser.add_argument("--gain", type=lambda x: bool(strtobool(x)), default=False)
57
- parser.add_argument("--gain_db", type=int, default=0)
58
- parser.add_argument("--bitcrush", type=lambda x: bool(strtobool(x)), default=False)
59
- parser.add_argument("--bitcrush_bit_depth", type=int, default=16)
60
- parser.add_argument("--clipping", type=lambda x: bool(strtobool(x)), default=False)
61
- parser.add_argument("--clipping_threshold", type=int, default=-10)
62
- parser.add_argument("--phaser", type=lambda x: bool(strtobool(x)), default=False)
63
- parser.add_argument("--phaser_rate_hz", type=float, default=0.5)
64
- parser.add_argument("--phaser_depth", type=float, default=0.5)
65
- parser.add_argument("--phaser_centre_frequency_hz", type=int, default=1000)
66
- parser.add_argument("--phaser_feedback", type=float, default=0)
67
- parser.add_argument("--phaser_mix", type=float, default=0.5)
68
- parser.add_argument("--treble_bass_boost", type=lambda x: bool(strtobool(x)), default=False)
69
- parser.add_argument("--bass_boost_db", type=int, default=0)
70
- parser.add_argument("--bass_boost_frequency", type=int, default=100)
71
- parser.add_argument("--treble_boost_db", type=int, default=0)
72
- parser.add_argument("--treble_boost_frequency", type=int, default=3000)
73
- parser.add_argument("--fade_in_out", type=lambda x: bool(strtobool(x)), default=False)
74
- parser.add_argument("--fade_in_duration", type=float, default=2000)
75
- parser.add_argument("--fade_out_duration", type=float, default=2000)
76
- parser.add_argument("--audio_combination", type=lambda x: bool(strtobool(x)), default=False)
77
- parser.add_argument("--audio_combination_input", type=str)
78
-
79
- return parser.parse_args()
80
-
81
- def process_audio(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, pitchshift, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out, audio_combination, audio_combination_input):
82
- def bass_boost(audio, gain_db, frequency, sample_rate):
83
- if gain_db >= 1:
84
- b, a = butter(4, frequency / (0.5 * sample_rate), btype='low')
85
- return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
86
- else: return audio
87
-
88
- def treble_boost(audio, gain_db, frequency, sample_rate):
89
- if gain_db >=1:
90
- b, a = butter(4, frequency / (0.5 * sample_rate), btype='high')
91
- return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
92
- else: return audio
93
-
94
- def fade_out_effect(audio, sr, duration=3.0):
95
- length = int(duration * sr)
96
- end = audio.shape[0]
97
-
98
- if length > end: length = end
99
- start = end - length
100
-
101
- audio[start:end] = audio[start:end] * np.linspace(1.0, 0.0, length)
102
- return audio
103
-
104
- def fade_in_effect(audio, sr, duration=3.0):
105
- length = int(duration * sr)
106
- start = 0
107
-
108
- if length > audio.shape[0]: length = audio.shape[0]
109
- end = length
110
-
111
- audio[start:end] = audio[start:end] * np.linspace(0.0, 1.0, length)
112
- return audio
113
-
114
- if not input_path or not os.path.exists(input_path):
115
- print(translations["input_not_valid"])
116
- sys.exit(1)
117
-
118
- if not output_path:
119
- print(translations["output_not_valid"])
120
- sys.exit(1)
121
-
122
- if os.path.exists(output_path): os.remove(output_path)
123
-
124
- try:
125
- audio, sample_rate = sf.read(input_path)
126
- except Exception as e:
127
- raise RuntimeError(translations["errors_loading_audio"].format(e=e))
128
-
129
- try:
130
- board = Pedalboard([HighpassFilter()])
131
-
132
- if chorus: board.append(Chorus(depth=chorus_depth, rate_hz=chorus_rate, mix=chorus_mix, centre_delay_ms=chorus_delay, feedback=chorus_feedback))
133
- if distortion: board.append(Distortion(drive_db=distortion_drive))
134
- if reverb: board.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=reverb_dry_level, width=reverb_width, freeze_mode=1 if reverb_freeze_mode else 0))
135
- if pitchshift: board.append(PitchShift(semitones=pitch_shift))
136
- if delay: board.append(Delay(delay_seconds=delay_seconds, feedback=delay_feedback, mix=delay_mix))
137
- if compressor: board.append(Compressor(threshold_db=compressor_threshold, ratio=compressor_ratio, attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
138
- if limiter: board.append(Limiter(threshold_db=limiter_threshold, release_ms=limiter_release))
139
- if gain: board.append(Gain(gain_db=gain_db))
140
- if bitcrush: board.append(Bitcrush(bit_depth=bitcrush_bit_depth))
141
- if clipping: board.append(Clipping(threshold_db=clipping_threshold))
142
- if phaser: board.append(Phaser(rate_hz=phaser_rate_hz, depth=phaser_depth, centre_frequency_hz=phaser_centre_frequency_hz, feedback=phaser_feedback, mix=phaser_mix))
143
-
144
- processed_audio = board(audio, sample_rate)
145
-
146
- if treble_bass_boost:
147
- processed_audio = bass_boost(processed_audio, bass_boost_db, bass_boost_frequency, sample_rate)
148
- processed_audio = treble_boost(processed_audio, treble_boost_db, treble_boost_frequency, sample_rate)
149
-
150
- if fade_in_out:
151
- processed_audio = fade_in_effect(processed_audio, sample_rate, fade_in_duration)
152
- processed_audio = fade_out_effect(processed_audio, sample_rate, fade_out_duration)
153
-
154
- if resample_sr != sample_rate and resample_sr > 0 and resample:
155
- target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - resample_sr))
156
- processed_audio = librosa.resample(processed_audio, orig_sr=sample_rate, target_sr=target_sr, res_type="soxr_vhq")
157
- sample_rate = target_sr
158
-
159
- sf.write(output_path.replace("wav", export_format), processed_audio, sample_rate, format=export_format)
160
-
161
- if audio_combination:
162
- from pydub import AudioSegment
163
- pydub_convert(AudioSegment.from_file(audio_combination_input)).overlay(pydub_convert(AudioSegment.from_file(output_path.replace("wav", export_format)))).export(output_path.replace("wav", export_format), format=export_format)
164
- except Exception as e:
165
- raise RuntimeError(translations["apply_error"].format(e=e))
166
- return output_path
167
-
168
- if __name__ == "__main__":
169
- args = parse_arguments()
170
- process_audio(input_path=args.input_path, output_path=args.output_path, resample=args.resample, resample_sr=args.resample_sr, chorus_depth=args.chorus_depth, chorus_rate=args.chorus_rate, chorus_mix=args.chorus_mix, chorus_delay=args.chorus_delay, chorus_feedback=args.chorus_feedback, distortion_drive=args.drive_db, reverb_room_size=args.reverb_room_size, reverb_damping=args.reverb_damping, reverb_wet_level=args.reverb_wet_level, reverb_dry_level=args.reverb_dry_level, reverb_width=args.reverb_width, reverb_freeze_mode=args.reverb_freeze_mode, pitch_shift=args.pitch_shift, delay_seconds=args.delay_seconds, delay_feedback=args.delay_feedback, delay_mix=args.delay_mix, compressor_threshold=args.compressor_threshold, compressor_ratio=args.compressor_ratio, compressor_attack_ms=args.compressor_attack_ms, compressor_release_ms=args.compressor_release_ms, limiter_threshold=args.limiter_threshold, limiter_release=args.limiter_release, gain_db=args.gain_db, bitcrush_bit_depth=args.bitcrush_bit_depth, clipping_threshold=args.clipping_threshold, phaser_rate_hz=args.phaser_rate_hz, phaser_depth=args.phaser_depth, phaser_centre_frequency_hz=args.phaser_centre_frequency_hz, phaser_feedback=args.phaser_feedback, phaser_mix=args.phaser_mix, bass_boost_db=args.bass_boost_db, bass_boost_frequency=args.bass_boost_frequency, treble_boost_db=args.treble_boost_db, treble_boost_frequency=args.treble_boost_frequency, fade_in_duration=args.fade_in_duration, fade_out_duration=args.fade_out_duration, export_format=args.export_format, chorus=args.chorus, distortion=args.distortion, reverb=args.reverb, pitchshift=args.pitchshift, delay=args.delay, compressor=args.compressor, limiter=args.limiter, gain=args.gain, bitcrush=args.bitcrush, clipping=args.clipping, phaser=args.phaser, treble_bass_boost=args.treble_bass_boost, fade_in_out=args.fade_in_out, audio_combination=args.audio_combination, audio_combination_input=args.audio_combination_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/convert.py DELETED
@@ -1,650 +0,0 @@
1
- import re
2
- import os
3
- import sys
4
- import time
5
- import faiss
6
- import torch
7
- import shutil
8
- import librosa
9
- import logging
10
- import argparse
11
- import warnings
12
- import parselmouth
13
- import onnxruntime
14
- import logging.handlers
15
-
16
- import numpy as np
17
- import soundfile as sf
18
- import torch.nn.functional as F
19
-
20
- from tqdm import tqdm
21
- from scipy import signal
22
- from distutils.util import strtobool
23
- from fairseq import checkpoint_utils
24
-
25
- warnings.filterwarnings("ignore")
26
- sys.path.append(os.getcwd())
27
-
28
- from main.configs.config import Config
29
- from main.library.predictors.FCPE import FCPE
30
- from main.library.predictors.RMVPE import RMVPE
31
- from main.library.predictors.WORLD import PYWORLD
32
- from main.library.algorithm.synthesizers import Synthesizer
33
- from main.library.predictors.CREPE import predict, mean, median
34
- from main.library.utils import check_predictors, check_embedders, load_audio, process_audio, merge_audio
35
-
36
- bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
37
- config = Config()
38
- translations = config.translations
39
- logger = logging.getLogger(__name__)
40
- logger.propagate = False
41
-
42
- for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3"]:
43
- logging.getLogger(l).setLevel(logging.ERROR)
44
-
45
- if logger.hasHandlers(): logger.handlers.clear()
46
- else:
47
- console_handler = logging.StreamHandler()
48
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
49
- console_handler.setFormatter(console_formatter)
50
- console_handler.setLevel(logging.INFO)
51
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "convert.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
52
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
53
- file_handler.setFormatter(file_formatter)
54
- file_handler.setLevel(logging.DEBUG)
55
- logger.addHandler(console_handler)
56
- logger.addHandler(file_handler)
57
- logger.setLevel(logging.DEBUG)
58
-
59
- def parse_arguments():
60
- parser = argparse.ArgumentParser()
61
- parser.add_argument("--pitch", type=int, default=0)
62
- parser.add_argument("--filter_radius", type=int, default=3)
63
- parser.add_argument("--index_rate", type=float, default=0.5)
64
- parser.add_argument("--volume_envelope", type=float, default=1)
65
- parser.add_argument("--protect", type=float, default=0.33)
66
- parser.add_argument("--hop_length", type=int, default=64)
67
- parser.add_argument("--f0_method", type=str, default="rmvpe")
68
- parser.add_argument("--embedder_model", type=str, default="contentvec_base")
69
- parser.add_argument("--input_path", type=str, required=True)
70
- parser.add_argument("--output_path", type=str, default="./audios/output.wav")
71
- parser.add_argument("--export_format", type=str, default="wav")
72
- parser.add_argument("--pth_path", type=str, required=True)
73
- parser.add_argument("--index_path", type=str)
74
- parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
75
- parser.add_argument("--f0_autotune_strength", type=float, default=1)
76
- parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
77
- parser.add_argument("--clean_strength", type=float, default=0.7)
78
- parser.add_argument("--resample_sr", type=int, default=0)
79
- parser.add_argument("--split_audio", type=lambda x: bool(strtobool(x)), default=False)
80
- parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
81
-
82
- return parser.parse_args()
83
-
84
- def main():
85
- args = parse_arguments()
86
- pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, checkpointing = args.pitch, args.filter_radius, args.index_rate, args.volume_envelope,args.protect, args.hop_length, args.f0_method, args.input_path, args.output_path, args.pth_path, args.index_path, args.f0_autotune, args.f0_autotune_strength, args.clean_audio, args.clean_strength, args.export_format, args.embedder_model, args.resample_sr, args.split_audio, args.checkpointing
87
-
88
- log_data = {translations['pitch']: pitch, translations['filter_radius']: filter_radius, translations['index_strength']: index_rate, translations['volume_envelope']: volume_envelope, translations['protect']: protect, "Hop length": hop_length, translations['f0_method']: f0_method, translations['audio_path']: input_path, translations['output_path']: output_path.replace('wav', export_format), translations['model_path']: pth_path, translations['indexpath']: index_path, translations['autotune']: f0_autotune, translations['clear_audio']: clean_audio, translations['export_format']: export_format, translations['hubert_model']: embedder_model, translations['split_audio']: split_audio, translations['memory_efficient_training']: checkpointing}
89
-
90
- if clean_audio: log_data[translations['clean_strength']] = clean_strength
91
- if resample_sr != 0: log_data[translations['sample_rate']] = resample_sr
92
- if f0_autotune: log_data[translations['autotune_rate_info']] = f0_autotune_strength
93
-
94
- for key, value in log_data.items():
95
- logger.debug(f"{key}: {value}")
96
-
97
- check_predictors(f0_method)
98
- check_embedders(embedder_model)
99
-
100
- run_convert_script(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, input_path=input_path, output_path=output_path, pth_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, split_audio=split_audio, checkpointing=checkpointing)
101
-
102
- def run_batch_convert(params):
103
- path, audio_temp, export_format, cut_files, pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, embedder_model, resample_sr, checkpointing = params["path"], params["audio_temp"], params["export_format"], params["cut_files"], params["pitch"], params["filter_radius"], params["index_rate"], params["volume_envelope"], params["protect"], params["hop_length"], params["f0_method"], params["pth_path"], params["index_path"], params["f0_autotune"], params["f0_autotune_strength"], params["clean_audio"], params["clean_strength"], params["embedder_model"], params["resample_sr"], params["checkpointing"]
104
-
105
- segment_output_path = os.path.join(audio_temp, f"output_{cut_files.index(path)}.{export_format}")
106
- if os.path.exists(segment_output_path): os.remove(segment_output_path)
107
-
108
- VoiceConverter().convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=path, audio_output_path=segment_output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing)
109
- os.remove(path)
110
-
111
- if os.path.exists(segment_output_path): return segment_output_path
112
- else:
113
- logger.warning(f"{translations['not_found_convert_file']}: {segment_output_path}")
114
- sys.exit(1)
115
-
116
- def run_convert_script(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, embedder_model, resample_sr, split_audio, checkpointing):
117
- cvt = VoiceConverter()
118
- start_time = time.time()
119
-
120
- pid_path = os.path.join("assets", "convert_pid.txt")
121
- with open(pid_path, "w") as pid_file:
122
- pid_file.write(str(os.getpid()))
123
-
124
- if not pth_path or not os.path.exists(pth_path) or os.path.isdir(pth_path) or not pth_path.endswith(".pth"):
125
- logger.warning(translations["provide_file"].format(filename=translations["model"]))
126
- sys.exit(1)
127
-
128
- output_dir = os.path.dirname(output_path) or output_path
129
- if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
130
-
131
- processed_segments = []
132
- audio_temp = os.path.join("audios_temp")
133
- if not os.path.exists(audio_temp) and split_audio: os.makedirs(audio_temp, exist_ok=True)
134
-
135
- if os.path.isdir(input_path):
136
- try:
137
- logger.info(translations["convert_batch"])
138
- audio_files = [f for f in os.listdir(input_path) if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
139
-
140
- if not audio_files:
141
- logger.warning(translations["not_found_audio"])
142
- sys.exit(1)
143
-
144
- logger.info(translations["found_audio"].format(audio_files=len(audio_files)))
145
-
146
- for audio in audio_files:
147
- audio_path = os.path.join(input_path, audio)
148
- output_audio = os.path.join(input_path, os.path.splitext(audio)[0] + f"_output.{export_format}")
149
-
150
- if split_audio:
151
- try:
152
- cut_files, time_stamps = process_audio(logger, audio_path, audio_temp)
153
- params_list = [{"path": path, "audio_temp": audio_temp, "export_format": export_format, "cut_files": cut_files, "pitch": pitch, "filter_radius": filter_radius, "index_rate": index_rate, "volume_envelope": volume_envelope, "protect": protect, "hop_length": hop_length, "f0_method": f0_method, "pth_path": pth_path, "index_path": index_path, "f0_autotune": f0_autotune, "f0_autotune_strength": f0_autotune_strength, "clean_audio": clean_audio, "clean_strength": clean_strength, "embedder_model": embedder_model, "resample_sr": resample_sr, "checkpointing": checkpointing} for path in cut_files]
154
-
155
- with tqdm(total=len(params_list), desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
156
- for params in params_list:
157
- results = run_batch_convert(params)
158
- processed_segments.append(results)
159
- pbar.update(1)
160
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
161
-
162
- merge_audio(processed_segments, time_stamps, audio_path, output_audio, export_format)
163
- except Exception as e:
164
- logger.error(translations["error_convert_batch"].format(e=e))
165
- finally:
166
- if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
167
- else:
168
- try:
169
- logger.info(f"{translations['convert_audio']} '{audio_path}'...")
170
- if os.path.exists(output_audio): os.remove(output_audio)
171
-
172
- with tqdm(total=1, desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
173
- cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=audio_path, audio_output_path=output_audio, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing)
174
- pbar.update(1)
175
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
176
- except Exception as e:
177
- logger.error(translations["error_convert"].format(e=e))
178
-
179
- elapsed_time = time.time() - start_time
180
- logger.info(translations["convert_batch_success"].format(elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('wav', export_format)))
181
- except Exception as e:
182
- logger.error(translations["error_convert_batch_2"].format(e=e))
183
- else:
184
- logger.info(f"{translations['convert_audio']} '{input_path}'...")
185
-
186
- if not os.path.exists(input_path):
187
- logger.warning(translations["not_found_audio"])
188
- sys.exit(1)
189
-
190
- if os.path.isdir(output_path): output_path = os.path.join(output_path, f"output.{export_format}")
191
- if os.path.exists(output_path): os.remove(output_path)
192
-
193
- if split_audio:
194
- try:
195
- cut_files, time_stamps = process_audio(logger, input_path, audio_temp)
196
- params_list = [{"path": path, "audio_temp": audio_temp, "export_format": export_format, "cut_files": cut_files, "pitch": pitch, "filter_radius": filter_radius, "index_rate": index_rate, "volume_envelope": volume_envelope, "protect": protect, "hop_length": hop_length, "f0_method": f0_method, "pth_path": pth_path, "index_path": index_path, "f0_autotune": f0_autotune, "f0_autotune_strength": f0_autotune_strength, "clean_audio": clean_audio, "clean_strength": clean_strength, "embedder_model": embedder_model, "resample_sr": resample_sr, "checkpointing": checkpointing} for path in cut_files]
197
-
198
- with tqdm(total=len(params_list), desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
199
- for params in params_list:
200
- results = run_batch_convert(params)
201
- processed_segments.append(results)
202
- pbar.update(1)
203
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
204
-
205
- merge_audio(processed_segments, time_stamps, input_path, output_path.replace("wav", export_format), export_format)
206
- except Exception as e:
207
- logger.error(translations["error_convert_batch"].format(e=e))
208
- finally:
209
- if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
210
- else:
211
- try:
212
- with tqdm(total=1, desc=translations["convert_audio"], ncols=100, unit="a") as pbar:
213
- cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=input_path, audio_output_path=output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, resample_sr=resample_sr, checkpointing=checkpointing)
214
- pbar.update(1)
215
-
216
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
217
- except Exception as e:
218
- logger.error(translations["error_convert"].format(e=e))
219
-
220
- if os.path.exists(pid_path): os.remove(pid_path)
221
- elapsed_time = time.time() - start_time
222
- logger.info(translations["convert_audio_success"].format(input_path=input_path, elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('wav', export_format)))
223
-
224
- def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
225
- rms2 = F.interpolate(torch.from_numpy(librosa.feature.rms(y=target_audio, frame_length=target_rate // 2 * 2, hop_length=target_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze()
226
- return (target_audio * (torch.pow(F.interpolate(torch.from_numpy(librosa.feature.rms(y=source_audio, frame_length=source_rate // 2 * 2, hop_length=source_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze(), 1 - rate) * torch.pow(torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6), rate - 1)).numpy())
227
-
228
- class Autotune:
229
- def __init__(self, ref_freqs):
230
- self.ref_freqs = ref_freqs
231
- self.note_dict = self.ref_freqs
232
-
233
- def autotune_f0(self, f0, f0_autotune_strength):
234
- autotuned_f0 = np.zeros_like(f0)
235
-
236
- for i, freq in enumerate(f0):
237
- autotuned_f0[i] = freq + (min(self.note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
238
-
239
- return autotuned_f0
240
-
241
- class VC:
242
- def __init__(self, tgt_sr, config):
243
- self.x_pad = config.x_pad
244
- self.x_query = config.x_query
245
- self.x_center = config.x_center
246
- self.x_max = config.x_max
247
- self.sample_rate = 16000
248
- self.window = 160
249
- self.t_pad = self.sample_rate * self.x_pad
250
- self.t_pad_tgt = tgt_sr * self.x_pad
251
- self.t_pad2 = self.t_pad * 2
252
- self.t_query = self.sample_rate * self.x_query
253
- self.t_center = self.sample_rate * self.x_center
254
- self.t_max = self.sample_rate * self.x_max
255
- self.time_step = self.window / self.sample_rate * 1000
256
- self.f0_min = 50
257
- self.f0_max = 1100
258
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
259
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
260
- self.device = config.device
261
- self.ref_freqs = [49.00, 51.91, 55.00, 58.27, 61.74, 65.41, 69.30, 73.42, 77.78, 82.41, 87.31, 92.50, 98.00, 103.83, 110.00, 116.54, 123.47, 130.81, 138.59, 146.83, 155.56, 164.81, 174.61, 185.00, 196.00, 207.65, 220.00, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13, 329.63, 349.23, 369.99, 392.00, 415.30, 440.00, 466.16, 493.88, 523.25, 554.37, 587.33, 622.25, 659.25, 698.46, 739.99, 783.99, 830.61, 880.00, 932.33, 987.77, 1046.50]
262
- self.autotune = Autotune(self.ref_freqs)
263
- self.note_dict = self.autotune.note_dict
264
-
265
- def get_providers(self):
266
- ort_providers = onnxruntime.get_available_providers()
267
-
268
- if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
269
- elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
270
- else: providers = ["CPUExecutionProvider"]
271
-
272
- return providers
273
-
274
- def get_f0_pm(self, x, p_len):
275
- f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
276
- pad_size = (p_len - len(f0) + 1) // 2
277
-
278
- if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
279
- return f0
280
-
281
- def get_f0_mangio_crepe(self, x, p_len, hop_length, model="full", onnx=False):
282
- providers = self.get_providers() if onnx else None
283
-
284
- x = x.astype(np.float32)
285
- x /= np.quantile(np.abs(x), 0.999)
286
-
287
- audio = torch.unsqueeze(torch.from_numpy(x).to(self.device, copy=True), dim=0)
288
- if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
289
-
290
- p_len = p_len or x.shape[0] // hop_length
291
- source = np.array(predict(audio.detach(), self.sample_rate, hop_length, self.f0_min, self.f0_max, model, batch_size=hop_length * 2, device=self.device, pad=True, providers=providers, onnx=onnx).squeeze(0).cpu().float().numpy())
292
- source[source < 0.001] = np.nan
293
- return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
294
-
295
- def get_f0_crepe(self, x, model="full", onnx=False):
296
- providers = self.get_providers() if onnx else None
297
-
298
- f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.sample_rate, self.window, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=providers, onnx=onnx)
299
- f0, pd = mean(f0, 3), median(pd, 3)
300
- f0[pd < 0.1] = 0
301
-
302
- return f0[0].cpu().numpy()
303
-
304
- def get_f0_fcpe(self, x, p_len, hop_length, onnx=False, legacy=False):
305
- providers = self.get_providers() if onnx else None
306
-
307
- model_fcpe = FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03, providers=providers, onnx=onnx) if legacy else FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=self.window, f0_min=0, f0_max=8000, dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.006, providers=providers, onnx=onnx)
308
- f0 = model_fcpe.compute_f0(x, p_len=p_len)
309
-
310
- del model_fcpe
311
- return f0
312
-
313
- def get_f0_rmvpe(self, x, legacy=False, onnx=False):
314
- providers = self.get_providers() if onnx else None
315
-
316
- rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), device=self.device, onnx=onnx, providers=providers)
317
- f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
318
-
319
- del rmvpe_model
320
- return f0
321
-
322
- def get_f0_pyworld(self, x, filter_radius, model="harvest"):
323
- pw = PYWORLD()
324
-
325
- if model == "harvest": f0, t = pw.harvest(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
326
- elif model == "dio": f0, t = pw.dio(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
327
- else: raise ValueError(translations["method_not_valid"])
328
-
329
- f0 = pw.stonemask(x.astype(np.double), self.sample_rate, t, f0)
330
-
331
- if filter_radius > 2 or model == "dio": f0 = signal.medfilt(f0, 3)
332
- return f0
333
-
334
- def get_f0_yin(self, x, hop_length, p_len):
335
- source = np.array(librosa.yin(x.astype(np.double), sr=self.sample_rate, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length))
336
- source[source < 0.001] = np.nan
337
-
338
- return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
339
-
340
- def get_f0_pyin(self, x, hop_length, p_len):
341
- f0, _, _ = librosa.pyin(x.astype(np.double), fmin=self.f0_min, fmax=self.f0_max, sr=self.sample_rate, hop_length=hop_length)
342
- source = np.array(f0)
343
- source[source < 0.001] = np.nan
344
-
345
- return np.nan_to_num(np.interp(np.arange(0, len(source) * p_len, len(source)) / p_len, np.arange(0, len(source)), source))
346
-
347
- def get_f0_hybrid(self, methods_str, x, p_len, hop_length, filter_radius):
348
- methods_str = re.search("hybrid\[(.+)\]", methods_str)
349
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
350
-
351
- f0_computation_stack, resampled_stack = [], []
352
- logger.debug(translations["hybrid_methods"].format(methods=methods))
353
-
354
- x = x.astype(np.float32)
355
- x /= np.quantile(np.abs(x), 0.999)
356
-
357
- for method in methods:
358
- f0 = None
359
-
360
- if method == "pm": f0 = self.get_f0_pm(x, p_len)
361
- elif method == "dio": f0 = self.get_f0_pyworld(x, filter_radius, "dio")
362
- elif method == "mangio-crepe-tiny": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny")
363
- elif method == "mangio-crepe-tiny-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=True)
364
- elif method == "mangio-crepe-small": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small")
365
- elif method == "mangio-crepe-small-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=True)
366
- elif method == "mangio-crepe-medium": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium")
367
- elif method == "mangio-crepe-medium-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=True)
368
- elif method == "mangio-crepe-large": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large")
369
- elif method == "mangio-crepe-large-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=True)
370
- elif method == "mangio-crepe-full": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full")
371
- elif method == "mangio-crepe-full-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=True)
372
- elif method == "crepe-tiny": f0 = self.get_f0_crepe(x, "tiny")
373
- elif method == "crepe-tiny-onnx": f0 = self.get_f0_crepe(x, "tiny", onnx=True)
374
- elif method == "crepe-small": f0 = self.get_f0_crepe(x, "small")
375
- elif method == "crepe-small-onnx": f0 = self.get_f0_crepe(x, "small", onnx=True)
376
- elif method == "crepe-medium": f0 = self.get_f0_crepe(x, "medium")
377
- elif method == "crepe-medium-onnx": f0 = self.get_f0_crepe(x, "medium", onnx=True)
378
- elif method == "crepe-large": f0 = self.get_f0_crepe(x, "large")
379
- elif method == "crepe-large-onnx": f0 = self.get_f0_crepe(x, "large", onnx=True)
380
- elif method == "crepe-full": f0 = self.get_f0_crepe(x, "full")
381
- elif method == "crepe-full-onnx": f0 = self.get_f0_crepe(x, "full", onnx=True)
382
- elif method == "fcpe": f0 = self.get_f0_fcpe(x, p_len, int(hop_length))
383
- elif method == "fcpe-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True)
384
- elif method == "fcpe-legacy": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True)
385
- elif method == "fcpe-legacy-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True, legacy=True)
386
- elif method == "rmvpe": f0 = self.get_f0_rmvpe(x)
387
- elif method == "rmvpe-onnx": f0 = self.get_f0_rmvpe(x, onnx=True)
388
- elif method == "rmvpe-legacy": f0 = self.get_f0_rmvpe(x, legacy=True)
389
- elif method == "rmvpe-legacy-onnx": f0 = self.get_f0_rmvpe(x, legacy=True, onnx=True)
390
- elif method == "harvest": f0 = self.get_f0_pyworld(x, filter_radius, "harvest")
391
- elif method == "yin": f0 = self.get_f0_yin(x, int(hop_length), p_len)
392
- elif method == "pyin": f0 = self.get_f0_pyin(x, int(hop_length), p_len)
393
- else: raise ValueError(translations["method_not_valid"])
394
-
395
- f0_computation_stack.append(f0)
396
-
397
- for f0 in f0_computation_stack:
398
- resampled_stack.append(np.interp(np.linspace(0, len(f0), p_len), np.arange(len(f0)), f0))
399
-
400
- return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
401
-
402
- def get_f0(self, x, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength):
403
- if f0_method == "pm": f0 = self.get_f0_pm(x, p_len)
404
- elif f0_method == "dio": f0 = self.get_f0_pyworld(x, filter_radius, "dio")
405
- elif f0_method == "mangio-crepe-tiny": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny")
406
- elif f0_method == "mangio-crepe-tiny-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "tiny", onnx=True)
407
- elif f0_method == "mangio-crepe-small": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small")
408
- elif f0_method == "mangio-crepe-small-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "small", onnx=True)
409
- elif f0_method == "mangio-crepe-medium": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium")
410
- elif f0_method == "mangio-crepe-medium-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "medium", onnx=True)
411
- elif f0_method == "mangio-crepe-large": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large")
412
- elif f0_method == "mangio-crepe-large-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "large", onnx=True)
413
- elif f0_method == "mangio-crepe-full": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full")
414
- elif f0_method == "mangio-crepe-full-onnx": f0 = self.get_f0_mangio_crepe(x, p_len, int(hop_length), "full", onnx=True)
415
- elif f0_method == "crepe-tiny": f0 = self.get_f0_crepe(x, "tiny")
416
- elif f0_method == "crepe-tiny-onnx": f0 = self.get_f0_crepe(x, "tiny", onnx=True)
417
- elif f0_method == "crepe-small": f0 = self.get_f0_crepe(x, "small")
418
- elif f0_method == "crepe-small-onnx": f0 = self.get_f0_crepe(x, "small", onnx=True)
419
- elif f0_method == "crepe-medium": f0 = self.get_f0_crepe(x, "medium")
420
- elif f0_method == "crepe-medium-onnx": f0 = self.get_f0_crepe(x, "medium", onnx=True)
421
- elif f0_method == "crepe-large": f0 = self.get_f0_crepe(x, "large")
422
- elif f0_method == "crepe-large-onnx": f0 = self.get_f0_crepe(x, "large", onnx=True)
423
- elif f0_method == "crepe-full": f0 = self.get_f0_crepe(x, "full")
424
- elif f0_method == "crepe-full-onnx": f0 = self.get_f0_crepe(x, "full", onnx=True)
425
- elif f0_method == "fcpe": f0 = self.get_f0_fcpe(x, p_len, int(hop_length))
426
- elif f0_method == "fcpe-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True)
427
- elif f0_method == "fcpe-legacy": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), legacy=True)
428
- elif f0_method == "fcpe-legacy-onnx": f0 = self.get_f0_fcpe(x, p_len, int(hop_length), onnx=True, legacy=True)
429
- elif f0_method == "rmvpe": f0 = self.get_f0_rmvpe(x)
430
- elif f0_method == "rmvpe-onnx": f0 = self.get_f0_rmvpe(x, onnx=True)
431
- elif f0_method == "rmvpe-legacy": f0 = self.get_f0_rmvpe(x, legacy=True)
432
- elif f0_method == "rmvpe-legacy-onnx": f0 = self.get_f0_rmvpe(x, legacy=True, onnx=True)
433
- elif f0_method == "harvest": f0 = self.get_f0_pyworld(x, filter_radius, "harvest")
434
- elif f0_method == "yin": f0 = self.get_f0_yin(x, int(hop_length), p_len)
435
- elif f0_method == "pyin": f0 = self.get_f0_pyin(x, int(hop_length), p_len)
436
- elif "hybrid" in f0_method: f0 = self.get_f0_hybrid(f0_method, x, p_len, hop_length, filter_radius)
437
- else: raise ValueError(translations["method_not_valid"])
438
-
439
- if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
440
-
441
- f0 *= pow(2, pitch / 12)
442
- f0_mel = 1127 * np.log(1 + f0 / 700)
443
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (self.f0_mel_max - self.f0_mel_min) + 1
444
- f0_mel[f0_mel <= 1] = 1
445
- f0_mel[f0_mel > 255] = 255
446
-
447
- return np.rint(f0_mel).astype(np.int32), f0.copy()
448
-
449
- def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect):
450
- pitch_guidance = pitch != None and pitchf != None
451
- feats = torch.from_numpy(audio0).float()
452
-
453
- if feats.dim() == 2: feats = feats.mean(-1)
454
- assert feats.dim() == 1, feats.dim()
455
-
456
- feats = feats.view(1, -1)
457
- padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
458
- inputs = {"source": feats.to(self.device), "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12}
459
-
460
- with torch.no_grad():
461
- logits = model.extract_features(**inputs)
462
- feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
463
-
464
- if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
465
-
466
- if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
467
- npy = feats[0].cpu().numpy()
468
- score, ix = index.search(npy, k=8)
469
- weight = np.square(1 / score)
470
- weight /= weight.sum(axis=1, keepdims=True)
471
- npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
472
- feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
473
-
474
- feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
475
- if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
476
-
477
- p_len = audio0.shape[0] // self.window
478
-
479
- if feats.shape[1] < p_len:
480
- p_len = feats.shape[1]
481
- if pitch_guidance:
482
- pitch = pitch[:, :p_len]
483
- pitchf = pitchf[:, :p_len]
484
-
485
- if protect < 0.5 and pitch_guidance:
486
- pitchff = pitchf.clone()
487
- pitchff[pitchf > 0] = 1
488
- pitchff[pitchf < 1] = protect
489
- pitchff = pitchff.unsqueeze(-1)
490
- feats = feats * pitchff + feats0 * (1 - pitchff)
491
- feats = feats.to(feats0.dtype)
492
-
493
- p_len = torch.tensor([p_len], device=self.device).long()
494
- audio1 = ((net_g.infer(feats, p_len, pitch if pitch_guidance else None, pitchf if pitch_guidance else None, sid)[0][0, 0]).data.cpu().float().numpy())
495
-
496
- del feats, p_len, padding_mask
497
- if torch.cuda.is_available(): torch.cuda.empty_cache()
498
- return audio1
499
-
500
- def pipeline(self, model, net_g, sid, audio, pitch, f0_method, file_index, index_rate, pitch_guidance, filter_radius, tgt_sr, resample_sr, volume_envelope, version, protect, hop_length, f0_autotune, f0_autotune_strength):
501
- if file_index != "" and os.path.exists(file_index) and index_rate != 0:
502
- try:
503
- index = faiss.read_index(file_index)
504
- big_npy = index.reconstruct_n(0, index.ntotal)
505
- except Exception as e:
506
- logger.error(translations["read_faiss_index_error"].format(e=e))
507
- index = big_npy = None
508
- else: index = big_npy = None
509
-
510
- audio = signal.filtfilt(bh, ah, audio)
511
- audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
512
- opt_ts, audio_opt = [], []
513
-
514
- if audio_pad.shape[0] > self.t_max:
515
- audio_sum = np.zeros_like(audio)
516
-
517
- for i in range(self.window):
518
- audio_sum += audio_pad[i : i - self.window]
519
-
520
- for t in range(self.t_center, audio.shape[0], self.t_center):
521
- opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
522
-
523
- s = 0
524
- t = None
525
-
526
- audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
527
- sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
528
- p_len = audio_pad.shape[0] // self.window
529
-
530
- if pitch_guidance:
531
- pitch, pitchf = self.get_f0(audio_pad, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength)
532
- pitch, pitchf = pitch[:p_len], pitchf[:p_len]
533
-
534
- if self.device == "mps": pitchf = pitchf.astype(np.float32)
535
- pitch, pitchf = torch.tensor(pitch, device=self.device).unsqueeze(0).long(), torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
536
-
537
- for t in opt_ts:
538
- t = t // self.window * self.window
539
- audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], pitch[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, pitchf[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
540
- s = t
541
-
542
- audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], (pitch[:, t // self.window :] if t is not None else pitch) if pitch_guidance else None, (pitchf[:, t // self.window :] if t is not None else pitchf) if pitch_guidance else None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
543
- audio_opt = np.concatenate(audio_opt)
544
-
545
- if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, tgt_sr, volume_envelope)
546
- if resample_sr >= self.sample_rate and tgt_sr != resample_sr: audio_opt = librosa.resample(audio_opt, orig_sr=tgt_sr, target_sr=resample_sr, res_type="soxr_vhq")
547
-
548
- audio_max = np.abs(audio_opt).max() / 0.99
549
- if audio_max > 1: audio_opt /= audio_max
550
-
551
- if pitch_guidance: del pitch, pitchf
552
- del sid
553
-
554
- if torch.cuda.is_available(): torch.cuda.empty_cache()
555
- return audio_opt
556
-
557
- class VoiceConverter:
558
- def __init__(self):
559
- self.config = config
560
- self.hubert_model = None
561
- self.tgt_sr = None
562
- self.net_g = None
563
- self.vc = None
564
- self.cpt = None
565
- self.version = None
566
- self.n_spk = None
567
- self.use_f0 = None
568
- self.loaded_model = None
569
- self.vocoder = "Default"
570
- self.checkpointing = False
571
-
572
- def load_embedders(self, embedder_model):
573
- try:
574
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join("assets", "models", "embedders", embedder_model + '.pt')], suffix="")
575
- except Exception as e:
576
- logger.error(translations["read_model_error"].format(e=e))
577
- self.hubert_model = models[0].to(self.config.device).float().eval()
578
-
579
- def convert_audio(self, audio_input_path, audio_output_path, model_path, index_path, embedder_model, pitch, f0_method, index_rate, volume_envelope, protect, hop_length, f0_autotune, f0_autotune_strength, filter_radius, clean_audio, clean_strength, export_format, resample_sr = 0, sid = 0, checkpointing = False):
580
- try:
581
- self.get_vc(model_path, sid)
582
- audio = load_audio(audio_input_path)
583
- self.checkpointing = checkpointing
584
-
585
- audio_max = np.abs(audio).max() / 0.95
586
- if audio_max > 1: audio /= audio_max
587
-
588
- if not self.hubert_model:
589
- if not os.path.exists(os.path.join("assets", "models", "embedders", embedder_model + '.pt')): raise FileNotFoundError(f"{translations['not_found'].format(name=translations['model'])}: {embedder_model}")
590
- self.load_embedders(embedder_model)
591
-
592
- if self.tgt_sr != resample_sr >= 16000: self.tgt_sr = resample_sr
593
- target_sr = min([8000, 11025, 12000, 16000, 22050, 24000, 32000, 44100, 48000, 96000], key=lambda x: abs(x - self.tgt_sr))
594
-
595
- audio_output = self.vc.pipeline(model=self.hubert_model, net_g=self.net_g, sid=sid, audio=audio, pitch=pitch, f0_method=f0_method, file_index=(index_path.strip().strip('"').strip("\n").strip('"').strip().replace("trained", "added")), index_rate=index_rate, pitch_guidance=self.use_f0, filter_radius=filter_radius, tgt_sr=self.tgt_sr, resample_sr=target_sr, volume_envelope=volume_envelope, version=self.version, protect=protect, hop_length=hop_length, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength)
596
-
597
- if clean_audio:
598
- from main.tools.noisereduce import reduce_noise
599
- audio_output = reduce_noise(y=audio_output, sr=target_sr, prop_decrease=clean_strength)
600
-
601
- sf.write(audio_output_path, audio_output, target_sr, format=export_format)
602
- except Exception as e:
603
- logger.error(translations["error_convert"].format(e=e))
604
-
605
- import traceback
606
- logger.debug(traceback.format_exc())
607
-
608
- def get_vc(self, weight_root, sid):
609
- if sid == "" or sid == []:
610
- self.cleanup()
611
- if torch.cuda.is_available(): torch.cuda.empty_cache()
612
-
613
- if not self.loaded_model or self.loaded_model != weight_root:
614
- self.load_model(weight_root)
615
- if self.cpt is not None: self.setup()
616
- self.loaded_model = weight_root
617
-
618
- def cleanup(self):
619
- if self.hubert_model is not None:
620
- del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
621
- self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
622
- if torch.cuda.is_available(): torch.cuda.empty_cache()
623
-
624
- del self.net_g, self.cpt
625
- if torch.cuda.is_available(): torch.cuda.empty_cache()
626
- self.cpt = None
627
-
628
- def load_model(self, weight_root):
629
- self.cpt = (torch.load(weight_root, map_location="cpu") if os.path.isfile(weight_root) else None)
630
-
631
- def setup(self):
632
- if self.cpt is not None:
633
- self.tgt_sr = self.cpt["config"][-1]
634
- self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
635
-
636
- self.use_f0 = self.cpt.get("f0", 1)
637
- self.version = self.cpt.get("version", "v1")
638
- self.vocoder = self.cpt.get("vocoder", "Default")
639
-
640
- self.text_enc_hidden_dim = 768 if self.version == "v2" else 256
641
- self.net_g = Synthesizer(*self.cpt["config"], use_f0=self.use_f0, text_enc_hidden_dim=self.text_enc_hidden_dim, vocoder=self.vocoder, checkpointing=self.checkpointing)
642
- del self.net_g.enc_q
643
-
644
- self.net_g.load_state_dict(self.cpt["weight"], strict=False)
645
- self.net_g.eval().to(self.config.device).float()
646
-
647
- self.vc = VC(self.tgt_sr, self.config)
648
- self.n_spk = self.cpt["config"][-3]
649
-
650
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/create_dataset.py DELETED
@@ -1,240 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import yt_dlp
5
- import shutil
6
- import librosa
7
- import logging
8
- import argparse
9
- import warnings
10
- import logging.handlers
11
-
12
- from soundfile import read, write
13
- from distutils.util import strtobool
14
-
15
- sys.path.append(os.getcwd())
16
-
17
- from main.configs.config import Config
18
- from main.library.utils import process_audio, merge_audio
19
-
20
- translations = Config().translations
21
- dataset_temp = os.path.join("dataset_temp")
22
- logger = logging.getLogger(__name__)
23
-
24
- if logger.hasHandlers(): logger.handlers.clear()
25
- else:
26
- console_handler = logging.StreamHandler()
27
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
28
- console_handler.setFormatter(console_formatter)
29
- console_handler.setLevel(logging.INFO)
30
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "create_dataset.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
31
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
32
- file_handler.setFormatter(file_formatter)
33
- file_handler.setLevel(logging.DEBUG)
34
- logger.addHandler(console_handler)
35
- logger.addHandler(file_handler)
36
- logger.setLevel(logging.DEBUG)
37
-
38
- def parse_arguments():
39
- parser = argparse.ArgumentParser()
40
- parser.add_argument("--input_audio", type=str, required=True)
41
- parser.add_argument("--output_dataset", type=str, default="./dataset")
42
- parser.add_argument("--sample_rate", type=int, default=44100)
43
- parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
44
- parser.add_argument("--clean_strength", type=float, default=0.7)
45
- parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False)
46
- parser.add_argument("--kim_vocal_version", type=int, default=2)
47
- parser.add_argument("--overlap", type=float, default=0.25)
48
- parser.add_argument("--segments_size", type=int, default=256)
49
- parser.add_argument("--mdx_hop_length", type=int, default=1024)
50
- parser.add_argument("--mdx_batch_size", type=int, default=1)
51
- parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False)
52
- parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False)
53
- parser.add_argument("--skip_start_audios", type=str, default="0")
54
- parser.add_argument("--skip_end_audios", type=str, default="0")
55
-
56
- return parser.parse_args()
57
-
58
- def main():
59
- pid_path = os.path.join("assets", "create_dataset_pid.txt")
60
- with open(pid_path, "w") as pid_file:
61
- pid_file.write(str(os.getpid()))
62
-
63
- args = parse_arguments()
64
- input_audio, output_dataset, sample_rate, clean_dataset, clean_strength, separator_reverb, kim_vocal_version, overlap, segments_size, hop_length, batch_size, denoise_mdx, skip, skip_start_audios, skip_end_audios = args.input_audio, args.output_dataset, args.sample_rate, args.clean_dataset, args.clean_strength, args.separator_reverb, args.kim_vocal_version, args.overlap, args.segments_size, args.mdx_hop_length, args.mdx_batch_size, args.denoise_mdx, args.skip, args.skip_start_audios, args.skip_end_audios
65
- log_data = {translations['audio_path']: input_audio, translations['output_path']: output_dataset, translations['sr']: sample_rate, translations['clear_dataset']: clean_dataset, translations['dereveb_audio']: separator_reverb, translations['segments_size']: segments_size, translations['overlap']: overlap, "Hop length": hop_length, translations['batch_size']: batch_size, translations['denoise_mdx']: denoise_mdx, translations['skip']: skip}
66
-
67
- if clean_dataset: log_data[translations['clean_strength']] = clean_strength
68
- if skip:
69
- log_data[translations['skip_start']] = skip_start_audios
70
- log_data[translations['skip_end']] = skip_end_audios
71
-
72
- for key, value in log_data.items():
73
- logger.debug(f"{key}: {value}")
74
-
75
- if kim_vocal_version not in [1, 2]: raise ValueError(translations["version_not_valid"])
76
- start_time = time.time()
77
-
78
- try:
79
- paths = []
80
-
81
- if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True)
82
- urls = input_audio.replace(", ", ",").split(",")
83
-
84
- for url in urls:
85
- path = downloader(url, urls.index(url))
86
- paths.append(path)
87
-
88
- if skip:
89
- skip_start_audios = skip_start_audios.replace(", ", ",").split(",")
90
- skip_end_audios = skip_end_audios.replace(", ", ",").split(",")
91
-
92
- if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths):
93
- logger.warning(translations["skip<audio"])
94
- sys.exit(1)
95
- elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths):
96
- logger.warning(translations["skip>audio"])
97
- sys.exit(1)
98
- else:
99
- for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios):
100
- skip_start(audio, skip_start_audio)
101
- skip_end(audio, skip_end_audio)
102
-
103
- separator_paths = []
104
-
105
- for audio in paths:
106
- vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size, sample_rate)
107
- if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size, sample_rate)
108
- separator_paths.append(vocals)
109
-
110
- paths = separator_paths
111
- processed_paths = []
112
-
113
- for audio in paths:
114
- cut_files, time_stamps = process_audio(logger, audio, os.path.dirname(audio))
115
- processed_paths.append(merge_audio(cut_files, time_stamps, audio, os.path.splitext(audio)[0] + "_processed" + ".wav", "wav"))
116
-
117
- paths = processed_paths
118
-
119
- for audio_path in paths:
120
- data, sample_rate = read(audio_path)
121
- data = librosa.to_mono(data.T)
122
-
123
- if clean_dataset:
124
- from main.tools.noisereduce import reduce_noise
125
- data = reduce_noise(y=data, prop_decrease=clean_strength)
126
-
127
- write(audio_path, data, sample_rate)
128
- except Exception as e:
129
- logger.error(f"{translations['create_dataset_error']}: {e}")
130
-
131
- import traceback
132
- logger.error(traceback.format_exc())
133
- finally:
134
- for audio in paths:
135
- shutil.move(audio, output_dataset)
136
-
137
- if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True)
138
-
139
- elapsed_time = time.time() - start_time
140
- if os.path.exists(pid_path): os.remove(pid_path)
141
- logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
142
-
143
- def downloader(url, name):
144
- with warnings.catch_warnings():
145
- warnings.simplefilter("ignore")
146
-
147
- ydl_opts = {"format": "bestaudio/best", "outtmpl": os.path.join(dataset_temp, f"{name}"), "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav", "preferredquality": "192"}], "no_warnings": True, "noplaylist": True, "noplaylist": True, "verbose": False}
148
- logger.info(f"{translations['starting_download']}: {url}...")
149
-
150
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
151
- ydl.extract_info(url)
152
- logger.info(f"{translations['download_success']}: {url}")
153
-
154
- return os.path.join(dataset_temp, f"{name}" + ".wav")
155
-
156
- def skip_start(input_file, seconds):
157
- data, sr = read(input_file)
158
- total_duration = len(data) / sr
159
-
160
- if seconds <= 0: logger.warning(translations["=<0"])
161
- elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
162
- else:
163
- logger.info(f"{translations['skip_start']}: {input_file}...")
164
- write(input_file, data[int(seconds * sr):], sr)
165
-
166
- logger.info(translations["skip_start_audio"].format(input_file=input_file))
167
-
168
- def skip_end(input_file, seconds):
169
- data, sr = read(input_file)
170
- total_duration = len(data) / sr
171
-
172
- if seconds <= 0: logger.warning(translations["=<0"])
173
- elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
174
- else:
175
- logger.info(f"{translations['skip_end']}: {input_file}...")
176
- write(input_file, data[:-int(seconds * sr)], sr)
177
-
178
- logger.info(translations["skip_end_audio"].format(input_file=input_file))
179
-
180
- def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size, sample_rate):
181
- if not os.path.exists(input):
182
- logger.warning(translations["input_not_valid"])
183
- return None
184
-
185
- if not os.path.exists(output):
186
- logger.warning(translations["output_not_valid"])
187
- return None
188
-
189
- model = f"Kim_Vocal_{version}.onnx"
190
- output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
191
-
192
- for f in output_separator:
193
- path = os.path.join(output, f)
194
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
195
-
196
- if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
197
- elif '_(Vocals)_' in f:
198
- rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
199
- os.rename(path, rename_file)
200
-
201
- return rename_file
202
-
203
- def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size, sample_rate):
204
- if not os.path.exists(input):
205
- logger.warning(translations["input_not_valid"])
206
- return None
207
-
208
- if not os.path.exists(output):
209
- logger.warning(translations["output_not_valid"])
210
- return None
211
-
212
- logger.info(f"{translations['dereverb']}: {input}...")
213
- output_dereverb = separator_main(audio_file=input, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise, sample_rate=sample_rate)
214
-
215
- for f in output_dereverb:
216
- path = os.path.join(output, f)
217
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
218
-
219
- if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
220
- elif '_(No Reverb)_' in f:
221
- rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
222
- os.rename(path, rename_file)
223
-
224
- logger.info(f"{translations['dereverb_success']}: {rename_file}")
225
- return rename_file
226
-
227
- def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, sample_rate=44100):
228
- from main.library.algorithm.separator import Separator
229
-
230
- try:
231
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise})
232
- separator.load_model(model_filename=model_filename)
233
- return separator.separate(audio_file)
234
- except:
235
- logger.debug(translations["default_setting"])
236
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise})
237
- separator.load_model(model_filename=model_filename)
238
- return separator.separate(audio_file)
239
-
240
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/create_index.py DELETED
@@ -1,100 +0,0 @@
1
- import os
2
- import sys
3
- import faiss
4
- import logging
5
- import argparse
6
- import logging.handlers
7
-
8
- import numpy as np
9
-
10
- from multiprocessing import cpu_count
11
- from sklearn.cluster import MiniBatchKMeans
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from main.configs.config import Config
16
- translations = Config().translations
17
-
18
-
19
- def parse_arguments():
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument("--model_name", type=str, required=True)
22
- parser.add_argument("--rvc_version", type=str, default="v2")
23
- parser.add_argument("--index_algorithm", type=str, default="Auto")
24
-
25
- return parser.parse_args()
26
-
27
- def main():
28
- args = parse_arguments()
29
-
30
- exp_dir = os.path.join("assets", "logs", args.model_name)
31
- version = args.rvc_version
32
- index_algorithm = args.index_algorithm
33
- logger = logging.getLogger(__name__)
34
-
35
- if logger.hasHandlers(): logger.handlers.clear()
36
- else:
37
- console_handler = logging.StreamHandler()
38
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
39
- console_handler.setFormatter(console_formatter)
40
- console_handler.setLevel(logging.INFO)
41
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "create_index.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
42
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
43
- file_handler.setFormatter(file_formatter)
44
- file_handler.setLevel(logging.DEBUG)
45
- logger.addHandler(console_handler)
46
- logger.addHandler(file_handler)
47
- logger.setLevel(logging.DEBUG)
48
-
49
- log_data = {translations['modelname']: args.model_name, translations['model_path']: exp_dir, translations['training_version']: version, translations['index_algorithm_info']: index_algorithm}
50
- for key, value in log_data.items():
51
- logger.debug(f"{key}: {value}")
52
-
53
- try:
54
- npys = []
55
-
56
- feature_dir = os.path.join(exp_dir, f"{version}_extracted")
57
- model_name = os.path.basename(exp_dir)
58
-
59
- for name in sorted(os.listdir(feature_dir)):
60
- npys.append(np.load(os.path.join(feature_dir, name)))
61
-
62
- big_npy = np.concatenate(npys, axis=0)
63
- big_npy_idx = np.arange(big_npy.shape[0])
64
-
65
- np.random.shuffle(big_npy_idx)
66
- big_npy = big_npy[big_npy_idx]
67
-
68
- if big_npy.shape[0] > 2e5 and (index_algorithm == "Auto" or index_algorithm == "KMeans"): big_npy = (MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * cpu_count(), compute_labels=False, init="random").fit(big_npy).cluster_centers_)
69
- np.save(os.path.join(exp_dir, "total_fea.npy"), big_npy)
70
-
71
- n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
72
- index_trained = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
73
-
74
- index_ivf_trained = faiss.extract_index_ivf(index_trained)
75
- index_ivf_trained.nprobe = 1
76
-
77
- index_trained.train(big_npy)
78
- faiss.write_index(index_trained, os.path.join(exp_dir, f"trained_IVF{n_ivf}_Flat_nprobe_{index_ivf_trained.nprobe}_{model_name}_{version}.index"))
79
-
80
- index_added = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
81
- index_ivf_added = faiss.extract_index_ivf(index_added)
82
- index_ivf_added.nprobe = 1
83
-
84
- index_added.train(big_npy)
85
- batch_size_add = 8192
86
-
87
- for i in range(0, big_npy.shape[0], batch_size_add):
88
- index_added.add(big_npy[i : i + batch_size_add])
89
-
90
- index_filepath_added = os.path.join(exp_dir, f"added_IVF{n_ivf}_Flat_nprobe_{index_ivf_added.nprobe}_{model_name}_{version}.index")
91
- faiss.write_index(index_added, index_filepath_added)
92
-
93
- logger.info(f"{translations['save_index']} '{index_filepath_added}'")
94
- except Exception as e:
95
- logger.error(f"{translations['create_index_error']}: {e}")
96
-
97
- import traceback
98
- logger.debug(traceback.format_exc())
99
-
100
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/extract.py DELETED
@@ -1,450 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import time
5
- import tqdm
6
- import torch
7
- import shutil
8
- import librosa
9
- import logging
10
- import argparse
11
- import warnings
12
- import parselmouth
13
- import logging.handlers
14
-
15
- import numpy as np
16
- import soundfile as sf
17
- import torch.nn.functional as F
18
-
19
- from random import shuffle
20
- from multiprocessing import Pool
21
- from distutils.util import strtobool
22
- from fairseq import checkpoint_utils
23
- from functools import partial
24
- from concurrent.futures import ThreadPoolExecutor, as_completed
25
-
26
- sys.path.append(os.getcwd())
27
-
28
- from main.configs.config import Config
29
- from main.library.predictors.FCPE import FCPE
30
- from main.library.predictors.RMVPE import RMVPE
31
- from main.library.predictors.WORLD import PYWORLD
32
- from main.library.predictors.CREPE import predict, mean, median
33
- from main.library.utils import check_predictors, check_embedders, load_audio
34
-
35
- logger = logging.getLogger(__name__)
36
- translations = Config().translations
37
- logger.propagate = False
38
-
39
- warnings.filterwarnings("ignore")
40
- for l in ["torch", "faiss", "httpx", "fairseq", "httpcore", "faiss.loader", "numba.core", "urllib3"]:
41
- logging.getLogger(l).setLevel(logging.ERROR)
42
-
43
- def parse_arguments():
44
- parser = argparse.ArgumentParser()
45
- parser.add_argument("--model_name", type=str, required=True)
46
- parser.add_argument("--rvc_version", type=str, default="v2")
47
- parser.add_argument("--f0_method", type=str, default="rmvpe")
48
- parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
49
- parser.add_argument("--hop_length", type=int, default=128)
50
- parser.add_argument("--cpu_cores", type=int, default=2)
51
- parser.add_argument("--gpu", type=str, default="-")
52
- parser.add_argument("--sample_rate", type=int, required=True)
53
- parser.add_argument("--embedder_model", type=str, default="contentvec_base")
54
-
55
- return parser.parse_args()
56
-
57
- def generate_config(rvc_version, sample_rate, model_path):
58
- config_save_path = os.path.join(model_path, "config.json")
59
- if not os.path.exists(config_save_path): shutil.copy(os.path.join("main", "configs", rvc_version, f"{sample_rate}.json"), config_save_path)
60
-
61
- def generate_filelist(pitch_guidance, model_path, rvc_version, sample_rate):
62
- gt_wavs_dir, feature_dir = os.path.join(model_path, "sliced_audios"), os.path.join(model_path, f"{rvc_version}_extracted")
63
- f0_dir, f0nsf_dir = None, None
64
-
65
- if pitch_guidance: f0_dir, f0nsf_dir = os.path.join(model_path, "f0"), os.path.join(model_path, "f0_voiced")
66
-
67
- gt_wavs_files, feature_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir)), set(name.split(".")[0] for name in os.listdir(feature_dir))
68
- names = gt_wavs_files & feature_files & set(name.split(".")[0] for name in os.listdir(f0_dir)) & set(name.split(".")[0] for name in os.listdir(f0nsf_dir)) if pitch_guidance else gt_wavs_files & feature_files
69
-
70
- options = []
71
- mute_base_path = os.path.join("assets", "logs", "mute")
72
-
73
- for name in names:
74
- options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|{f0_dir}/{name}.wav.npy|{f0nsf_dir}/{name}.wav.npy|0" if pitch_guidance else f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|0")
75
-
76
- mute_audio_path, mute_feature_path = os.path.join(mute_base_path, "sliced_audios", f"mute{sample_rate}.wav"), os.path.join(mute_base_path, f"{rvc_version}_extracted", "mute.npy")
77
-
78
- for _ in range(2):
79
- options.append(f"{mute_audio_path}|{mute_feature_path}|{os.path.join(mute_base_path, 'f0', 'mute.wav.npy')}|{os.path.join(mute_base_path, 'f0_voiced', 'mute.wav.npy')}|0" if pitch_guidance else f"{mute_audio_path}|{mute_feature_path}|0")
80
-
81
- shuffle(options)
82
-
83
- with open(os.path.join(model_path, "filelist.txt"), "w") as f:
84
- f.write("\n".join(options))
85
-
86
- def setup_paths(exp_dir, version = None):
87
- wav_path = os.path.join(exp_dir, "sliced_audios_16k")
88
-
89
- if version:
90
- out_path = os.path.join(exp_dir, f"{version}_extracted")
91
- os.makedirs(out_path, exist_ok=True)
92
-
93
- return wav_path, out_path
94
- else:
95
- output_root1, output_root2 = os.path.join(exp_dir, "f0"), os.path.join(exp_dir, "f0_voiced")
96
- os.makedirs(output_root1, exist_ok=True); os.makedirs(output_root2, exist_ok=True)
97
-
98
- return wav_path, output_root1, output_root2
99
-
100
- def read_wave(wav_path, normalize = False):
101
- wav, sr = sf.read(wav_path)
102
- assert sr == 16000, translations["sr_not_16000"]
103
-
104
- feats = torch.from_numpy(wav).float()
105
-
106
- if feats.dim() == 2: feats = feats.mean(-1)
107
- feats = feats.view(1, -1)
108
-
109
- if normalize: feats = F.layer_norm(feats, feats.shape)
110
- return feats
111
-
112
- def get_device(gpu_index):
113
- if gpu_index == "cpu": return "cpu"
114
-
115
- try:
116
- index = int(gpu_index)
117
-
118
- if index < torch.cuda.device_count(): return f"cuda:{index}"
119
- else: logger.warning(translations["gpu_not_valid"])
120
- except ValueError:
121
- logger.warning(translations["gpu_not_valid"])
122
- return "cpu"
123
-
124
- class FeatureInput:
125
- def __init__(self, sample_rate=16000, hop_size=160, device="cpu"):
126
- self.fs = sample_rate
127
- self.hop = hop_size
128
- self.f0_bin = 256
129
- self.f0_max = 1100.0
130
- self.f0_min = 50.0
131
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
132
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
133
- self.device = device
134
-
135
- def get_providers(self):
136
- import onnxruntime
137
-
138
- ort_providers = onnxruntime.get_available_providers()
139
-
140
- if "CUDAExecutionProvider" in ort_providers: providers = ["CUDAExecutionProvider"]
141
- elif "CoreMLExecutionProvider" in ort_providers: providers = ["CoreMLExecutionProvider"]
142
- else: providers = ["CPUExecutionProvider"]
143
-
144
- return providers
145
-
146
- def compute_f0_hybrid(self, methods_str, np_arr, hop_length):
147
- methods_str = re.search("hybrid\[(.+)\]", methods_str)
148
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
149
-
150
- f0_computation_stack, resampled_stack = [], []
151
- logger.debug(translations["hybrid_methods"].format(methods=methods))
152
-
153
- for method in methods:
154
- f0 = None
155
-
156
- if method == "pm": f0 = self.get_pm(np_arr)
157
- elif method == "dio": f0 = self.get_pyworld(np_arr, "dio")
158
- elif method == "mangio-crepe-full": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "full")
159
- elif method == "mangio-crepe-full-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=True)
160
- elif method == "mangio-crepe-large": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "large")
161
- elif method == "mangio-crepe-large-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=True)
162
- elif method == "mangio-crepe-medium": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "medium")
163
- elif method == "mangio-crepe-medium-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=True)
164
- elif method == "mangio-crepe-small": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "small")
165
- elif method == "mangio-crepe-small-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=True)
166
- elif method == "mangio-crepe-tiny": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "tiny")
167
- elif method == "mangio-crepe-tiny-onnx": f0 = self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=True)
168
- elif method == "crepe-full": f0 = self.get_crepe(np_arr, "full")
169
- elif method == "crepe-full-onnx": f0 = self.get_crepe(np_arr, "full", onnx=True)
170
- elif method == "crepe-large": f0 = self.get_crepe(np_arr, "large")
171
- elif method == "crepe-large-onnx": f0 = self.get_crepe(np_arr, "large", onnx=True)
172
- elif method == "crepe-medium": f0 = self.get_crepe(np_arr, "medium")
173
- elif method == "crepe-medium-onnx": f0 = self.get_crepe(np_arr, "medium", onnx=True)
174
- elif method == "crepe-small": f0 = self.get_crepe(np_arr, "small")
175
- elif method == "crepe-small-onnx": f0 = self.get_crepe(np_arr, "small", onnx=True)
176
- elif method == "crepe-tiny": f0 = self.get_crepe(np_arr, "tiny")
177
- elif method == "crepe-tiny-onnx": f0 = self.get_crepe(np_arr, "tiny", onnx=True)
178
- elif method == "fcpe": f0 = self.get_fcpe(np_arr, int(hop_length))
179
- elif method == "fcpe-onnx": f0 = self.get_fcpe(np_arr, int(hop_length), onnx=True)
180
- elif method == "fcpe-legacy": f0 = self.get_fcpe(np_arr, int(hop_length), legacy=True)
181
- elif method == "fcpe-legacy-onnx": f0 = self.get_fcpe(np_arr, int(hop_length), onnx=True, legacy=True)
182
- elif method == "rmvpe": f0 = self.get_rmvpe(np_arr)
183
- elif method == "rmvpe-onnx": f0 = self.get_rmvpe(np_arr, onnx=True)
184
- elif method == "rmvpe-legacy": f0 = self.get_rmvpe(np_arr, legacy=True)
185
- elif method == "rmvpe-legacy-onnx": f0 = self.get_rmvpe(np_arr, legacy=True, onnx=True)
186
- elif method == "harvest": f0 = self.get_pyworld(np_arr, "harvest")
187
- elif method == "yin": f0 = self.get_yin(np_arr, int(hop_length))
188
- elif method == "pyin": return self.get_pyin(np_arr, int(hop_length))
189
- else: raise ValueError(translations["method_not_valid"])
190
-
191
- f0_computation_stack.append(f0)
192
-
193
- for f0 in f0_computation_stack:
194
- resampled_stack.append(np.interp(np.linspace(0, len(f0), (np_arr.size // self.hop)), np.arange(len(f0)), f0))
195
-
196
- return resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
197
-
198
- def compute_f0(self, np_arr, f0_method, hop_length):
199
- if f0_method == "pm": return self.get_pm(np_arr)
200
- elif f0_method == "dio": return self.get_pyworld(np_arr, "dio")
201
- elif f0_method == "mangio-crepe-full": return self.get_mangio_crepe(np_arr, int(hop_length), "full")
202
- elif f0_method == "mangio-crepe-full-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "full", onnx=True)
203
- elif f0_method == "mangio-crepe-large": return self.get_mangio_crepe(np_arr, int(hop_length), "large")
204
- elif f0_method == "mangio-crepe-large-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "large", onnx=True)
205
- elif f0_method == "mangio-crepe-medium": return self.get_mangio_crepe(np_arr, int(hop_length), "medium")
206
- elif f0_method == "mangio-crepe-medium-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "medium", onnx=True)
207
- elif f0_method == "mangio-crepe-small": return self.get_mangio_crepe(np_arr, int(hop_length), "small")
208
- elif f0_method == "mangio-crepe-small-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "small", onnx=True)
209
- elif f0_method == "mangio-crepe-tiny": return self.get_mangio_crepe(np_arr, int(hop_length), "tiny")
210
- elif f0_method == "mangio-crepe-tiny-onnx": return self.get_mangio_crepe(np_arr, int(hop_length), "tiny", onnx=True)
211
- elif f0_method == "crepe-full": return self.get_crepe(np_arr, "full")
212
- elif f0_method == "crepe-full-onnx": return self.get_crepe(np_arr, "full", onnx=True)
213
- elif f0_method == "crepe-large": return self.get_crepe(np_arr, "large")
214
- elif f0_method == "crepe-large-onnx": return self.get_crepe(np_arr, "large", onnx=True)
215
- elif f0_method == "crepe-medium": return self.get_crepe(np_arr, "medium")
216
- elif f0_method == "crepe-medium-onnx": return self.get_crepe(np_arr, "medium", onnx=True)
217
- elif f0_method == "crepe-small": return self.get_crepe(np_arr, "small")
218
- elif f0_method == "crepe-small-onnx": return self.get_crepe(np_arr, "small", onnx=True)
219
- elif f0_method == "crepe-tiny": return self.get_crepe(np_arr, "tiny")
220
- elif f0_method == "crepe-tiny-onnx": return self.get_crepe(np_arr, "tiny", onnx=True)
221
- elif f0_method == "fcpe": return self.get_fcpe(np_arr, int(hop_length))
222
- elif f0_method == "fcpe-onnx": return self.get_fcpe(np_arr, int(hop_length), onnx=True)
223
- elif f0_method == "fcpe-legacy": return self.get_fcpe(np_arr, int(hop_length), legacy=True)
224
- elif f0_method == "fcpe-legacy-onnx": return self.get_fcpe(np_arr, int(hop_length), onnx=True, legacy=True)
225
- elif f0_method == "rmvpe": return self.get_rmvpe(np_arr)
226
- elif f0_method == "rmvpe-onnx": return self.get_rmvpe(np_arr, onnx=True)
227
- elif f0_method == "rmvpe-legacy": return self.get_rmvpe(np_arr, legacy=True)
228
- elif f0_method == "rmvpe-legacy-onnx": return self.get_rmvpe(np_arr, legacy=True, onnx=True)
229
- elif f0_method == "harvest": return self.get_pyworld(np_arr, "harvest")
230
- elif f0_method == "yin": return self.get_yin(np_arr, int(hop_length))
231
- elif f0_method == "pyin": return self.get_pyin(np_arr, int(hop_length))
232
- elif "hybrid" in f0_method: return self.compute_f0_hybrid(f0_method, np_arr, int(hop_length))
233
- else: raise ValueError(translations["method_not_valid"])
234
-
235
- def get_pm(self, x):
236
- f0 = (parselmouth.Sound(x, self.fs).to_pitch_ac(time_step=(160 / 16000 * 1000) / 1000, voicing_threshold=0.6, pitch_floor=50, pitch_ceiling=1100).selected_array["frequency"])
237
- pad_size = ((x.size // self.hop) - len(f0) + 1) // 2
238
-
239
- if pad_size > 0 or (x.size // self.hop) - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, (x.size // self.hop) - len(f0) - pad_size]], mode="constant")
240
- return f0
241
-
242
- def get_mangio_crepe(self, x, hop_length, model="full", onnx=False):
243
- providers = self.get_providers() if onnx else None
244
-
245
- audio = torch.from_numpy(x.astype(np.float32)).to(self.device)
246
- audio /= torch.quantile(torch.abs(audio), 0.999)
247
- audio = audio.unsqueeze(0)
248
-
249
- source = predict(audio, self.fs, hop_length, self.f0_min, self.f0_max, model=model, batch_size=hop_length * 2, device=self.device, pad=True, providers=providers, onnx=onnx).squeeze(0).cpu().float().numpy()
250
- source[source < 0.001] = np.nan
251
-
252
- return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
253
-
254
- def get_crepe(self, x, model="full", onnx=False):
255
- providers = self.get_providers() if onnx else None
256
-
257
- f0, pd = predict(torch.tensor(np.copy(x))[None].float(), self.fs, 160, self.f0_min, self.f0_max, model, batch_size=512, device=self.device, return_periodicity=True, providers=providers, onnx=onnx)
258
- f0, pd = mean(f0, 3), median(pd, 3)
259
- f0[pd < 0.1] = 0
260
-
261
- return f0[0].cpu().numpy()
262
-
263
- def get_fcpe(self, x, hop_length, legacy=False, onnx=False):
264
- providers = self.get_providers() if onnx else None
265
-
266
- model_fcpe = FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.03, providers=providers, onnx=onnx) if legacy else FCPE(os.path.join("assets", "models", "predictors", "fcpe" + (".onnx" if onnx else ".pt")), hop_length=160, f0_min=0, f0_max=8000, dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.006, providers=providers, onnx=onnx)
267
- f0 = model_fcpe.compute_f0(x, p_len=(x.size // self.hop))
268
-
269
- del model_fcpe
270
- return f0
271
-
272
- def get_rmvpe(self, x, legacy=False, onnx=False):
273
- providers = self.get_providers() if onnx else None
274
-
275
- rmvpe_model = RMVPE(os.path.join("assets", "models", "predictors", "rmvpe" + (".onnx" if onnx else ".pt")), device=self.device, onnx=onnx, providers=providers)
276
- f0 = rmvpe_model.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else rmvpe_model.infer_from_audio(x, thred=0.03)
277
-
278
- del rmvpe_model
279
- return f0
280
-
281
- def get_pyworld(self, x, model="harvest"):
282
- pw = PYWORLD()
283
-
284
- if model == "harvest": f0, t = pw.harvest(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
285
- elif model == "dio": f0, t = pw.dio(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
286
- else: raise ValueError(translations["method_not_valid"])
287
-
288
- return pw.stonemask(x.astype(np.double), self.fs, t, f0)
289
-
290
- def get_yin(self, x, hop_length):
291
- source = np.array(librosa.yin(x.astype(np.double), sr=self.fs, fmin=self.f0_min, fmax=self.f0_max, hop_length=hop_length))
292
- source[source < 0.001] = np.nan
293
-
294
- return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
295
-
296
- def get_pyin(self, x, hop_length):
297
- f0, _, _ = librosa.pyin(x.astype(np.double), fmin=self.f0_min, fmax=self.f0_max, sr=self.fs, hop_length=hop_length)
298
-
299
- source = np.array(f0)
300
- source[source < 0.001] = np.nan
301
-
302
- return np.nan_to_num(np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source))
303
-
304
- def coarse_f0(self, f0):
305
- return np.rint(np.clip(((1127 * np.log(1 + f0 / 700)) - self.f0_mel_min) * (self.f0_bin - 2) / (self.f0_mel_max - self.f0_mel_min) + 1, 1, self.f0_bin - 1)).astype(int)
306
-
307
- def process_file(self, file_info, f0_method, hop_length):
308
- inp_path, opt_path1, opt_path2, np_arr = file_info
309
- if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"): return
310
-
311
- try:
312
- feature_pit = self.compute_f0(np_arr, f0_method, hop_length)
313
- np.save(opt_path2, feature_pit, allow_pickle=False)
314
- np.save(opt_path1, self.coarse_f0(feature_pit), allow_pickle=False)
315
- except Exception as e:
316
- raise RuntimeError(f"{translations['extract_file_error']} {inp_path}: {e}")
317
-
318
- def process_files(self, files, f0_method, hop_length, pbar):
319
- for file_info in files:
320
- self.process_file(file_info, f0_method, hop_length)
321
- pbar.update()
322
-
323
- def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus):
324
- input_root, *output_roots = setup_paths(exp_dir)
325
- output_root1, output_root2 = output_roots if len(output_roots) == 2 else (output_roots[0], None)
326
- paths = [(os.path.join(input_root, name), os.path.join(output_root1, name) if output_root1 else None, os.path.join(output_root2, name) if output_root2 else None, load_audio(os.path.join(input_root, name))) for name in sorted(os.listdir(input_root)) if "spec" not in name]
327
- logger.info(translations["extract_f0_method"].format(num_processes=num_processes, f0_method=f0_method))
328
-
329
- start_time = time.time()
330
-
331
- if gpus != "-":
332
- gpus = gpus.split("-")
333
- process_partials = []
334
-
335
- pbar = tqdm.tqdm(total=len(paths), desc=translations["extract_f0"], ncols=100, unit="p")
336
-
337
- for idx, gpu in enumerate(gpus):
338
- feature_input = FeatureInput(device=get_device(gpu))
339
- process_partials.append((feature_input, paths[idx::len(gpus)]))
340
-
341
- with ThreadPoolExecutor() as executor:
342
- for future in as_completed([executor.submit(FeatureInput.process_files, feature_input, part_paths, f0_method, hop_length, pbar) for feature_input, part_paths in process_partials]):
343
- pbar.update(1)
344
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
345
- future.result()
346
-
347
- pbar.close()
348
- else:
349
- with tqdm.tqdm(total=len(paths), desc=translations["extract_f0"], ncols=100, unit="p") as pbar:
350
- with Pool(processes=num_processes) as pool:
351
- for _ in pool.imap_unordered(partial(FeatureInput(device="cpu").process_file, f0_method=f0_method, hop_length=hop_length), paths):
352
- pbar.update(1)
353
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
354
-
355
- elapsed_time = time.time() - start_time
356
- logger.info(translations["extract_f0_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
357
-
358
- def process_file_embedding(file, wav_path, out_path, model, device, version, saved_cfg):
359
- out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
360
- if os.path.exists(out_file_path): return
361
-
362
- feats = read_wave(os.path.join(wav_path, file), normalize=saved_cfg.task.normalize).to(device).float()
363
- inputs = {"source": feats, "padding_mask": torch.BoolTensor(feats.shape).fill_(False).to(device), "output_layer": 9 if version == "v1" else 12}
364
-
365
- with torch.no_grad():
366
- model = model.to(device).float().eval()
367
- logits = model.extract_features(**inputs)
368
- feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
369
-
370
- feats = feats.squeeze(0).float().cpu().numpy()
371
-
372
- if not np.isnan(feats).any(): np.save(out_file_path, feats, allow_pickle=False)
373
- else: logger.warning(f"{file} {translations['NaN']}")
374
-
375
- def run_embedding_extraction(exp_dir, version, gpus, embedder_model):
376
- wav_path, out_path = setup_paths(exp_dir, version)
377
- logger.info(translations["start_extract_hubert"])
378
-
379
- start_time = time.time()
380
-
381
- try:
382
- models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join("assets", "models", "embedders", embedder_model + '.pt')], suffix="")
383
- except Exception as e:
384
- raise ImportError(translations["read_model_error"].format(e=e))
385
-
386
- devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]
387
- paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
388
-
389
- if not paths:
390
- logger.warning(translations["not_found_audio_file"])
391
- sys.exit(1)
392
-
393
- pbar = tqdm.tqdm(total=len(paths) * len(devices), desc=translations["extract_hubert"], ncols=100, unit="p")
394
-
395
- for task in [(file, wav_path, out_path, models[0], device, version, saved_cfg) for file in paths for device in devices]:
396
- try:
397
- process_file_embedding(*task)
398
- except Exception as e:
399
- raise RuntimeError(f"{translations['process_error']} {task[0]}: {e}")
400
-
401
- pbar.update(1)
402
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
403
-
404
- pbar.close()
405
- elapsed_time = time.time() - start_time
406
- logger.info(translations["extract_hubert_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
407
-
408
- if __name__ == "__main__":
409
- args = parse_arguments()
410
- exp_dir = os.path.join("assets", "logs", args.model_name)
411
- f0_method, hop_length, num_processes, gpus, version, pitch_guidance, sample_rate, embedder_model = args.f0_method, args.hop_length, args.cpu_cores, args.gpu, args.rvc_version, args.pitch_guidance, args.sample_rate, args.embedder_model
412
-
413
- check_predictors(f0_method)
414
- check_embedders(embedder_model)
415
-
416
- if logger.hasHandlers(): logger.handlers.clear()
417
- else:
418
- console_handler = logging.StreamHandler()
419
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
420
- console_handler.setFormatter(console_formatter)
421
- console_handler.setLevel(logging.INFO)
422
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(exp_dir, "extract.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
423
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
424
- file_handler.setFormatter(file_formatter)
425
- file_handler.setLevel(logging.DEBUG)
426
- logger.addHandler(console_handler)
427
- logger.addHandler(file_handler)
428
- logger.setLevel(logging.DEBUG)
429
-
430
- log_data = {translations['modelname']: args.model_name, translations['export_process']: exp_dir, translations['f0_method']: f0_method, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, "Gpu": gpus, "Hop length": hop_length, translations['training_version']: version, translations['extract_f0']: pitch_guidance, translations['hubert_model']: embedder_model}
431
- for key, value in log_data.items():
432
- logger.debug(f"{key}: {value}")
433
-
434
- pid_path = os.path.join(exp_dir, "extract_pid.txt")
435
- with open(pid_path, "w") as pid_file:
436
- pid_file.write(str(os.getpid()))
437
-
438
- try:
439
- run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus)
440
- run_embedding_extraction(exp_dir, version, gpus, embedder_model)
441
- generate_config(version, sample_rate, exp_dir)
442
- generate_filelist(pitch_guidance, exp_dir, version, sample_rate)
443
- except Exception as e:
444
- logger.error(f"{translations['extract_error']}: {e}")
445
-
446
- import traceback
447
- logger.debug(traceback.format_exc())
448
-
449
- if os.path.exists(pid_path): os.remove(pid_path)
450
- logger.info(f"{translations['extract_success']} {args.model_name}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/preprocess.py DELETED
@@ -1,290 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import logging
5
- import librosa
6
- import argparse
7
- import logging.handlers
8
-
9
- import numpy as np
10
- import soundfile as sf
11
- import multiprocessing as mp
12
-
13
- from tqdm import tqdm
14
- from scipy import signal
15
- from scipy.io import wavfile
16
- from distutils.util import strtobool
17
- from concurrent.futures import ProcessPoolExecutor, as_completed
18
-
19
- sys.path.append(os.getcwd())
20
-
21
- from main.configs.config import Config
22
-
23
- logger = logging.getLogger(__name__)
24
- for l in ["numba.core.byteflow", "numba.core.ssa", "numba.core.interpreter"]:
25
- logging.getLogger(l).setLevel(logging.ERROR)
26
-
27
- OVERLAP, MAX_AMPLITUDE, ALPHA, HIGH_PASS_CUTOFF, SAMPLE_RATE_16K = 0.3, 0.9, 0.75, 48, 16000
28
- translations = Config().translations
29
-
30
- def parse_arguments():
31
- parser = argparse.ArgumentParser()
32
- parser.add_argument("--model_name", type=str, required=True)
33
- parser.add_argument("--dataset_path", type=str, default="./dataset")
34
- parser.add_argument("--sample_rate", type=int, required=True)
35
- parser.add_argument("--cpu_cores", type=int, default=2)
36
- parser.add_argument("--cut_preprocess", type=lambda x: bool(strtobool(x)), default=True)
37
- parser.add_argument("--process_effects", type=lambda x: bool(strtobool(x)), default=False)
38
- parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
39
- parser.add_argument("--clean_strength", type=float, default=0.7)
40
-
41
- return parser.parse_args()
42
-
43
- def load_audio(file, sample_rate):
44
- try:
45
- audio, sr = sf.read(file.strip(" ").strip('"').strip("\n").strip('"').strip(" "))
46
-
47
- if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
48
- if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq")
49
- except Exception as e:
50
- raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
51
-
52
- return audio.flatten()
53
-
54
- class Slicer:
55
- def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
56
- if not min_length >= min_interval >= hop_size: raise ValueError(translations["min_length>=min_interval>=hop_size"])
57
- if not max_sil_kept >= hop_size: raise ValueError(translations["max_sil_kept>=hop_size"])
58
-
59
- min_interval = sr * min_interval / 1000
60
- self.threshold = 10 ** (threshold / 20.0)
61
- self.hop_size = round(sr * hop_size / 1000)
62
- self.win_size = min(round(min_interval), 4 * self.hop_size)
63
- self.min_length = round(sr * min_length / 1000 / self.hop_size)
64
- self.min_interval = round(min_interval / self.hop_size)
65
- self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
66
-
67
- def _apply_slice(self, waveform, begin, end):
68
- start_idx = begin * self.hop_size
69
-
70
- if len(waveform.shape) > 1: return waveform[:, start_idx:min(waveform.shape[1], end * self.hop_size)]
71
- else: return waveform[start_idx:min(waveform.shape[0], end * self.hop_size)]
72
-
73
- def slice(self, waveform):
74
- samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
75
- if samples.shape[0] <= self.min_length: return [waveform]
76
-
77
- rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
78
- sil_tags = []
79
- silence_start, clip_start = None, 0
80
-
81
- for i, rms in enumerate(rms_list):
82
- if rms < self.threshold:
83
- if silence_start is None: silence_start = i
84
- continue
85
-
86
- if silence_start is None: continue
87
-
88
- is_leading_silence = silence_start == 0 and i > self.max_sil_kept
89
- need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
90
-
91
- if not is_leading_silence and not need_slice_middle:
92
- silence_start = None
93
- continue
94
-
95
- if i - silence_start <= self.max_sil_kept:
96
- pos = rms_list[silence_start : i + 1].argmin() + silence_start
97
-
98
- sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
99
- clip_start = pos
100
- elif i - silence_start <= self.max_sil_kept * 2:
101
- pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
102
-
103
- pos += i - self.max_sil_kept
104
- pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
105
-
106
- if silence_start == 0:
107
- sil_tags.append((0, pos_r))
108
- clip_start = pos_r
109
- else:
110
- sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
111
- clip_start = max(pos_r, pos)
112
- else:
113
- pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
114
-
115
- sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
116
- clip_start = pos_r
117
-
118
- silence_start = None
119
- total_frames = rms_list.shape[0]
120
- if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
121
-
122
- if not sil_tags: return [waveform]
123
- else:
124
- chunks = []
125
- if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
126
-
127
- for i in range(len(sil_tags) - 1):
128
- chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
129
-
130
- if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
131
- return chunks
132
-
133
- def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
134
- y = np.pad(y, (int(frame_length // 2), int(frame_length // 2)), mode=pad_mode)
135
- axis = -1
136
-
137
- x_shape_trimmed = list(y.shape)
138
- x_shape_trimmed[axis] -= frame_length - 1
139
-
140
- xw = np.lib.stride_tricks.as_strided(y, shape=tuple(x_shape_trimmed) + tuple([frame_length]), strides=y.strides + tuple([y.strides[axis]]))
141
- xw = np.moveaxis(xw, -1, axis - 1 if axis < 0 else axis + 1)
142
-
143
- slices = [slice(None)] * xw.ndim
144
- slices[axis] = slice(0, None, hop_length)
145
-
146
- return np.sqrt(np.mean(np.abs(xw[tuple(slices)]) ** 2, axis=-2, keepdims=True))
147
-
148
- class PreProcess:
149
- def __init__(self, sr, exp_dir, per):
150
- self.slicer = Slicer(sr=sr, threshold=-42, min_length=1500, min_interval=400, hop_size=15, max_sil_kept=500)
151
- self.sr = sr
152
- self.b_high, self.a_high = signal.butter(N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr)
153
- self.per = per
154
- self.exp_dir = exp_dir
155
- self.device = "cpu"
156
- self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios")
157
- self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k")
158
- os.makedirs(self.gt_wavs_dir, exist_ok=True)
159
- os.makedirs(self.wavs16k_dir, exist_ok=True)
160
-
161
- def _normalize_audio(self, audio):
162
- tmp_max = np.abs(audio).max()
163
- if tmp_max > 2.5: return None
164
- return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio
165
-
166
- def process_audio_segment(self, normalized_audio, sid, idx0, idx1):
167
- if normalized_audio is None:
168
- logger.debug(f"{sid}-{idx0}-{idx1}-filtered")
169
- return
170
-
171
- wavfile.write(os.path.join(self.gt_wavs_dir, f"{sid}_{idx0}_{idx1}.wav"), self.sr, normalized_audio.astype(np.float32))
172
- wavfile.write(os.path.join(self.wavs16k_dir, f"{sid}_{idx0}_{idx1}.wav"), SAMPLE_RATE_16K, librosa.resample(normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K, res_type="soxr_vhq").astype(np.float32))
173
-
174
- def process_audio(self, path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength):
175
- try:
176
- audio = load_audio(path, self.sr)
177
-
178
- if process_effects:
179
- audio = signal.lfilter(self.b_high, self.a_high, audio)
180
- audio = self._normalize_audio(audio)
181
-
182
- if clean_dataset:
183
- from main.tools.noisereduce import reduce_noise
184
- audio = reduce_noise(y=audio, sr=self.sr, prop_decrease=clean_strength)
185
-
186
- idx1 = 0
187
- if cut_preprocess:
188
- for audio_segment in self.slicer.slice(audio):
189
- i = 0
190
-
191
- while 1:
192
- start = int(self.sr * (self.per - OVERLAP) * i)
193
- i += 1
194
-
195
- if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr:
196
- self.process_audio_segment(audio_segment[start : start + int(self.per * self.sr)], sid, idx0, idx1)
197
- idx1 += 1
198
- else:
199
- self.process_audio_segment(audio_segment[start:], sid, idx0, idx1)
200
- idx1 += 1
201
- break
202
- else: self.process_audio_segment(audio, sid, idx0, idx1)
203
- except Exception as e:
204
- raise RuntimeError(f"{translations['process_audio_error']}: {e}")
205
-
206
- def process_file(args):
207
- pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength = (args)
208
- file_path, idx0, sid = file
209
- pp.process_audio(file_path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength)
210
-
211
- def preprocess_training_set(input_root, sr, num_processes, exp_dir, per, cut_preprocess, process_effects, clean_dataset, clean_strength):
212
- start_time = time.time()
213
-
214
- pp = PreProcess(sr, exp_dir, per)
215
- logger.info(translations["start_preprocess"].format(num_processes=num_processes))
216
- files = []
217
- idx = 0
218
-
219
- for root, _, filenames in os.walk(input_root):
220
- try:
221
- sid = 0 if root == input_root else int(os.path.basename(root))
222
-
223
- for f in filenames:
224
- if f.lower().endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3")):
225
- files.append((os.path.join(root, f), idx, sid))
226
- idx += 1
227
- except ValueError:
228
- raise ValueError(f"{translations['not_integer']} '{os.path.basename(root)}'.")
229
-
230
- with tqdm(total=len(files), desc=translations["preprocess"], ncols=100, unit="f") as pbar:
231
- with ProcessPoolExecutor(max_workers=num_processes) as executor:
232
- futures = [executor.submit(process_file, (pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength)) for file in files]
233
- for future in as_completed(futures):
234
- try:
235
- future.result()
236
- except Exception as e:
237
- raise RuntimeError(f"{translations['process_error']}: {e}")
238
- pbar.update(1)
239
- logger.debug(pbar.format_meter(pbar.n, pbar.total, pbar.format_dict["elapsed"]))
240
-
241
- elapsed_time = time.time() - start_time
242
- logger.info(translations["preprocess_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
243
-
244
- if __name__ == "__main__":
245
- args = parse_arguments()
246
- experiment_directory = os.path.join("assets", "logs", args.model_name)
247
- num_processes = args.cpu_cores
248
- num_processes = mp.cpu_count() if num_processes is None else int(num_processes)
249
- dataset = args.dataset_path
250
- sample_rate = args.sample_rate
251
- cut_preprocess = args.cut_preprocess
252
- preprocess_effects = args.process_effects
253
- clean_dataset = args.clean_dataset
254
- clean_strength = args.clean_strength
255
-
256
- os.makedirs(experiment_directory, exist_ok=True)
257
-
258
- if logger.hasHandlers(): logger.handlers.clear()
259
- else:
260
- console_handler = logging.StreamHandler()
261
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
262
- console_handler.setFormatter(console_formatter)
263
- console_handler.setLevel(logging.INFO)
264
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_directory, "preprocess.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
265
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
266
- file_handler.setFormatter(file_formatter)
267
- file_handler.setLevel(logging.DEBUG)
268
- logger.addHandler(console_handler)
269
- logger.addHandler(file_handler)
270
- logger.setLevel(logging.DEBUG)
271
-
272
- log_data = {translations['modelname']: args.model_name, translations['export_process']: experiment_directory, translations['dataset_folder']: dataset, translations['pretrain_sr']: sample_rate, translations['cpu_core']: num_processes, translations['split_audio']: cut_preprocess, translations['preprocess_effect']: preprocess_effects, translations['clear_audio']: clean_dataset}
273
- if clean_dataset: log_data[translations['clean_strength']] = clean_strength
274
-
275
- for key, value in log_data.items():
276
- logger.debug(f"{key}: {value}")
277
-
278
- pid_path = os.path.join(experiment_directory, "preprocess_pid.txt")
279
- with open(pid_path, "w") as pid_file:
280
- pid_file.write(str(os.getpid()))
281
-
282
- try:
283
- preprocess_training_set(dataset, sample_rate, num_processes, experiment_directory, 3.7, cut_preprocess, preprocess_effects, clean_dataset, clean_strength)
284
- except Exception as e:
285
- logger.error(f"{translations['process_audio_error']} {e}")
286
- import traceback
287
- logger.debug(traceback.format_exc())
288
-
289
- if os.path.exists(pid_path): os.remove(pid_path)
290
- logger.info(f"{translations['preprocess_model_success']} {args.model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/separator_music.py DELETED
@@ -1,290 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import logging
5
- import argparse
6
- import logging.handlers
7
-
8
- from pydub import AudioSegment
9
- from distutils.util import strtobool
10
-
11
- sys.path.append(os.getcwd())
12
-
13
- from main.configs.config import Config
14
- from main.library.utils import pydub_convert
15
- from main.library.algorithm.separator import Separator
16
-
17
- translations = Config().translations
18
- logger = logging.getLogger(__name__)
19
-
20
- if logger.hasHandlers(): logger.handlers.clear()
21
- else:
22
- console_handler = logging.StreamHandler()
23
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
24
- console_handler.setFormatter(console_formatter)
25
- console_handler.setLevel(logging.INFO)
26
- file_handler = logging.handlers.RotatingFileHandler(os.path.join("assets", "logs", "separator.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
27
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
28
- file_handler.setFormatter(file_formatter)
29
- file_handler.setLevel(logging.DEBUG)
30
- logger.addHandler(console_handler)
31
- logger.addHandler(file_handler)
32
- logger.setLevel(logging.DEBUG)
33
-
34
- demucs_models = {"HT-Tuned": "htdemucs_ft.yaml", "HT-Normal": "htdemucs.yaml", "HD_MMI": "hdemucs_mmi.yaml", "HT_6S": "htdemucs_6s.yaml"}
35
- mdx_models = {"Main_340": "UVR-MDX-NET_Main_340.onnx", "Main_390": "UVR-MDX-NET_Main_390.onnx", "Main_406": "UVR-MDX-NET_Main_406.onnx", "Main_427": "UVR-MDX-NET_Main_427.onnx", "Main_438": "UVR-MDX-NET_Main_438.onnx", "Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx", "Inst_HQ_1": "UVR-MDX-NET_Inst_HQ_1.onnx", "Inst_HQ_2": "UVR-MDX-NET_Inst_HQ_2.onnx", "Inst_HQ_3": "UVR-MDX-NET_Inst_HQ_3.onnx", "Inst_HQ_4": "UVR-MDX-NET-Inst_HQ_4.onnx", "Inst_HQ_5": "UVR-MDX-NET-Inst_HQ_5.onnx", "Kim_Vocal_1": "Kim_Vocal_1.onnx", "Kim_Vocal_2": "Kim_Vocal_2.onnx", "Kim_Inst": "Kim_Inst.onnx", "Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx", "Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx", "Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx", "Voc_FT": "UVR-MDX-NET-Voc_FT.onnx", "Crowd_HQ": "UVR-MDX-NET_Crowd_HQ_1.onnx", "MDXNET_9482": "UVR_MDXNET_9482.onnx", "Inst_1": "UVR-MDX-NET-Inst_1.onnx", "Inst_2": "UVR-MDX-NET-Inst_2.onnx", "Inst_3": "UVR-MDX-NET-Inst_3.onnx", "MDXNET_1_9703": "UVR_MDXNET_1_9703.onnx", "MDXNET_2_9682": "UVR_MDXNET_2_9682.onnx", "MDXNET_3_9662": "UVR_MDXNET_3_9662.onnx", "Inst_Main": "UVR-MDX-NET-Inst_Main.onnx", "MDXNET_Main": "UVR_MDXNET_Main.onnx"}
36
- kara_models = {"Version-1": "UVR_MDXNET_KARA.onnx", "Version-2": "UVR_MDXNET_KARA_2.onnx"}
37
-
38
- def parse_arguments():
39
- parser = argparse.ArgumentParser()
40
- parser.add_argument("--input_path", type=str, required=True)
41
- parser.add_argument("--output_path", type=str, default="./audios")
42
- parser.add_argument("--format", type=str, default="wav")
43
- parser.add_argument("--shifts", type=int, default=2)
44
- parser.add_argument("--segments_size", type=int, default=256)
45
- parser.add_argument("--overlap", type=float, default=0.25)
46
- parser.add_argument("--mdx_hop_length", type=int, default=1024)
47
- parser.add_argument("--mdx_batch_size", type=int, default=1)
48
- parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
49
- parser.add_argument("--clean_strength", type=float, default=0.7)
50
- parser.add_argument("--model_name", type=str, default="HT-Normal")
51
- parser.add_argument("--kara_model", type=str, default="Version-1")
52
- parser.add_argument("--backing", type=lambda x: bool(strtobool(x)), default=False)
53
- parser.add_argument("--mdx_denoise", type=lambda x: bool(strtobool(x)), default=False)
54
- parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
55
- parser.add_argument("--backing_reverb", type=lambda x: bool(strtobool(x)), default=False)
56
- parser.add_argument("--sample_rate", type=int, default=44100)
57
-
58
- return parser.parse_args()
59
-
60
- def main():
61
- start_time = time.time()
62
- pid_path = os.path.join("assets", "separate_pid.txt")
63
-
64
- with open(pid_path, "w") as pid_file:
65
- pid_file.write(str(os.getpid()))
66
-
67
- try:
68
- args = parse_arguments()
69
- input_path, output_path, export_format, shifts, segments_size, overlap, hop_length, batch_size, clean_audio, clean_strength, model_name, kara_model, backing, mdx_denoise, reverb, backing_reverb, sample_rate = args.input_path, args.output_path, args.format, args.shifts, args.segments_size, args.overlap, args.mdx_hop_length, args.mdx_batch_size, args.clean_audio, args.clean_strength, args.model_name, args.kara_model, args.backing, args.mdx_denoise, args.reverb, args.backing_reverb, args.sample_rate
70
-
71
- if backing_reverb and not reverb:
72
- logger.warning(translations["turn_on_dereverb"])
73
- sys.exit(1)
74
-
75
- if backing_reverb and not backing:
76
- logger.warning(translations["turn_on_separator_backing"])
77
- sys.exit(1)
78
-
79
- log_data = {translations['audio_path']: input_path, translations['output_path']: output_path, translations['export_format']: export_format, translations['shift']: shifts, translations['segments_size']: segments_size, translations['overlap']: overlap, translations['modelname']: model_name, translations['denoise_mdx']: mdx_denoise, "Hop length": hop_length, translations['batch_size']: batch_size, translations['sr']: sample_rate}
80
-
81
- if clean_audio:
82
- log_data[translations['clear_audio']] = clean_audio
83
- log_data[translations['clean_strength']] = clean_strength
84
-
85
- if backing:
86
- log_data[translations['backing_model_ver']] = kara_model
87
- log_data[translations['separator_backing']] = backing
88
-
89
- if reverb:
90
- log_data[translations['dereveb_audio']] = reverb
91
- log_data[translations['dereveb_backing']] = backing_reverb
92
-
93
- for key, value in log_data.items():
94
- logger.debug(f"{key}: {value}")
95
-
96
- if model_name in ["HT-Tuned", "HT-Normal", "HD_MMI", "HT_6S"]: vocals, instruments = separator_music_demucs(input_path, output_path, export_format, shifts, overlap, segments_size, model_name, sample_rate)
97
- else: vocals, instruments = separator_music_mdx(input_path, output_path, export_format, segments_size, overlap, mdx_denoise, model_name, hop_length, batch_size, sample_rate)
98
-
99
- if backing: main_vocals, backing_vocals = separator_backing(vocals, output_path, export_format, segments_size, overlap, mdx_denoise, kara_model, hop_length, batch_size, sample_rate)
100
- if reverb: vocals_no_reverb, main_vocals_no_reverb, backing_vocals_no_reverb = separator_reverb(output_path, export_format, segments_size, overlap, mdx_denoise, reverb, backing_reverb, hop_length, batch_size, sample_rate)
101
-
102
- original_output = os.path.join(output_path, f"Original_Vocals_No_Reverb.{export_format}") if reverb else os.path.join(output_path, f"Original_Vocals.{export_format}")
103
- main_output = os.path.join(output_path, f"Main_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Main_Vocals.{export_format}")
104
- backing_output = os.path.join(output_path, f"Backing_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Backing_Vocals.{export_format}")
105
-
106
- if clean_audio:
107
- import soundfile as sf
108
- logger.info(f"{translations['clear_audio']}...")
109
-
110
- vocal_data, vocal_sr = sf.read(vocals_no_reverb if reverb else vocals)
111
- main_data, main_sr = sf.read(main_vocals_no_reverb if reverb and backing else main_vocals)
112
- backing_data, backing_sr = sf.read(backing_vocals_no_reverb if reverb and backing_reverb else backing_vocals)
113
-
114
- from main.tools.noisereduce import reduce_noise
115
- sf.write(original_output, reduce_noise(y=vocal_data, prop_decrease=clean_strength), vocal_sr, format=export_format)
116
-
117
- if backing:
118
- sf.write(main_output, reduce_noise(y=main_data, sr=main_sr, prop_decrease=clean_strength), main_sr, format=export_format)
119
- sf.write(backing_output, reduce_noise(y=backing_data, sr=backing_sr, prop_decrease=clean_strength), backing_sr, format=export_format)
120
-
121
- logger.info(translations["clean_audio_success"])
122
- return original_output, instruments, main_output, backing_output
123
- except Exception as e:
124
- logger.error(f"{translations['separator_error']}: {e}")
125
- import traceback
126
- logger.debug(traceback.format_exc())
127
-
128
- if os.path.exists(pid_path): os.remove(pid_path)
129
-
130
- elapsed_time = time.time() - start_time
131
- logger.info(translations["separator_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
132
-
133
- def separator_music_demucs(input, output, format, shifts, overlap, segments_size, demucs_model, sample_rate):
134
- if not os.path.exists(input):
135
- logger.warning(translations["input_not_valid"])
136
- sys.exit(1)
137
-
138
- if not os.path.exists(output):
139
- logger.warning(translations["output_not_valid"])
140
- sys.exit(1)
141
-
142
- for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
143
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
144
-
145
- logger.info(f"{translations['separator_process_2']}...")
146
- demucs_output = separator_main(audio_file=input, model_filename=demucs_models.get(demucs_model), output_format=format, output_dir=output, demucs_segment_size=(segments_size / 2), demucs_shifts=shifts, demucs_overlap=overlap, sample_rate=sample_rate)
147
-
148
- for f in demucs_output:
149
- path = os.path.join(output, f)
150
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
151
-
152
- if '_(Drums)_' in f: drums = path
153
- elif '_(Bass)_' in f: bass = path
154
- elif '_(Other)_' in f: other = path
155
- elif '_(Vocals)_' in f: os.rename(path, os.path.join(output, f"Original_Vocals.{format}"))
156
-
157
- pydub_convert(AudioSegment.from_file(drums)).overlay(pydub_convert(AudioSegment.from_file(bass))).overlay(pydub_convert(AudioSegment.from_file(other))).export(os.path.join(output, f"Instruments.{format}"), format=format)
158
-
159
- for f in [drums, bass, other]:
160
- if os.path.exists(f): os.remove(f)
161
-
162
- logger.info(translations["separator_success_2"])
163
- return os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
164
-
165
- def separator_backing(input, output, format, segments_size, overlap, denoise, kara_model, hop_length, batch_size, sample_rate):
166
- if not os.path.exists(input):
167
- logger.warning(translations["input_not_valid"])
168
- sys.exit(1)
169
-
170
- if not os.path.exists(output):
171
- logger.warning(translations["output_not_valid"])
172
- sys.exit(1)
173
-
174
- for f in [f"Main_Vocals.{format}", f"Backing_Vocals.{format}"]:
175
- if os.path.exists(os.path.join(output, f)): os.remove(os.path.join(output, f))
176
-
177
- model_2 = kara_models.get(kara_model)
178
- logger.info(f"{translations['separator_process_backing']}...")
179
-
180
- backing_outputs = separator_main(audio_file=input, model_filename=model_2, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
181
- main_output = os.path.join(output, f"Main_Vocals.{format}")
182
- backing_output = os.path.join(output, f"Backing_Vocals.{format}")
183
-
184
- for f in backing_outputs:
185
- path = os.path.join(output, f)
186
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
187
- if '_(Instrumental)_' in f: os.rename(path, backing_output)
188
- elif '_(Vocals)_' in f: os.rename(path, main_output)
189
-
190
- logger.info(translations["separator_process_backing_success"])
191
- return main_output, backing_output
192
-
193
- def separator_music_mdx(input, output, format, segments_size, overlap, denoise, mdx_model, hop_length, batch_size, sample_rate):
194
- if not os.path.exists(input):
195
- logger.warning(translations["input_not_valid"])
196
- sys.exit(1)
197
-
198
- if not os.path.exists(output):
199
- logger.warning(translations["output_not_valid"])
200
- sys.exit(1)
201
-
202
- for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
203
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
204
-
205
- model_3 = mdx_models.get(mdx_model)
206
- logger.info(f"{translations['separator_process_2']}...")
207
-
208
- output_music = separator_main(audio_file=input, model_filename=model_3, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
209
- original_output, instruments_output = os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
210
-
211
- for f in output_music:
212
- path = os.path.join(output, f)
213
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
214
- if '_(Instrumental)_' in f: os.rename(path, instruments_output)
215
- elif '_(Vocals)_' in f: os.rename(path, original_output)
216
-
217
- logger.info(translations["separator_process_backing_success"])
218
- return original_output, instruments_output
219
-
220
- def separator_reverb(output, format, segments_size, overlap, denoise, original, backing_reverb, hop_length, batch_size, sample_rate):
221
- if not os.path.exists(output):
222
- logger.warning(translations["output_not_valid"])
223
- sys.exit(1)
224
-
225
- for i in [f"Original_Vocals_Reverb.{format}", f"Main_Vocals_Reverb.{format}", f"Original_Vocals_No_Reverb.{format}", f"Main_Vocals_No_Reverb.{format}"]:
226
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
227
-
228
- dereveb_path = []
229
-
230
- if original:
231
- try:
232
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Original_Vocals' in f][0]))
233
- except IndexError:
234
- logger.warning(translations["not_found_original_vocal"])
235
- sys.exit(1)
236
-
237
- if backing_reverb:
238
- try:
239
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Main_Vocals' in f][0]))
240
- except IndexError:
241
- logger.warning(translations["not_found_main_vocal"])
242
- sys.exit(1)
243
-
244
- if backing_reverb:
245
- try:
246
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Backing_Vocals' in f][0]))
247
- except IndexError:
248
- logger.warning(translations["not_found_backing_vocal"])
249
- sys.exit(1)
250
-
251
- for path in dereveb_path:
252
- if not os.path.exists(path):
253
- logger.warning(translations["not_found"].format(name=path))
254
- sys.exit(1)
255
-
256
- if "Original_Vocals" in path:
257
- reverb_path, no_reverb_path = os.path.join(output, f"Original_Vocals_Reverb.{format}"), os.path.join(output, f"Original_Vocals_No_Reverb.{format}")
258
- start_title, end_title = translations["process_original"], translations["process_original_success"]
259
- elif "Main_Vocals" in path:
260
- reverb_path, no_reverb_path = os.path.join(output, f"Main_Vocals_Reverb.{format}"), os.path.join(output, f"Main_Vocals_No_Reverb.{format}")
261
- start_title, end_title = translations["process_main"], translations["process_main_success"]
262
- elif "Backing_Vocals" in path:
263
- reverb_path, no_reverb_path = os.path.join(output, f"Backing_Vocals_Reverb.{format}"), os.path.join(output, f"Backing_Vocals_No_Reverb.{format}")
264
- start_title, end_title = translations["process_backing"], translations["process_backing_success"]
265
-
266
- logger.info(start_title)
267
- output_dereveb = separator_main(audio_file=path, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise, sample_rate=sample_rate)
268
-
269
- for f in output_dereveb:
270
- path = os.path.join(output, f)
271
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
272
-
273
- if '_(Reverb)_' in f: os.rename(path, reverb_path)
274
- elif '_(No Reverb)_' in f: os.rename(path, no_reverb_path)
275
-
276
- logger.info(end_title)
277
- return (os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if original else None), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if backing_reverb else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else None)
278
-
279
- def separator_main(audio_file=None, model_filename="UVR-MDX-NET_Main_340.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, demucs_segment_size=256, demucs_shifts=2, demucs_overlap=0.25, sample_rate=44100):
280
- try:
281
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=sample_rate, mdx_params={"hop_length": mdx_hop_length, "segment_size": mdx_segment_size, "overlap": mdx_overlap, "batch_size": mdx_batch_size, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": demucs_segment_size, "shifts": demucs_shifts, "overlap": demucs_overlap, "segments_enabled": True})
282
- separator.load_model(model_filename=model_filename)
283
- return separator.separate(audio_file)
284
- except:
285
- logger.debug(translations["default_setting"])
286
- separator = Separator(logger=logger, log_formatter=file_formatter, log_level=logging.INFO, output_dir=output_dir, output_format=output_format, output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": mdx_enable_denoise}, demucs_params={"segment_size": 128, "shifts": 2, "overlap": 0.25, "segments_enabled": True})
287
- separator.load_model(model_filename=model_filename)
288
- return separator.separate(audio_file)
289
-
290
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/train.py DELETED
@@ -1,1000 +0,0 @@
1
- import os
2
- import sys
3
- import glob
4
- import json
5
- import torch
6
- import hashlib
7
- import logging
8
- import argparse
9
- import datetime
10
- import warnings
11
- import logging.handlers
12
-
13
- import numpy as np
14
- import soundfile as sf
15
- import matplotlib.pyplot as plt
16
- import torch.distributed as dist
17
- import torch.utils.data as tdata
18
- import torch.multiprocessing as mp
19
- import torch.utils.checkpoint as checkpoint
20
-
21
- from tqdm import tqdm
22
- from collections import OrderedDict
23
- from random import randint, shuffle
24
- from torch.cuda.amp import GradScaler, autocast
25
- from torch.utils.tensorboard import SummaryWriter
26
-
27
- from time import time as ttime
28
- from torch.nn import functional as F
29
- from distutils.util import strtobool
30
- from librosa.filters import mel as librosa_mel_fn
31
- from torch.nn.parallel import DistributedDataParallel as DDP
32
- from torch.nn.utils.parametrizations import spectral_norm, weight_norm
33
-
34
- sys.path.append(os.getcwd())
35
- from main.configs.config import Config
36
- from main.library.algorithm.residuals import LRELU_SLOPE
37
- from main.library.algorithm.synthesizers import Synthesizer
38
- from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value
39
-
40
- MATPLOTLIB_FLAG = False
41
- translations = Config().translations
42
- warnings.filterwarnings("ignore")
43
- logging.getLogger("torch").setLevel(logging.ERROR)
44
-
45
- class HParams:
46
- def __init__(self, **kwargs):
47
- for k, v in kwargs.items():
48
- self[k] = HParams(**v) if isinstance(v, dict) else v
49
-
50
- def keys(self):
51
- return self.__dict__.keys()
52
-
53
- def items(self):
54
- return self.__dict__.items()
55
-
56
- def values(self):
57
- return self.__dict__.values()
58
-
59
- def __len__(self):
60
- return len(self.__dict__)
61
-
62
- def __getitem__(self, key):
63
- return self.__dict__[key]
64
-
65
- def __setitem__(self, key, value):
66
- self.__dict__[key] = value
67
-
68
- def __contains__(self, key):
69
- return key in self.__dict__
70
-
71
- def __repr__(self):
72
- return repr(self.__dict__)
73
-
74
- def parse_arguments():
75
- parser = argparse.ArgumentParser()
76
- parser.add_argument("--model_name", type=str, required=True)
77
- parser.add_argument("--rvc_version", type=str, default="v2")
78
- parser.add_argument("--save_every_epoch", type=int, required=True)
79
- parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
80
- parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
81
- parser.add_argument("--total_epoch", type=int, default=300)
82
- parser.add_argument("--sample_rate", type=int, required=True)
83
- parser.add_argument("--batch_size", type=int, default=8)
84
- parser.add_argument("--gpu", type=str, default="0")
85
- parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
86
- parser.add_argument("--g_pretrained_path", type=str, default="")
87
- parser.add_argument("--d_pretrained_path", type=str, default="")
88
- parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
89
- parser.add_argument("--overtraining_threshold", type=int, default=50)
90
- parser.add_argument("--cleanup", type=lambda x: bool(strtobool(x)), default=False)
91
- parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
92
- parser.add_argument("--model_author", type=str)
93
- parser.add_argument("--vocoder", type=str, default="Default")
94
- parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False)
95
-
96
- return parser.parse_args()
97
-
98
- args = parse_arguments()
99
- model_name, save_every_epoch, total_epoch, pretrainG, pretrainD, version, gpus, batch_size, sample_rate, pitch_guidance, save_only_latest, save_every_weights, cache_data_in_gpu, overtraining_detector, overtraining_threshold, cleanup, model_author, vocoder, checkpointing = args.model_name, args.save_every_epoch, args.total_epoch, args.g_pretrained_path, args.d_pretrained_path, args.rvc_version, args.gpu, args.batch_size, args.sample_rate, args.pitch_guidance, args.save_only_latest, args.save_every_weights, args.cache_data_in_gpu, args.overtraining_detector, args.overtraining_threshold, args.cleanup, args.model_author, args.vocoder, args.checkpointing
100
-
101
- experiment_dir = os.path.join("assets", "logs", model_name)
102
- training_file_path = os.path.join(experiment_dir, "training_data.json")
103
- config_save_path = os.path.join(experiment_dir, "config.json")
104
-
105
- os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",")
106
- n_gpus = len(gpus.split("-"))
107
-
108
- torch.backends.cudnn.deterministic = False
109
- torch.backends.cudnn.benchmark = False
110
-
111
- lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
112
- global_step, last_loss_gen_all, overtrain_save_epoch = 0, 0, 0
113
- loss_gen_history, smoothed_loss_gen_history, loss_disc_history, smoothed_loss_disc_history = [], [], [], []
114
-
115
- with open(config_save_path, "r") as f:
116
- config = json.load(f)
117
-
118
- config = HParams(**config)
119
- config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
120
- logger = logging.getLogger(__name__)
121
-
122
- if logger.hasHandlers(): logger.handlers.clear()
123
- else:
124
- console_handler = logging.StreamHandler()
125
- console_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
126
- console_handler.setLevel(logging.INFO)
127
- file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_dir, "train.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
128
- file_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
129
- file_handler.setLevel(logging.DEBUG)
130
- logger.addHandler(console_handler)
131
- logger.addHandler(file_handler)
132
- logger.setLevel(logging.DEBUG)
133
-
134
- log_data = {translations['modelname']: model_name, translations["save_every_epoch"]: save_every_epoch, translations["total_e"]: total_epoch, translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD): "", translations['training_version']: version, "Gpu": gpus, translations['batch_size']: batch_size, translations['pretrain_sr']: sample_rate, translations['training_f0']: pitch_guidance, translations['save_only_latest']: save_only_latest, translations['save_every_weights']: save_every_weights, translations['cache_in_gpu']: cache_data_in_gpu, translations['overtraining_detector']: overtraining_detector, translations['threshold']: overtraining_threshold, translations['cleanup_training']: cleanup, translations['memory_efficient_training']: checkpointing}
135
- if model_author: log_data[translations["model_author"].format(model_author=model_author)] = ""
136
- if vocoder != "Default": log_data[translations['vocoder']] = vocoder
137
-
138
- for key, value in log_data.items():
139
- logger.debug(f"{key}: {value}" if value != "" else f"{key} {value}")
140
-
141
- def main():
142
- global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author, vocoder, checkpointing
143
-
144
- os.environ["MASTER_ADDR"] = "localhost"
145
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
146
-
147
- if torch.cuda.is_available(): device, n_gpus = torch.device("cuda"), torch.cuda.device_count()
148
- elif torch.backends.mps.is_available(): device, n_gpus = torch.device("mps"), 1
149
- else: device, n_gpus = torch.device("cpu"), 1
150
-
151
- def start():
152
- children = []
153
- pid_data = {"process_pids": []}
154
-
155
- with open(config_save_path, "r") as pid_file:
156
- try:
157
- pid_data.update(json.load(pid_file))
158
- except json.JSONDecodeError:
159
- pass
160
-
161
- with open(config_save_path, "w") as pid_file:
162
- for i in range(n_gpus):
163
- subproc = mp.Process(target=run, args=(i, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, total_epoch, save_every_weights, config, device, model_author, vocoder, checkpointing))
164
- children.append(subproc)
165
-
166
- subproc.start()
167
- pid_data["process_pids"].append(subproc.pid)
168
-
169
- json.dump(pid_data, pid_file, indent=4)
170
-
171
- for i in range(n_gpus):
172
- children[i].join()
173
-
174
- def load_from_json(file_path):
175
- if os.path.exists(file_path):
176
- with open(file_path, "r") as f:
177
- data = json.load(f)
178
- return (data.get("loss_disc_history", []), data.get("smoothed_loss_disc_history", []), data.get("loss_gen_history", []), data.get("smoothed_loss_gen_history", []))
179
-
180
- return [], [], [], []
181
-
182
- def continue_overtrain_detector(training_file_path):
183
- if overtraining_detector and os.path.exists(training_file_path): (loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) = load_from_json(training_file_path)
184
-
185
- n_gpus = torch.cuda.device_count()
186
-
187
- if not torch.cuda.is_available() and torch.backends.mps.is_available(): n_gpus = 1
188
- if n_gpus < 1:
189
- logger.warning(translations["not_gpu"])
190
- n_gpus = 1
191
-
192
- if cleanup:
193
- for root, dirs, files in os.walk(experiment_dir, topdown=False):
194
- for name in files:
195
- file_path = os.path.join(root, name)
196
- _, file_extension = os.path.splitext(name)
197
- if (file_extension == ".0" or (name.startswith("D_") and file_extension == ".pth") or (name.startswith("G_") and file_extension == ".pth") or (file_extension == ".index")): os.remove(file_path)
198
-
199
- for name in dirs:
200
- if name == "eval":
201
- folder_path = os.path.join(root, name)
202
-
203
- for item in os.listdir(folder_path):
204
- item_path = os.path.join(folder_path, item)
205
- if os.path.isfile(item_path): os.remove(item_path)
206
-
207
- os.rmdir(folder_path)
208
-
209
- continue_overtrain_detector(training_file_path)
210
- start()
211
-
212
- def plot_spectrogram_to_numpy(spectrogram):
213
- global MATPLOTLIB_FLAG
214
-
215
- if not MATPLOTLIB_FLAG:
216
- plt.switch_backend("Agg")
217
- MATPLOTLIB_FLAG = True
218
-
219
- fig, ax = plt.subplots(figsize=(10, 2))
220
-
221
- plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax)
222
- plt.xlabel("Frames")
223
- plt.ylabel("Channels")
224
- plt.tight_layout()
225
- fig.canvas.draw()
226
- plt.close(fig)
227
-
228
- return np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (3,))
229
-
230
- def verify_checkpoint_shapes(checkpoint_path, model):
231
- checkpoint = torch.load(checkpoint_path, map_location="cpu")
232
- checkpoint_state_dict = checkpoint["model"]
233
- try:
234
- model_state_dict = model.module.load_state_dict(checkpoint_state_dict) if hasattr(model, "module") else model.load_state_dict(checkpoint_state_dict)
235
- except RuntimeError:
236
- logger.error(translations["checkpointing_err"])
237
- sys.exit(1)
238
- else: del checkpoint, checkpoint_state_dict, model_state_dict
239
-
240
- def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
241
- for k, v in scalars.items():
242
- writer.add_scalar(k, v, global_step)
243
-
244
- for k, v in histograms.items():
245
- writer.add_histogram(k, v, global_step)
246
-
247
- for k, v in images.items():
248
- writer.add_image(k, v, global_step, dataformats="HWC")
249
-
250
- for k, v in audios.items():
251
- writer.add_audio(k, v, global_step, audio_sample_rate)
252
-
253
- def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
254
- assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
255
- checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu"), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
256
- new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()}
257
-
258
- if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False)
259
- else: model.load_state_dict(new_state_dict, strict=False)
260
-
261
- if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
262
- logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
263
- return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"])
264
-
265
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
266
- state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
267
- torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path)
268
- logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
269
-
270
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
271
- checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
272
- return checkpoints[-1] if checkpoints else None
273
-
274
- def load_wav_to_torch(full_path):
275
- data, sample_rate = sf.read(full_path, dtype='float32')
276
- return torch.FloatTensor(data.astype(np.float32)), sample_rate
277
-
278
- def load_filepaths_and_text(filename, split="|"):
279
- with open(filename, encoding="utf-8") as f:
280
- return [line.strip().split(split) for line in f]
281
-
282
- def feature_loss(fmap_r, fmap_g):
283
- loss = 0
284
- for dr, dg in zip(fmap_r, fmap_g):
285
- for rl, gl in zip(dr, dg):
286
- loss += torch.mean(torch.abs(rl.float().detach() - gl.float()))
287
- return loss * 2
288
-
289
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
290
- loss = 0
291
- r_losses, g_losses = [], []
292
-
293
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
294
- dr = dr.float()
295
- dg = dg.float()
296
- r_loss = torch.mean((1 - dr) ** 2)
297
- g_loss = torch.mean(dg**2)
298
- loss += r_loss + g_loss
299
- r_losses.append(r_loss.item())
300
- g_losses.append(g_loss.item())
301
- return loss, r_losses, g_losses
302
-
303
- def generator_loss(disc_outputs):
304
- loss = 0
305
- gen_losses = []
306
-
307
- for dg in disc_outputs:
308
- l = torch.mean((1 - dg.float()) ** 2)
309
- gen_losses.append(l)
310
- loss += l
311
- return loss, gen_losses
312
-
313
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
314
- z_p = z_p.float()
315
- logs_q = logs_q.float()
316
- m_p = m_p.float()
317
- logs_p = logs_p.float()
318
- z_mask = z_mask.float()
319
- kl = logs_p - logs_q - 0.5
320
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
321
- return torch.sum(kl * z_mask) / torch.sum(z_mask)
322
-
323
- class TextAudioLoaderMultiNSFsid(tdata.Dataset):
324
- def __init__(self, hparams):
325
- self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
326
- self.max_wav_value = hparams.max_wav_value
327
- self.sample_rate = hparams.sample_rate
328
- self.filter_length = hparams.filter_length
329
- self.hop_length = hparams.hop_length
330
- self.win_length = hparams.win_length
331
- self.sample_rate = hparams.sample_rate
332
- self.min_text_len = getattr(hparams, "min_text_len", 1)
333
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
334
- self._filter()
335
-
336
- def _filter(self):
337
- audiopaths_and_text_new, lengths = [], []
338
- for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
339
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
340
- audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
341
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
342
-
343
- self.audiopaths_and_text = audiopaths_and_text_new
344
- self.lengths = lengths
345
-
346
- def get_sid(self, sid):
347
- try:
348
- sid = torch.LongTensor([int(sid)])
349
- except ValueError as e:
350
- logger.error(translations["sid_error"].format(sid=sid, e=e))
351
- sid = torch.LongTensor([0])
352
- return sid
353
-
354
- def get_audio_text_pair(self, audiopath_and_text):
355
- phone, pitch, pitchf = self.get_labels(audiopath_and_text[1], audiopath_and_text[2], audiopath_and_text[3])
356
- spec, wav = self.get_audio(audiopath_and_text[0])
357
- dv = self.get_sid(audiopath_and_text[4])
358
- len_phone = phone.size()[0]
359
- len_spec = spec.size()[-1]
360
-
361
- if len_phone != len_spec:
362
- len_min = min(len_phone, len_spec)
363
- len_wav = len_min * self.hop_length
364
- spec, wav, phone = spec[:, :len_min], wav[:, :len_wav], phone[:len_min, :]
365
- pitch, pitchf = pitch[:len_min], pitchf[:len_min]
366
- return (spec, wav, phone, pitch, pitchf, dv)
367
-
368
- def get_labels(self, phone, pitch, pitchf):
369
- phone = np.repeat(np.load(phone), 2, axis=0)
370
- n_num = min(phone.shape[0], 900)
371
- return torch.FloatTensor(phone[:n_num, :]), torch.LongTensor(np.load(pitch)[:n_num]), torch.FloatTensor(np.load(pitchf)[:n_num])
372
-
373
- def get_audio(self, filename):
374
- audio, sample_rate = load_wav_to_torch(filename)
375
- if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
376
- audio_norm = audio.unsqueeze(0)
377
- spec_filename = filename.replace(".wav", ".spec.pt")
378
-
379
- if os.path.exists(spec_filename):
380
- try:
381
- spec = torch.load(spec_filename)
382
- except Exception as e:
383
- logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
384
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
385
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
386
- else:
387
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
388
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
389
- return spec, audio_norm
390
-
391
- def __getitem__(self, index):
392
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
393
-
394
- def __len__(self):
395
- return len(self.audiopaths_and_text)
396
-
397
- class TextAudioCollateMultiNSFsid:
398
- def __init__(self, return_ids=False):
399
- self.return_ids = return_ids
400
-
401
- def __call__(self, batch):
402
- _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
403
- spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
404
- spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
405
- spec_padded.zero_()
406
- wave_padded.zero_()
407
- max_phone_len = max([x[2].size(0) for x in batch])
408
- phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
409
- pitch_padded, pitchf_padded = torch.LongTensor(len(batch), max_phone_len), torch.FloatTensor(len(batch), max_phone_len)
410
- phone_padded.zero_()
411
- pitch_padded.zero_()
412
- pitchf_padded.zero_()
413
- sid = torch.LongTensor(len(batch))
414
-
415
- for i in range(len(ids_sorted_decreasing)):
416
- row = batch[ids_sorted_decreasing[i]]
417
- spec = row[0]
418
- spec_padded[i, :, : spec.size(1)] = spec
419
- spec_lengths[i] = spec.size(1)
420
- wave = row[1]
421
- wave_padded[i, :, : wave.size(1)] = wave
422
- wave_lengths[i] = wave.size(1)
423
- phone = row[2]
424
- phone_padded[i, : phone.size(0), :] = phone
425
- phone_lengths[i] = phone.size(0)
426
- pitch = row[3]
427
- pitch_padded[i, : pitch.size(0)] = pitch
428
- pitchf = row[4]
429
- pitchf_padded[i, : pitchf.size(0)] = pitchf
430
- sid[i] = row[5]
431
- return (phone_padded, phone_lengths, pitch_padded, pitchf_padded, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
432
-
433
- class TextAudioLoader(tdata.Dataset):
434
- def __init__(self, hparams):
435
- self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
436
- self.max_wav_value = hparams.max_wav_value
437
- self.sample_rate = hparams.sample_rate
438
- self.filter_length = hparams.filter_length
439
- self.hop_length = hparams.hop_length
440
- self.win_length = hparams.win_length
441
- self.sample_rate = hparams.sample_rate
442
- self.min_text_len = getattr(hparams, "min_text_len", 1)
443
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
444
- self._filter()
445
-
446
- def _filter(self):
447
- audiopaths_and_text_new, lengths = [], []
448
- for entry in self.audiopaths_and_text:
449
- if len(entry) >= 3:
450
- audiopath, text, dv = entry[:3]
451
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
452
- audiopaths_and_text_new.append([audiopath, text, dv])
453
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
454
-
455
- self.audiopaths_and_text = audiopaths_and_text_new
456
- self.lengths = lengths
457
-
458
- def get_sid(self, sid):
459
- try:
460
- sid = torch.LongTensor([int(sid)])
461
- except ValueError as e:
462
- logger.error(translations["sid_error"].format(sid=sid, e=e))
463
- sid = torch.LongTensor([0])
464
- return sid
465
-
466
- def get_audio_text_pair(self, audiopath_and_text):
467
- phone = self.get_labels(audiopath_and_text[1])
468
- spec, wav = self.get_audio(audiopath_and_text[0])
469
- dv = self.get_sid(audiopath_and_text[2])
470
- len_phone = phone.size()[0]
471
- len_spec = spec.size()[-1]
472
-
473
- if len_phone != len_spec:
474
- len_min = min(len_phone, len_spec)
475
- len_wav = len_min * self.hop_length
476
- spec = spec[:, :len_min]
477
- wav = wav[:, :len_wav]
478
- phone = phone[:len_min, :]
479
- return (spec, wav, phone, dv)
480
-
481
- def get_labels(self, phone):
482
- phone = np.repeat(np.load(phone), 2, axis=0)
483
- return torch.FloatTensor(phone[:min(phone.shape[0], 900), :])
484
-
485
- def get_audio(self, filename):
486
- audio, sample_rate = load_wav_to_torch(filename)
487
- if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
488
- audio_norm = audio.unsqueeze(0)
489
- spec_filename = filename.replace(".wav", ".spec.pt")
490
-
491
- if os.path.exists(spec_filename):
492
- try:
493
- spec = torch.load(spec_filename)
494
- except Exception as e:
495
- logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
496
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
497
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
498
- else:
499
- spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0)
500
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
501
- return spec, audio_norm
502
-
503
- def __getitem__(self, index):
504
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
505
-
506
- def __len__(self):
507
- return len(self.audiopaths_and_text)
508
-
509
- class TextAudioCollate:
510
- def __init__(self, return_ids=False):
511
- self.return_ids = return_ids
512
-
513
- def __call__(self, batch):
514
- _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
515
- spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch))
516
- spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch]))
517
- spec_padded.zero_()
518
- wave_padded.zero_()
519
- max_phone_len = max([x[2].size(0) for x in batch])
520
- phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
521
- phone_padded.zero_()
522
- sid = torch.LongTensor(len(batch))
523
- for i in range(len(ids_sorted_decreasing)):
524
- row = batch[ids_sorted_decreasing[i]]
525
- spec = row[0]
526
- spec_padded[i, :, : spec.size(1)] = spec
527
- spec_lengths[i] = spec.size(1)
528
- wave = row[1]
529
- wave_padded[i, :, : wave.size(1)] = wave
530
- wave_lengths[i] = wave.size(1)
531
- phone = row[2]
532
- phone_padded[i, : phone.size(0), :] = phone
533
- phone_lengths[i] = phone.size(0)
534
- sid[i] = row[3]
535
- return (phone_padded, phone_lengths, spec_padded, spec_lengths, wave_padded, wave_lengths, sid)
536
-
537
- class DistributedBucketSampler(tdata.distributed.DistributedSampler):
538
- def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
539
- super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
540
- self.lengths = dataset.lengths
541
- self.batch_size = batch_size
542
- self.boundaries = boundaries
543
- self.buckets, self.num_samples_per_bucket = self._create_buckets()
544
- self.total_size = sum(self.num_samples_per_bucket)
545
- self.num_samples = self.total_size // self.num_replicas
546
-
547
- def _create_buckets(self):
548
- buckets = [[] for _ in range(len(self.boundaries) - 1)]
549
- for i in range(len(self.lengths)):
550
- idx_bucket = self._bisect(self.lengths[i])
551
- if idx_bucket != -1: buckets[idx_bucket].append(i)
552
-
553
- for i in range(len(buckets) - 1, -1, -1):
554
- if len(buckets[i]) == 0:
555
- buckets.pop(i)
556
- self.boundaries.pop(i + 1)
557
-
558
- num_samples_per_bucket = []
559
- for i in range(len(buckets)):
560
- len_bucket = len(buckets[i])
561
- total_batch_size = self.num_replicas * self.batch_size
562
- num_samples_per_bucket.append(len_bucket + ((total_batch_size - (len_bucket % total_batch_size)) % total_batch_size))
563
- return buckets, num_samples_per_bucket
564
-
565
- def __iter__(self):
566
- g = torch.Generator()
567
- g.manual_seed(self.epoch)
568
- indices, batches = [], []
569
- if self.shuffle:
570
- for bucket in self.buckets:
571
- indices.append(torch.randperm(len(bucket), generator=g).tolist())
572
- else:
573
- for bucket in self.buckets:
574
- indices.append(list(range(len(bucket))))
575
-
576
- for i in range(len(self.buckets)):
577
- bucket = self.buckets[i]
578
- len_bucket = len(bucket)
579
- ids_bucket = indices[i]
580
- rem = self.num_samples_per_bucket[i] - len_bucket
581
- ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])[self.rank :: self.num_replicas]
582
-
583
- for j in range(len(ids_bucket) // self.batch_size):
584
- batches.append([bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]])
585
-
586
- if self.shuffle: batches = [batches[i] for i in torch.randperm(len(batches), generator=g).tolist()]
587
- self.batches = batches
588
- assert len(self.batches) * self.batch_size == self.num_samples
589
- return iter(self.batches)
590
-
591
- def _bisect(self, x, lo=0, hi=None):
592
- if hi is None: hi = len(self.boundaries) - 1
593
-
594
- if hi > lo:
595
- mid = (hi + lo) // 2
596
- if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid
597
- elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid)
598
- else: return self._bisect(x, mid + 1, hi)
599
- else: return -1
600
-
601
- def __len__(self):
602
- return self.num_samples // self.batch_size
603
-
604
- class MultiPeriodDiscriminator(torch.nn.Module):
605
- def __init__(self, version, use_spectral_norm=False, checkpointing=False):
606
- super(MultiPeriodDiscriminator, self).__init__()
607
- self.checkpointing = checkpointing
608
- periods = ([2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37])
609
- self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm, checkpointing=checkpointing)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing) for p in periods])
610
-
611
- def forward(self, y, y_hat):
612
- y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
613
- for d in self.discriminators:
614
- if self.training and self.checkpointing:
615
- def forward_discriminator(d, y, y_hat):
616
- y_d_r, fmap_r = d(y)
617
- y_d_g, fmap_g = d(y_hat)
618
- return y_d_r, fmap_r, y_d_g, fmap_g
619
- y_d_r, fmap_r, y_d_g, fmap_g = checkpoint.checkpoint(forward_discriminator, d, y, y_hat, use_reentrant=False)
620
- else:
621
- y_d_r, fmap_r = d(y)
622
- y_d_g, fmap_g = d(y_hat)
623
-
624
- y_d_rs.append(y_d_r)
625
- y_d_gs.append(y_d_g)
626
- fmap_rs.append(fmap_r)
627
- fmap_gs.append(fmap_g)
628
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
629
-
630
- class DiscriminatorS(torch.nn.Module):
631
- def __init__(self, use_spectral_norm=False, checkpointing=False):
632
- super(DiscriminatorS, self).__init__()
633
- self.checkpointing = checkpointing
634
- norm_f = spectral_norm if use_spectral_norm else weight_norm
635
- self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))])
636
- self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
637
- self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
638
-
639
- def forward(self, x):
640
- fmap = []
641
- for conv in self.convs:
642
- x = checkpoint.checkpoint(self.lrelu, checkpoint.checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
643
- fmap.append(x)
644
-
645
- x = self.conv_post(x)
646
- fmap.append(x)
647
- return torch.flatten(x, 1, -1), fmap
648
-
649
- class DiscriminatorP(torch.nn.Module):
650
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, checkpointing=False):
651
- super(DiscriminatorP, self).__init__()
652
- self.period = period
653
- self.checkpointing = checkpointing
654
- norm_f = spectral_norm if use_spectral_norm else weight_norm
655
- self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv2d(in_ch, out_ch, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))) for in_ch, out_ch in zip([1, 32, 128, 512, 1024], [32, 128, 512, 1024, 1024])])
656
- self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
657
- self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
658
-
659
- def forward(self, x):
660
- fmap = []
661
- b, c, t = x.shape
662
- if t % self.period != 0: x = torch.nn.functional.pad(x, (0, (self.period - (t % self.period))), "reflect")
663
- x = x.view(b, c, -1, self.period)
664
- for conv in self.convs:
665
- x = checkpoint.checkpoint(self.lrelu, checkpoint.checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x))
666
- fmap.append(x)
667
-
668
- x = self.conv_post(x)
669
- fmap.append(x)
670
- return torch.flatten(x, 1, -1), fmap
671
-
672
- class EpochRecorder:
673
- def __init__(self):
674
- self.last_time = ttime()
675
-
676
- def record(self):
677
- now_time = ttime()
678
- elapsed_time = now_time - self.last_time
679
- self.last_time = now_time
680
- return translations["time_or_speed_training"].format(current_time=datetime.datetime.now().strftime("%H:%M:%S"), elapsed_time_str=str(datetime.timedelta(seconds=int(round(elapsed_time, 1)))))
681
-
682
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
683
- return torch.log(torch.clamp(x, min=clip_val) * C)
684
-
685
- def dynamic_range_decompression_torch(x, C=1):
686
- return torch.exp(x) / C
687
-
688
- def spectral_normalize_torch(magnitudes):
689
- return dynamic_range_compression_torch(magnitudes)
690
-
691
- def spectral_de_normalize_torch(magnitudes):
692
- return dynamic_range_decompression_torch(magnitudes)
693
-
694
- mel_basis, hann_window = {}, {}
695
-
696
- def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
697
- global hann_window
698
-
699
- wnsize_dtype_device = str(win_size) + "_" + str(y.dtype) + "_" + str(y.device)
700
- if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
701
- spec = torch.stft(torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect").squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
702
- return torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
703
-
704
- def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
705
- global mel_basis
706
-
707
- fmax_dtype_device = str(fmax) + "_" + str(spec.dtype) + "_" + str(spec.device)
708
- if fmax_dtype_device not in mel_basis: mel_basis[fmax_dtype_device] = torch.from_numpy(librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)).to(dtype=spec.dtype, device=spec.device)
709
- return spectral_normalize_torch(torch.matmul(mel_basis[fmax_dtype_device], spec))
710
-
711
- def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
712
- return spec_to_mel_torch(spectrogram_torch(y, n_fft, hop_size, win_size, center), n_fft, num_mels, sample_rate, fmin, fmax)
713
-
714
- def replace_keys_in_dict(d, old_key_part, new_key_part):
715
- updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
716
- for key, value in d.items():
717
- updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
718
- return updated_dict
719
-
720
- def extract_model(ckpt, sr, pitch_guidance, name, model_path, epoch, step, version, hps, model_author, vocoder):
721
- try:
722
- logger.info(translations["savemodel"].format(model_dir=model_path, epoch=epoch, step=step))
723
- os.makedirs(os.path.dirname(model_path), exist_ok=True)
724
-
725
- opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key})
726
- opt["config"] = [hps.data.filter_length // 2 + 1, 32, hps.model.inter_channels, hps.model.hidden_channels, hps.model.filter_channels, hps.model.n_heads, hps.model.n_layers, hps.model.kernel_size, hps.model.p_dropout, hps.model.resblock, hps.model.resblock_kernel_sizes, hps.model.resblock_dilation_sizes, hps.model.upsample_rates, hps.model.upsample_initial_channel, hps.model.upsample_kernel_sizes, hps.model.spk_embed_dim, hps.model.gin_channels, hps.data.sample_rate]
727
- opt["epoch"] = f"{epoch}epoch"
728
- opt["step"] = step
729
- opt["sr"] = sr
730
- opt["f0"] = int(pitch_guidance)
731
- opt["version"] = version
732
- opt["creation_date"] = datetime.datetime.now().isoformat()
733
- opt["model_hash"] = hashlib.sha256(f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}".encode()).hexdigest()
734
- opt["model_name"] = name
735
- opt["author"] = model_author
736
- opt["vocoder"] = vocoder
737
-
738
- torch.save(replace_keys_in_dict(replace_keys_in_dict(opt, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), model_path)
739
- except Exception as e:
740
- logger.error(f"{translations['extract_model_error']}: {e}")
741
-
742
- def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, model_author, vocoder, checkpointing):
743
- global global_step
744
-
745
- if rank == 0: writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
746
- else: writer_eval = None
747
-
748
- dist.init_process_group(backend="gloo", init_method="env://", world_size=n_gpus, rank=rank)
749
- torch.manual_seed(config.train.seed)
750
- if torch.cuda.is_available(): torch.cuda.set_device(rank)
751
-
752
- train_dataset = TextAudioLoaderMultiNSFsid(config.data)
753
- train_loader = tdata.DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=TextAudioCollateMultiNSFsid(), batch_sampler=DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True), persistent_workers=True, prefetch_factor=8)
754
-
755
- net_g, net_d = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance, sr=sample_rate, vocoder=vocoder, checkpointing=checkpointing), MultiPeriodDiscriminator(version, config.model.use_spectral_norm, checkpointing=checkpointing)
756
-
757
- if torch.cuda.is_available(): net_g, net_d = net_g.cuda(rank), net_d.cuda(rank)
758
- else: net_g, net_d = net_g.to(device), net_d.to(device)
759
- optim_g, optim_d = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps), torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
760
- net_g, net_d = (DDP(net_g, device_ids=[rank]), DDP(net_d, device_ids=[rank])) if torch.cuda.is_available() else (DDP(net_g), DDP(net_d))
761
-
762
- try:
763
- logger.info(translations["start_training"])
764
- _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d)
765
- _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g)
766
- epoch_str += 1
767
- global_step = (epoch_str - 1) * len(train_loader)
768
- except:
769
- epoch_str, global_step = 1, 0
770
-
771
- if pretrainG != "" and pretrainG != "None":
772
- if rank == 0:
773
- verify_checkpoint_shapes(pretrainG, net_g)
774
- logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG))
775
-
776
- if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
777
- else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
778
- else: logger.warning(translations["not_using_pretrain"].format(dg="G"))
779
-
780
- if pretrainD != "" and pretrainD != "None":
781
- if rank == 0:
782
- verify_checkpoint_shapes(pretrainD, net_d)
783
- logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD))
784
-
785
- if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
786
- else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
787
- else: logger.warning(translations["not_using_pretrain"].format(dg="D"))
788
-
789
- scheduler_g, scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2), torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
790
-
791
- optim_d.step()
792
- optim_g.step()
793
-
794
- scaler = GradScaler(enabled=False)
795
- cache = []
796
-
797
- for info in train_loader:
798
- phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
799
- reference = (phone.cuda(rank, non_blocking=True), phone_lengths.cuda(rank, non_blocking=True), (pitch.cuda(rank, non_blocking=True) if pitch_guidance else None), (pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None), sid.cuda(rank, non_blocking=True)) if device.type == "cuda" else (phone.to(device), phone_lengths.to(device), (pitch.to(device) if pitch_guidance else None), (pitchf.to(device) if pitch_guidance else None), sid.to(device))
800
- break
801
-
802
- for epoch in range(epoch_str, total_epoch + 1):
803
- train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, train_loader, writer_eval, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author, vocoder)
804
- scheduler_g.step()
805
- scheduler_d.step()
806
-
807
- def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, train_loader, writer, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author, vocoder):
808
- global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
809
-
810
- if epoch == 1:
811
- lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
812
- last_loss_gen_all, consecutive_increases_gen, consecutive_increases_disc = 0.0, 0, 0
813
-
814
- net_g, net_d = nets
815
- optim_g, optim_d = optims
816
- train_loader.batch_sampler.set_epoch(epoch)
817
-
818
- net_g.train()
819
- net_d.train()
820
-
821
- if device.type == "cuda" and cache_data_in_gpu:
822
- data_iterator = cache
823
- if cache == []:
824
- for batch_idx, info in enumerate(train_loader):
825
- cache.append((batch_idx, [tensor.cuda(rank, non_blocking=True) for tensor in info]))
826
- else: shuffle(cache)
827
- else: data_iterator = enumerate(train_loader)
828
-
829
- with tqdm(total=len(train_loader), leave=False) as pbar:
830
- for batch_idx, info in data_iterator:
831
- if device.type == "cuda" and not cache_data_in_gpu: info = [tensor.cuda(rank, non_blocking=True) for tensor in info]
832
- elif device.type != "cuda": info = [tensor.to(device) for tensor in info]
833
-
834
- phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, _, sid = info
835
- pitch = pitch if pitch_guidance else None
836
- pitchf = pitchf if pitch_guidance else None
837
-
838
- with autocast(enabled=False):
839
- model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
840
- y_hat, ids_slice, _, z_mask, (_, z_p, m_p, logs_p, _, logs_q) = model_output
841
-
842
- mel = spec_to_mel_torch(spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax)
843
- y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3)
844
-
845
- with autocast(enabled=False):
846
- y_hat_mel = mel_spectrogram_torch(y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax)
847
-
848
- wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3)
849
- y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
850
-
851
- with autocast(enabled=False):
852
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
853
-
854
- optim_d.zero_grad()
855
- scaler.scale(loss_disc).backward()
856
- scaler.unscale_(optim_d)
857
- grad_norm_d = clip_grad_value(net_d.parameters(), None)
858
- scaler.step(optim_d)
859
-
860
- with autocast(enabled=False):
861
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
862
- with autocast(enabled=False):
863
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
864
- loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl)
865
- loss_fm = feature_loss(fmap_r, fmap_g)
866
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
867
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
868
-
869
- if loss_gen_all < lowest_value["value"]:
870
- lowest_value["value"] = loss_gen_all
871
- lowest_value["step"] = global_step
872
- lowest_value["epoch"] = epoch
873
- if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"])
874
-
875
- optim_g.zero_grad()
876
- scaler.scale(loss_gen_all).backward()
877
- scaler.unscale_(optim_g)
878
- grad_norm_g = clip_grad_value(net_g.parameters(), None)
879
- scaler.step(optim_g)
880
- scaler.update()
881
-
882
- if rank == 0:
883
- if global_step % config.train.log_interval == 0:
884
- if loss_mel > 75: loss_mel = 75
885
- if loss_kl > 9: loss_kl = 9
886
-
887
- scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": optim_g.param_groups[0]["lr"], "grad/norm_d": grad_norm_d, "grad/norm_g": grad_norm_g, "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}
888
- scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)})
889
- scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)})
890
- scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)})
891
-
892
- with torch.no_grad():
893
- o, *_ = net_g.module.infer(*reference) if hasattr(net_g, "module") else net_g.infer(*reference)
894
-
895
- summarize(writer=writer, global_step=global_step, images={"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy())}, scalars=scalar_dict, audios={f"gen/audio_{global_step:07d}": o[0, :, :]}, audio_sample_rate=config.data.sample_rate)
896
-
897
- global_step += 1
898
- pbar.update(1)
899
-
900
- def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
901
- if len(smoothed_loss_history) < threshold + 1: return False
902
-
903
- for i in range(-threshold, -1):
904
- if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True
905
- if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False
906
-
907
- return True
908
-
909
- def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987):
910
- smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value)
911
- smoothed_loss_history.append(smoothed_value)
912
- return smoothed_value
913
-
914
- def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history):
915
- with open(file_path, "w") as f:
916
- json.dump({"loss_disc_history": loss_disc_history, "smoothed_loss_disc_history": smoothed_loss_disc_history, "loss_gen_history": loss_gen_history, "smoothed_loss_gen_history": smoothed_loss_gen_history}, f)
917
-
918
- model_add, model_del = [], []
919
- done = False
920
-
921
- if rank == 0:
922
- if epoch % save_every_epoch == False:
923
- checkpoint_suffix = f"{'latest' if save_only_latest else global_step}.pth"
924
- save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix))
925
- save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix))
926
- if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
927
-
928
- if overtraining_detector and epoch > 1:
929
- current_loss_disc = float(loss_disc)
930
- loss_disc_history.append(current_loss_disc)
931
- smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc)
932
- is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2)
933
-
934
- if is_overtraining_disc: consecutive_increases_disc += 1
935
- else: consecutive_increases_disc = 0
936
-
937
- current_loss_gen = float(lowest_value["value"])
938
- loss_gen_history.append(current_loss_gen)
939
- smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen)
940
- is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01)
941
-
942
- if is_overtraining_gen: consecutive_increases_gen += 1
943
- else: consecutive_increases_gen = 0
944
-
945
- if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history)
946
-
947
- if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)):
948
- logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
949
- done = True
950
- else:
951
- logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
952
- for file in glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth")):
953
- model_del.append(file)
954
-
955
- model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth"))
956
-
957
- if epoch >= custom_total_epoch:
958
- logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3)))
959
- logger.info(translations["training_info"].format(lowest_value_rounded=round(float(lowest_value["value"]), 3), lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
960
-
961
- pid_file_path = os.path.join(experiment_dir, "config.json")
962
-
963
- with open(pid_file_path, "r") as pid_file:
964
- pid_data = json.load(pid_file)
965
-
966
- with open(pid_file_path, "w") as pid_file:
967
- pid_data.pop("process_pids", None)
968
- json.dump(pid_data, pid_file, indent=4)
969
-
970
- model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
971
- done = True
972
-
973
- for m in model_del:
974
- os.remove(m)
975
-
976
- if model_add:
977
- ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict())
978
-
979
- for m in model_add:
980
- extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_path=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author, vocoder=vocoder)
981
-
982
- lowest_value_rounded = round(float(lowest_value["value"]), 3)
983
- epoch_recorder = EpochRecorder()
984
-
985
- if epoch > 1 and overtraining_detector: logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=(overtraining_threshold - consecutive_increases_gen), remaining_epochs_disc=((overtraining_threshold * 2) - consecutive_increases_disc), smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
986
- elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
987
- else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record()))
988
-
989
- last_loss_gen_all = loss_gen_all
990
- if done: os._exit(0)
991
-
992
- if __name__ == "__main__":
993
- torch.multiprocessing.set_start_method("spawn")
994
- try:
995
- main()
996
- except Exception as e:
997
- logger.error(f"{translations['training_error']} {e}")
998
-
999
- import traceback
1000
- logger.debug(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/commons.py DELETED
@@ -1,50 +0,0 @@
1
- import torch
2
-
3
-
4
- def init_weights(m, mean=0.0, std=0.01):
5
- if m.__class__.__name__.find("Conv") != -1: m.weight.data.normal_(mean, std)
6
-
7
- def get_padding(kernel_size, dilation=1):
8
- return int((kernel_size * dilation - dilation) / 2)
9
-
10
- def convert_pad_shape(pad_shape):
11
- return [item for sublist in pad_shape[::-1] for item in sublist]
12
-
13
- def slice_segments(x, ids_str, segment_size = 4, dim = 2):
14
- if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
15
- elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
16
- for i in range(x.size(0)):
17
- idx_str = ids_str[i].item()
18
- idx_end = idx_str + segment_size
19
- if dim == 2: ret[i] = x[i, idx_str:idx_end]
20
- else: ret[i] = x[i, :, idx_str:idx_end]
21
- return ret
22
-
23
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
24
- b, _, t = x.size()
25
- if x_lengths is None: x_lengths = t
26
- ids_str = (torch.rand([b]).to(device=x.device) * (x_lengths - segment_size + 1)).to(dtype=torch.long)
27
- return slice_segments(x, ids_str, segment_size, dim=3), ids_str
28
-
29
- @torch.jit.script
30
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
31
- n_channels_int = n_channels[0]
32
- in_act = input_a + input_b
33
- return torch.tanh(in_act[:, :n_channels_int, :]) * torch.sigmoid(in_act[:, n_channels_int:, :])
34
-
35
- def convert_pad_shape(pad_shape):
36
- return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
37
-
38
- def sequence_mask(length, max_length = None):
39
- if max_length is None: max_length = length.max()
40
- return torch.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
41
-
42
- def clip_grad_value(parameters, clip_value, norm_type=2):
43
- if isinstance(parameters, torch.Tensor): parameters = [parameters]
44
- norm_type = float(norm_type)
45
- if clip_value is not None: clip_value = float(clip_value)
46
- total_norm = 0
47
- for p in list(filter(lambda p: p.grad is not None, parameters)):
48
- total_norm += (p.grad.data.norm(norm_type)).item() ** norm_type
49
- if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
50
- return total_norm ** (1.0 / norm_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/modules.py DELETED
@@ -1,70 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
-
5
- sys.path.append(os.getcwd())
6
-
7
- from .commons import fused_add_tanh_sigmoid_multiply
8
-
9
-
10
- class WaveNet(torch.nn.Module):
11
- def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
12
- super(WaveNet, self).__init__()
13
- assert kernel_size % 2 == 1
14
- self.hidden_channels = hidden_channels
15
- self.kernel_size = (kernel_size,)
16
- self.dilation_rate = dilation_rate
17
- self.n_layers = n_layers
18
- self.gin_channels = gin_channels
19
- self.p_dropout = p_dropout
20
- self.in_layers = torch.nn.ModuleList()
21
- self.res_skip_layers = torch.nn.ModuleList()
22
- self.drop = torch.nn.Dropout(p_dropout)
23
-
24
- if gin_channels != 0: self.cond_layer = torch.nn.utils.parametrizations.weight_norm(torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), name="weight")
25
-
26
- dilations = [dilation_rate**i for i in range(n_layers)]
27
- paddings = [(kernel_size * d - d) // 2 for d in dilations]
28
-
29
- for i in range(n_layers):
30
- in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
31
- in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
32
- self.in_layers.append(in_layer)
33
-
34
- res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
35
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
36
-
37
- res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
38
- self.res_skip_layers.append(res_skip_layer)
39
-
40
- def forward(self, x, x_mask, g=None, **kwargs):
41
- output = torch.zeros_like(x)
42
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
43
-
44
- if g is not None: g = self.cond_layer(g)
45
-
46
- for i in range(self.n_layers):
47
- x_in = self.in_layers[i](x)
48
-
49
- if g is not None:
50
- cond_offset = i * 2 * self.hidden_channels
51
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
52
- else: g_l = torch.zeros_like(x_in)
53
-
54
- res_skip_acts = self.res_skip_layers[i](self.drop(fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)))
55
-
56
- if i < self.n_layers - 1:
57
- x = (x + (res_skip_acts[:, : self.hidden_channels, :])) * x_mask
58
- output = output + res_skip_acts[:, self.hidden_channels :, :]
59
- else: output = output + res_skip_acts
60
-
61
- return output * x_mask
62
-
63
- def remove_weight_norm(self):
64
- if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
65
-
66
- for l in self.in_layers:
67
- torch.nn.utils.remove_weight_norm(l)
68
-
69
- for l in self.res_skip_layers:
70
- torch.nn.utils.remove_weight_norm(l)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/mrf_hifigan.py DELETED
@@ -1,160 +0,0 @@
1
- import math
2
- import torch
3
-
4
- import numpy as np
5
- import torch.nn.functional as F
6
- import torch.utils.checkpoint as checkpoint
7
-
8
- from torch.nn.utils import remove_weight_norm
9
- from torch.nn.utils.parametrizations import weight_norm
10
-
11
-
12
- LRELU_SLOPE = 0.1
13
-
14
- class MRFLayer(torch.nn.Module):
15
- def __init__(self, channels, kernel_size, dilation):
16
- super().__init__()
17
- self.conv1 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
18
- self.conv2 = weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
19
-
20
- def forward(self, x):
21
- return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
22
-
23
- def remove_weight_norm(self):
24
- remove_weight_norm(self.conv1)
25
- remove_weight_norm(self.conv2)
26
-
27
- class MRFBlock(torch.nn.Module):
28
- def __init__(self, channels, kernel_size, dilations):
29
- super().__init__()
30
- self.layers = torch.nn.ModuleList()
31
-
32
- for dilation in dilations:
33
- self.layers.append(MRFLayer(channels, kernel_size, dilation))
34
-
35
- def forward(self, x):
36
- for layer in self.layers:
37
- x = layer(x)
38
-
39
- return x
40
-
41
- def remove_weight_norm(self):
42
- for layer in self.layers:
43
- layer.remove_weight_norm()
44
-
45
- class SineGenerator(torch.nn.Module):
46
- def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
47
- super(SineGenerator, self).__init__()
48
- self.sine_amp = sine_amp
49
- self.noise_std = noise_std
50
- self.harmonic_num = harmonic_num
51
- self.dim = self.harmonic_num + 1
52
- self.sampling_rate = samp_rate
53
- self.voiced_threshold = voiced_threshold
54
-
55
- def _f02uv(self, f0):
56
- return torch.ones_like(f0) * (f0 > self.voiced_threshold)
57
-
58
- def _f02sine(self, f0_values):
59
- rad_values = (f0_values / self.sampling_rate) % 1
60
- rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
61
-
62
- rand_ini[:, 0] = 0
63
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
64
-
65
- tmp_over_one = torch.cumsum(rad_values, 1) % 1
66
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
67
-
68
- cumsum_shift = torch.zeros_like(rad_values)
69
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
70
-
71
- return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
72
-
73
- def forward(self, f0):
74
- with torch.no_grad():
75
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
76
- f0_buf[:, :, 0] = f0[:, :, 0]
77
-
78
- for idx in np.arange(self.harmonic_num):
79
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
80
-
81
- sine_waves = self._f02sine(f0_buf) * self.sine_amp
82
- uv = self._f02uv(f0)
83
-
84
- sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
85
-
86
- return sine_waves
87
-
88
- class SourceModuleHnNSF(torch.nn.Module):
89
- def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshold=0):
90
- super(SourceModuleHnNSF, self).__init__()
91
- self.sine_amp = sine_amp
92
- self.noise_std = add_noise_std
93
-
94
- self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
95
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
96
- self.l_tanh = torch.nn.Tanh()
97
-
98
- def forward(self, x):
99
- return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
100
-
101
- class HiFiGANMRFGenerator(torch.nn.Module):
102
- def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing=False):
103
- super().__init__()
104
- self.num_kernels = len(resblock_kernel_sizes)
105
-
106
- self.f0_upsample = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
107
- self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
108
-
109
- self.conv_pre = weight_norm(torch.nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
110
- self.checkpointing = checkpointing
111
-
112
- self.upsamples = torch.nn.ModuleList()
113
- self.noise_convs = torch.nn.ModuleList()
114
-
115
- stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
116
-
117
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
118
- self.upsamples.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=(((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2)), output_padding=u % 2)))
119
- stride = stride_f0s[i]
120
-
121
- kernel = 1 if stride == 1 else stride * 2 - stride % 2
122
- self.noise_convs.append(torch.nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel, stride=stride, padding=( 0 if stride == 1 else (kernel - stride) // 2)))
123
-
124
- self.mrfs = torch.nn.ModuleList()
125
-
126
- for i in range(len(self.upsamples)):
127
- channel = upsample_initial_channel // (2 ** (i + 1))
128
- self.mrfs.append(torch.nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
129
-
130
- self.conv_post = weight_norm(torch.nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
131
- if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
132
-
133
- def forward(self, x, f0, g = None):
134
- har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
135
-
136
- x = self.conv_pre(x)
137
- if g is not None: x = x + self.cond(g)
138
-
139
- for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
140
- x = F.leaky_relu(x, LRELU_SLOPE)
141
- x = checkpoint.checkpoint(ups, x, use_reentrant=False) if self.training and self.checkpointing else ups(x)
142
- x += noise_conv(har_source)
143
-
144
- def mrf_sum(x, layers):
145
- return sum(layer(x) for layer in layers) / self.num_kernels
146
-
147
- x = checkpoint.checkpoint(mrf_sum, x, mrf, use_reentrant=False) if self.training and self.checkpointing else mrf_sum(x, mrf)
148
-
149
- return torch.tanh(self.conv_post(F.leaky_relu(x)))
150
-
151
- def remove_weight_norm(self):
152
- remove_weight_norm(self.conv_pre)
153
-
154
- for up in self.upsamples:
155
- remove_weight_norm(up)
156
-
157
- for mrf in self.mrfs:
158
- mrf.remove_weight_norm()
159
-
160
- remove_weight_norm(self.conv_post)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/refinegan.py DELETED
@@ -1,180 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
-
6
- import numpy as np
7
- import torch.nn.functional as F
8
- import torch.utils.checkpoint as checkpoint
9
-
10
- from torch.nn.utils.parametrizations import weight_norm
11
- from torch.nn.utils.parametrize import remove_parametrizations
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from .commons import get_padding
16
-
17
- class ResBlock(torch.nn.Module):
18
- def __init__(self, *, in_channels, out_channels, kernel_size = 7, dilation = (1, 3, 5), leaky_relu_slope = 0.2):
19
- super(ResBlock, self).__init__()
20
- self.leaky_relu_slope = leaky_relu_slope
21
- self.in_channels = in_channels
22
- self.out_channels = out_channels
23
-
24
- self.convs1 = torch.nn.ModuleList([weight_norm(torch.nn.Conv1d(in_channels=in_channels if idx == 0 else out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for idx, d in enumerate(dilation)])
25
- self.convs1.apply(self.init_weights)
26
-
27
- self.convs2 = torch.nn.ModuleList([weight_norm(torch.nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for _, d in enumerate(dilation)])
28
- self.convs2.apply(self.init_weights)
29
-
30
- def forward(self, x):
31
- for idx, (c1, c2) in enumerate(zip(self.convs1, self.convs2)):
32
- xt = c2(F.leaky_relu(c1(F.leaky_relu(x, self.leaky_relu_slope)), self.leaky_relu_slope))
33
- x = (xt + x) if idx != 0 or self.in_channels == self.out_channels else xt
34
-
35
- return x
36
-
37
- def remove_parametrizations(self):
38
- for c1, c2 in zip(self.convs1, self.convs2):
39
- remove_parametrizations(c1)
40
- remove_parametrizations(c2)
41
-
42
- def init_weights(self, m):
43
- if type(m) == torch.nn.Conv1d:
44
- m.weight.data.normal_(0, 0.01)
45
- m.bias.data.fill_(0.0)
46
-
47
- class AdaIN(torch.nn.Module):
48
- def __init__(self, *, channels, leaky_relu_slope = 0.2):
49
- super().__init__()
50
- self.weight = torch.nn.Parameter(torch.ones(channels))
51
- self.activation = torch.nn.LeakyReLU(leaky_relu_slope)
52
-
53
- def forward(self, x):
54
- return self.activation(x + (torch.randn_like(x) * self.weight[None, :, None]))
55
-
56
- class ParallelResBlock(torch.nn.Module):
57
- def __init__(self, *, in_channels, out_channels, kernel_sizes = (3, 7, 11), dilation = (1, 3, 5), leaky_relu_slope = 0.2):
58
- super().__init__()
59
- self.in_channels = in_channels
60
- self.out_channels = out_channels
61
- self.input_conv = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3)
62
- self.blocks = torch.nn.ModuleList([torch.nn.Sequential(AdaIN(channels=out_channels), ResBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope), AdaIN(channels=out_channels)) for kernel_size in kernel_sizes])
63
-
64
- def forward(self, x):
65
- return torch.mean(torch.stack([block(self.input_conv(x)) for block in self.blocks]), dim=0)
66
-
67
- def remove_parametrizations(self):
68
- for block in self.blocks:
69
- block[1].remove_parametrizations()
70
-
71
- class SineGenerator(torch.nn.Module):
72
- def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
73
- super(SineGenerator, self).__init__()
74
- self.sine_amp = sine_amp
75
- self.noise_std = noise_std
76
- self.harmonic_num = harmonic_num
77
- self.dim = self.harmonic_num + 1
78
- self.sampling_rate = samp_rate
79
- self.voiced_threshold = voiced_threshold
80
- self.merge = torch.nn.Sequential(torch.nn.Linear(self.dim, 1, bias=False), torch.nn.Tanh())
81
-
82
- def _f02uv(self, f0):
83
- return torch.ones_like(f0) * (f0 > self.voiced_threshold)
84
-
85
- def _f02sine(self, f0_values):
86
- rad_values = (f0_values / self.sampling_rate) % 1
87
- rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
88
-
89
- rand_ini[:, 0] = 0
90
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
91
-
92
- tmp_over_one = torch.cumsum(rad_values, 1) % 1
93
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
94
-
95
- cumsum_shift = torch.zeros_like(rad_values)
96
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
97
-
98
- return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
99
-
100
- def forward(self, f0):
101
- with torch.no_grad():
102
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
103
- f0_buf[:, :, 0] = f0[:, :, 0]
104
-
105
- for idx in np.arange(self.harmonic_num):
106
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
107
-
108
- sine_waves = self._f02sine(f0_buf) * self.sine_amp
109
- uv = self._f02uv(f0)
110
-
111
- sine_waves = sine_waves * uv + (uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves)
112
- sine_waves = sine_waves - sine_waves.mean(dim=1, keepdim=True)
113
-
114
- return self.merge(sine_waves)
115
-
116
- class RefineGANGenerator(torch.nn.Module):
117
- def __init__(self, *, sample_rate = 44100, upsample_rates = (8, 8, 2, 2), leaky_relu_slope = 0.2, num_mels = 128, gin_channels = 256, checkpointing = False, upsample_initial_channel = 512):
118
- super().__init__()
119
- self.upsample_rates = upsample_rates
120
- self.checkpointing = checkpointing
121
- self.leaky_relu_slope = leaky_relu_slope
122
- self.upp = np.prod(upsample_rates)
123
- self.m_source = SineGenerator(sample_rate)
124
- self.pre_conv = weight_norm(torch.nn.Conv1d(in_channels=1, out_channels=upsample_initial_channel // 2, kernel_size=7, stride=1, padding=3, bias=False))
125
- channels = upsample_initial_channel
126
- self.downsample_blocks = torch.nn.ModuleList([])
127
-
128
- stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
129
-
130
- for i, _ in enumerate(upsample_rates):
131
- stride = stride_f0s[i]
132
- kernel = 1 if stride == 1 else stride * 2 - stride % 2
133
-
134
- self.downsample_blocks.append(torch.nn.Conv1d(in_channels=1, out_channels=channels // 2 ** (i + 2), kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
135
-
136
- self.mel_conv = weight_norm(torch.nn.Conv1d(in_channels=num_mels, out_channels=channels // 2, kernel_size=7, stride=1, padding=3))
137
- if gin_channels != 0: self.cond = torch.nn.Conv1d(256, channels // 2, 1)
138
-
139
- self.upsample_blocks = torch.nn.ModuleList([])
140
- self.upsample_conv_blocks = torch.nn.ModuleList([])
141
- self.filters = torch.nn.ModuleList([])
142
-
143
- for rate in upsample_rates:
144
- new_channels = channels // 2
145
- self.upsample_blocks.append(torch.nn.Upsample(scale_factor=rate, mode="linear"))
146
-
147
- low_pass = torch.nn.Conv1d(channels, channels, kernel_size=15, padding=7, groups=channels, bias=False)
148
- low_pass.weight.data.fill_(1.0 / 15)
149
-
150
- self.filters.append(low_pass)
151
-
152
- self.upsample_conv_blocks.append(ParallelResBlock(in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope))
153
- channels = new_channels
154
-
155
- self.conv_post = weight_norm(torch.nn.Conv1d(in_channels=channels, out_channels=1, kernel_size=7, stride=1, padding=3))
156
-
157
- def forward(self, mel, f0, g = None):
158
- har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2)
159
- x = F.interpolate(self.pre_conv(har_source), size=mel.shape[-1], mode="linear")
160
-
161
- mel = self.mel_conv(mel)
162
- if g is not None: mel += self.cond(g)
163
-
164
- x = torch.cat([mel, x], dim=1)
165
-
166
- for ups, res, down, flt in zip(self.upsample_blocks, self.upsample_conv_blocks, self.downsample_blocks, self.filters):
167
- x = checkpoint(res, torch.cat([checkpoint(flt, checkpoint(ups, x, use_reentrant=False), use_reentrant=False), down(har_source)], dim=1), use_reentrant=False) if self.training and self.checkpointing else res(torch.cat([flt(ups(x)), down(har_source)], dim=1))
168
-
169
- return torch.tanh_(self.conv_post(F.leaky_relu_(x, self.leaky_relu_slope)))
170
-
171
- def remove_parametrizations(self):
172
- remove_parametrizations(self.source_conv)
173
- remove_parametrizations(self.mel_conv)
174
- remove_parametrizations(self.conv_post)
175
-
176
- for block in self.downsample_blocks:
177
- block[1].remove_parametrizations()
178
-
179
- for block in self.upsample_conv_blocks:
180
- block.remove_parametrizations()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/residuals.py DELETED
@@ -1,140 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
-
5
- from torch.nn.utils import remove_weight_norm
6
- from torch.nn.utils.parametrizations import weight_norm
7
-
8
- sys.path.append(os.getcwd())
9
-
10
- from .modules import WaveNet
11
- from .commons import get_padding, init_weights
12
-
13
-
14
- LRELU_SLOPE = 0.1
15
-
16
- def create_conv1d_layer(channels, kernel_size, dilation):
17
- return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
18
-
19
- def apply_mask(tensor, mask):
20
- return tensor * mask if mask is not None else tensor
21
-
22
- class ResBlockBase(torch.nn.Module):
23
- def __init__(self, channels, kernel_size, dilations):
24
- super(ResBlockBase, self).__init__()
25
-
26
- self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
27
- self.convs1.apply(init_weights)
28
-
29
- self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
30
- self.convs2.apply(init_weights)
31
-
32
- def forward(self, x, x_mask=None):
33
- for c1, c2 in zip(self.convs1, self.convs2):
34
- x = c2(apply_mask(torch.nn.functional.leaky_relu(c1(apply_mask(torch.nn.functional.leaky_relu(x, LRELU_SLOPE), x_mask)), LRELU_SLOPE), x_mask)) + x
35
-
36
- return apply_mask(x, x_mask)
37
-
38
- def remove_weight_norm(self):
39
- for conv in self.convs1 + self.convs2:
40
- remove_weight_norm(conv)
41
-
42
- class ResBlock(ResBlockBase):
43
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
44
- super(ResBlock, self).__init__(channels, kernel_size, dilation)
45
-
46
- class Log(torch.nn.Module):
47
- def forward(self, x, x_mask, reverse=False, **kwargs):
48
- if not reverse:
49
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
50
- return y, torch.sum(-y, [1, 2])
51
- else: return torch.exp(x) * x_mask
52
-
53
- class Flip(torch.nn.Module):
54
- def forward(self, x, *args, reverse=False, **kwargs):
55
- x = torch.flip(x, [1])
56
-
57
- if not reverse: return x, torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
58
- else: return x
59
-
60
- class ElementwiseAffine(torch.nn.Module):
61
- def __init__(self, channels):
62
- super().__init__()
63
- self.channels = channels
64
- self.m = torch.nn.Parameter(torch.zeros(channels, 1))
65
- self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
66
-
67
- def forward(self, x, x_mask, reverse=False, **kwargs):
68
- if not reverse: return ((self.m + torch.exp(self.logs) * x) * x_mask), torch.sum(self.logs * x_mask, [1, 2])
69
- else: return (x - self.m) * torch.exp(-self.logs) * x_mask
70
-
71
- class ResidualCouplingBlock(torch.nn.Module):
72
- def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
73
- super(ResidualCouplingBlock, self).__init__()
74
- self.channels = channels
75
- self.hidden_channels = hidden_channels
76
- self.kernel_size = kernel_size
77
- self.dilation_rate = dilation_rate
78
- self.n_layers = n_layers
79
- self.n_flows = n_flows
80
- self.gin_channels = gin_channels
81
- self.flows = torch.nn.ModuleList()
82
-
83
- for _ in range(n_flows):
84
- self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
85
- self.flows.append(Flip())
86
-
87
- def forward(self, x, x_mask, g = None, reverse = False):
88
- if not reverse:
89
- for flow in self.flows:
90
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
91
- else:
92
- for flow in reversed(self.flows):
93
- x = flow.forward(x, x_mask, g=g, reverse=reverse)
94
-
95
- return x
96
-
97
- def remove_weight_norm(self):
98
- for i in range(self.n_flows):
99
- self.flows[i * 2].remove_weight_norm()
100
-
101
- def __prepare_scriptable__(self):
102
- for i in range(self.n_flows):
103
- for hook in self.flows[i * 2]._forward_pre_hooks.values():
104
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
105
-
106
- return self
107
-
108
- class ResidualCouplingLayer(torch.nn.Module):
109
- def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
110
- assert channels % 2 == 0, "Channels/2"
111
- super().__init__()
112
- self.channels = channels
113
- self.hidden_channels = hidden_channels
114
- self.kernel_size = kernel_size
115
- self.dilation_rate = dilation_rate
116
- self.n_layers = n_layers
117
- self.half_channels = channels // 2
118
- self.mean_only = mean_only
119
-
120
- self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
121
- self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
122
- self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
123
-
124
- self.post.weight.data.zero_()
125
- self.post.bias.data.zero_()
126
-
127
- def forward(self, x, x_mask, g=None, reverse=False):
128
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
129
- stats = self.post(self.enc((self.pre(x0) * x_mask), x_mask, g=g)) * x_mask
130
-
131
- if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
132
- else:
133
- m = stats
134
- logs = torch.zeros_like(m)
135
-
136
- if not reverse: return torch.cat([x0, (m + x1 * torch.exp(logs) * x_mask)], 1), torch.sum(logs, [1, 2])
137
- else: return torch.cat([x0, ((x1 - m) * torch.exp(-logs) * x_mask)], 1)
138
-
139
- def remove_weight_norm(self):
140
- self.enc.remove_weight_norm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/separator.py DELETED
@@ -1,330 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import yaml
5
- import torch
6
- import codecs
7
- import hashlib
8
- import logging
9
- import platform
10
- import warnings
11
- import requests
12
- import onnxruntime
13
-
14
- from importlib import metadata, import_module
15
-
16
- now_dir = os.getcwd()
17
- sys.path.append(now_dir)
18
-
19
- from main.configs.config import Config
20
- translations = Config().translations
21
-
22
- class Separator:
23
- def __init__(self, logger=logging.getLogger(__name__), log_level=logging.INFO, log_formatter=None, model_file_dir="assets/models/uvr5", output_dir=None, output_format="wav", output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}):
24
- self.logger = logger
25
- self.log_level = log_level
26
- self.log_formatter = log_formatter
27
- self.log_handler = logging.StreamHandler()
28
-
29
- if self.log_formatter is None: self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
30
- self.log_handler.setFormatter(self.log_formatter)
31
-
32
- if not self.logger.hasHandlers(): self.logger.addHandler(self.log_handler)
33
- if log_level > logging.DEBUG: warnings.filterwarnings("ignore")
34
-
35
- self.logger.info(translations["separator_info"].format(output_dir=output_dir, output_format=output_format))
36
- self.model_file_dir = model_file_dir
37
-
38
- if output_dir is None:
39
- output_dir = now_dir
40
- self.logger.info(translations["output_dir_is_none"])
41
-
42
- self.output_dir = output_dir
43
-
44
- os.makedirs(self.model_file_dir, exist_ok=True)
45
- os.makedirs(self.output_dir, exist_ok=True)
46
-
47
- self.output_format = output_format
48
- self.output_bitrate = output_bitrate
49
-
50
- if self.output_format is None: self.output_format = "wav"
51
- self.normalization_threshold = normalization_threshold
52
- if normalization_threshold <= 0 or normalization_threshold > 1: raise ValueError(translations[">0or=1"])
53
-
54
- self.output_single_stem = output_single_stem
55
- if output_single_stem is not None: self.logger.debug(translations["output_single"].format(output_single_stem=output_single_stem))
56
-
57
- self.invert_using_spec = invert_using_spec
58
- if self.invert_using_spec: self.logger.debug(translations["step2"])
59
-
60
- self.sample_rate = int(sample_rate)
61
- self.arch_specific_params = {"MDX": mdx_params, "Demucs": demucs_params}
62
- self.torch_device = None
63
- self.torch_device_cpu = None
64
- self.torch_device_mps = None
65
- self.onnx_execution_provider = None
66
- self.model_instance = None
67
- self.model_is_uvr_vip = False
68
- self.model_friendly_name = None
69
- self.setup_accelerated_inferencing_device()
70
-
71
- def setup_accelerated_inferencing_device(self):
72
- system_info = self.get_system_info()
73
- self.log_onnxruntime_packages()
74
- self.setup_torch_device(system_info)
75
-
76
- def get_system_info(self):
77
- os_name = platform.system()
78
- os_version = platform.version()
79
- self.logger.info(f"{translations['os']}: {os_name} {os_version}")
80
- system_info = platform.uname()
81
- self.logger.info(translations["platform_info"].format(system_info=system_info, node=system_info.node, release=system_info.release, machine=system_info.machine, processor=system_info.processor))
82
- python_version = platform.python_version()
83
- self.logger.info(f"{translations['name_ver'].format(name='python')}: {python_version}")
84
- pytorch_version = torch.__version__
85
- self.logger.info(f"{translations['name_ver'].format(name='pytorch')}: {pytorch_version}")
86
-
87
- return system_info
88
-
89
- def log_onnxruntime_packages(self):
90
- onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
91
- onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
92
-
93
- if onnxruntime_gpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='GPU')}: {onnxruntime_gpu_package.version}")
94
- if onnxruntime_cpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='CPU')}: {onnxruntime_cpu_package.version}")
95
-
96
- def setup_torch_device(self, system_info):
97
- hardware_acceleration_enabled = False
98
- ort_providers = onnxruntime.get_available_providers()
99
- self.torch_device_cpu = torch.device("cpu")
100
-
101
- if torch.cuda.is_available():
102
- self.configure_cuda(ort_providers)
103
- hardware_acceleration_enabled = True
104
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
105
- self.configure_mps(ort_providers)
106
- hardware_acceleration_enabled = True
107
-
108
- if not hardware_acceleration_enabled:
109
- self.logger.info(translations["running_in_cpu"])
110
- self.torch_device = self.torch_device_cpu
111
- self.onnx_execution_provider = ["CPUExecutionProvider"]
112
-
113
- def configure_cuda(self, ort_providers):
114
- self.logger.info(translations["running_in_cuda"])
115
- self.torch_device = torch.device("cuda")
116
-
117
- if "CUDAExecutionProvider" in ort_providers:
118
- self.logger.info(translations["onnx_have"].format(have='CUDAExecutionProvider'))
119
- self.onnx_execution_provider = ["CUDAExecutionProvider"]
120
- else: self.logger.warning(translations["onnx_not_have"].format(have='CUDAExecutionProvider'))
121
-
122
- def configure_mps(self, ort_providers):
123
- self.logger.info(translations["set_torch_mps"])
124
- self.torch_device_mps = torch.device("mps")
125
- self.torch_device = self.torch_device_mps
126
-
127
- if "CoreMLExecutionProvider" in ort_providers:
128
- self.logger.info(translations["onnx_have"].format(have='CoreMLExecutionProvider'))
129
- self.onnx_execution_provider = ["CoreMLExecutionProvider"]
130
- else: self.logger.warning(translations["onnx_not_have"].format(have='CoreMLExecutionProvider'))
131
-
132
- def get_package_distribution(self, package_name):
133
- try:
134
- return metadata.distribution(package_name)
135
- except metadata.PackageNotFoundError:
136
- self.logger.debug(translations["python_not_install"].format(package_name=package_name))
137
- return None
138
-
139
- def get_model_hash(self, model_path):
140
- self.logger.debug(translations["hash"].format(model_path=model_path))
141
-
142
- try:
143
- with open(model_path, "rb") as f:
144
- f.seek(-10000 * 1024, 2)
145
- return hashlib.md5(f.read()).hexdigest()
146
- except IOError as e:
147
- self.logger.error(translations["ioerror"].format(e=e))
148
- return hashlib.md5(open(model_path, "rb").read()).hexdigest()
149
-
150
- def download_file_if_not_exists(self, url, output_path):
151
- if os.path.isfile(output_path):
152
- self.logger.debug(translations["cancel_download"].format(output_path=output_path))
153
- return
154
-
155
- self.logger.debug(translations["download_model"].format(url=url, output_path=output_path))
156
- response = requests.get(url, stream=True, timeout=300)
157
-
158
- if response.status_code == 200:
159
- from tqdm import tqdm
160
-
161
- progress_bar = tqdm(total=int(response.headers.get("content-length", 0)), ncols=100, unit="byte")
162
-
163
- with open(output_path, "wb") as f:
164
- for chunk in response.iter_content(chunk_size=8192):
165
- progress_bar.update(len(chunk))
166
- f.write(chunk)
167
-
168
- progress_bar.close()
169
- else: raise RuntimeError(translations["download_error"].format(url=url, status_code=response.status_code))
170
-
171
- def print_uvr_vip_message(self):
172
- if self.model_is_uvr_vip:
173
- self.logger.warning(translations["vip_model"].format(model_friendly_name=self.model_friendly_name))
174
- self.logger.warning(translations["vip_print"])
175
-
176
- def list_supported_model_files(self):
177
- response = requests.get(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/hie_zbqryf.wfba", "rot13"))
178
- response.raise_for_status()
179
- model_downloads_list = response.json()
180
- self.logger.debug(translations["load_download_json"])
181
-
182
- return {"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, "Demucs": {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}}
183
-
184
- def download_model_files(self, model_filename):
185
- model_path = os.path.join(self.model_file_dir, model_filename)
186
- supported_model_files_grouped = self.list_supported_model_files()
187
-
188
- yaml_config_filename = None
189
- self.logger.debug(translations["search_model"].format(model_filename=model_filename))
190
-
191
- for model_type, model_list in supported_model_files_grouped.items():
192
- for model_friendly_name, model_download_list in model_list.items():
193
- self.model_is_uvr_vip = "VIP" in model_friendly_name
194
- model_repo_url_prefix = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/hie5_zbqryf", "rot13")
195
-
196
- if isinstance(model_download_list, str) and model_download_list == model_filename:
197
- self.logger.debug(translations["single_model"].format(model_friendly_name=model_friendly_name))
198
- self.model_friendly_name = model_friendly_name
199
-
200
- try:
201
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/MDX/{model_filename}", model_path)
202
- except RuntimeError:
203
- self.logger.warning(translations["not_found_model"])
204
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{model_filename}", model_path)
205
-
206
- self.print_uvr_vip_message()
207
- self.logger.debug(translations["single_model_path"].format(model_path=model_path))
208
-
209
- return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
210
- elif isinstance(model_download_list, dict):
211
- this_model_matches_input_filename = False
212
-
213
- for file_name, file_url in model_download_list.items():
214
- if file_name == model_filename or file_url == model_filename:
215
- self.logger.debug(translations["find_model"].format(model_filename=model_filename, model_friendly_name=model_friendly_name))
216
- this_model_matches_input_filename = True
217
-
218
- if this_model_matches_input_filename:
219
- self.logger.debug(translations["find_models"].format(model_friendly_name=model_friendly_name))
220
- self.model_friendly_name = model_friendly_name
221
- self.print_uvr_vip_message()
222
-
223
- for config_key, config_value in model_download_list.items():
224
- self.logger.debug(f"{translations['find_path']}: {config_key} -> {config_value}")
225
-
226
- if config_value.startswith("http"): self.download_file_if_not_exists(config_value, os.path.join(self.model_file_dir, config_key))
227
- elif config_key.endswith(".ckpt"):
228
- try:
229
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_key}", os.path.join(self.model_file_dir, config_key))
230
- except RuntimeError:
231
- self.logger.warning(translations["not_found_model_warehouse"])
232
-
233
- if model_filename.endswith(".yaml"):
234
- self.logger.warning(translations["yaml_warning"].format(model_filename=model_filename))
235
- self.logger.warning(translations["yaml_warning_2"].format(config_key=config_key))
236
- self.logger.warning(translations["yaml_warning_3"])
237
-
238
- model_filename = config_key
239
- model_path = os.path.join(self.model_file_dir, f"{model_filename}")
240
-
241
- yaml_config_filename = config_value
242
- yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
243
-
244
- try:
245
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/mdx_c_configs/{yaml_config_filename}", yaml_config_filepath)
246
- except RuntimeError:
247
- self.logger.debug(translations["yaml_debug"])
248
- else: self.download_file_if_not_exists(f"{model_repo_url_prefix}/Demucs/{config_value}", os.path.join(self.model_file_dir, config_value))
249
-
250
- self.logger.debug(translations["download_model_friendly"].format(model_friendly_name=model_friendly_name, model_path=model_path))
251
- return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
252
-
253
- raise ValueError(translations["not_found_model_2"].format(model_filename=model_filename))
254
-
255
- def load_model_data_from_yaml(self, yaml_config_filename):
256
- model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) if not os.path.exists(yaml_config_filename) else yaml_config_filename
257
- self.logger.debug(translations["load_yaml"].format(model_data_yaml_filepath=model_data_yaml_filepath))
258
-
259
- model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
260
- self.logger.debug(translations["load_yaml_2"].format(model_data=model_data))
261
-
262
- if "roformer" in model_data_yaml_filepath: model_data["is_roformer"] = True
263
- return model_data
264
-
265
- def load_model_data_using_hash(self, model_path):
266
- self.logger.debug(translations["hash_md5"])
267
- model_hash = self.get_model_hash(model_path)
268
-
269
- self.logger.debug(translations["model_hash"].format(model_path=model_path, model_hash=model_hash))
270
- mdx_model_data_path = codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/enj/znva/wfba/zbqry_qngn.wfba", "rot13")
271
- self.logger.debug(translations["mdx_data"].format(mdx_model_data_path=mdx_model_data_path))
272
-
273
- response = requests.get(mdx_model_data_path)
274
- response.raise_for_status()
275
-
276
- mdx_model_data_object = response.json()
277
- self.logger.debug(translations["load_mdx"])
278
-
279
- if model_hash in mdx_model_data_object: model_data = mdx_model_data_object[model_hash]
280
- else: raise ValueError(translations["model_not_support"].format(model_hash=model_hash))
281
-
282
- self.logger.debug(translations["uvr_json"].format(model_hash=model_hash, model_data=model_data))
283
- return model_data
284
-
285
- def load_model(self, model_filename):
286
- self.logger.info(translations["loading_model"].format(model_filename=model_filename))
287
- load_model_start_time = time.perf_counter()
288
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
289
- self.logger.debug(translations["download_model_friendly_2"].format(model_friendly_name=model_friendly_name, model_path=model_path))
290
-
291
- if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
292
-
293
- common_params = {"logger": self.logger, "log_level": self.log_level, "torch_device": self.torch_device, "torch_device_cpu": self.torch_device_cpu, "torch_device_mps": self.torch_device_mps, "onnx_execution_provider": self.onnx_execution_provider, "model_name": model_filename.split(".")[0], "model_path": model_path, "model_data": self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path), "output_format": self.output_format, "output_bitrate": self.output_bitrate, "output_dir": self.output_dir, "normalization_threshold": self.normalization_threshold, "output_single_stem": self.output_single_stem, "invert_using_spec": self.invert_using_spec, "sample_rate": self.sample_rate}
294
- separator_classes = {"MDX": "mdx_separator.MDXSeparator", "Demucs": "demucs_separator.DemucsSeparator"}
295
-
296
- if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(translations["model_type_not_support"].format(model_type=model_type))
297
- if model_type == "Demucs" and sys.version_info < (3, 10): raise Exception(translations["demucs_not_support_python<3.10"])
298
-
299
- self.logger.debug(f"{translations['import_module']} {model_type}: {separator_classes[model_type]}")
300
- module_name, class_name = separator_classes[model_type].split(".")
301
- separator_class = getattr(import_module(f"main.library.architectures.{module_name}"), class_name)
302
-
303
- self.logger.debug(f"{translations['initialization']} {model_type}: {separator_class}")
304
- self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
305
-
306
- self.logger.debug(translations["loading_model_success"])
307
- self.logger.info(f"{translations['loading_model_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - load_model_start_time)))}")
308
-
309
- def separate(self, audio_file_path):
310
- self.logger.info(f"{translations['starting_separator']}: {audio_file_path}")
311
- separate_start_time = time.perf_counter()
312
-
313
- self.logger.debug(translations["normalization"].format(normalization_threshold=self.normalization_threshold))
314
- output_files = self.model_instance.separate(audio_file_path)
315
-
316
- self.model_instance.clear_gpu_cache()
317
- self.model_instance.clear_file_specific_paths()
318
-
319
- self.print_uvr_vip_message()
320
-
321
- self.logger.debug(translations["separator_success_3"])
322
- self.logger.info(f"{translations['separator_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - separate_start_time)))}")
323
- return output_files
324
-
325
- def download_model_and_data(self, model_filename):
326
- self.logger.info(translations["loading_separator_model"].format(model_filename=model_filename))
327
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
328
-
329
- if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
330
- self.logger.info(translations["downloading_model"].format(model_type=model_type, model_friendly_name=model_friendly_name, model_path=model_path, model_data_dict_size=len(self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path))))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/synthesizers.py DELETED
@@ -1,450 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
-
6
- import torch.nn.functional as F
7
- import torch.utils.checkpoint as checkpoint
8
-
9
- from torch.nn.utils import remove_weight_norm
10
- from torch.nn.utils.parametrizations import weight_norm
11
-
12
- sys.path.append(os.getcwd())
13
-
14
- from .modules import WaveNet
15
- from .refinegan import RefineGANGenerator
16
- from .mrf_hifigan import HiFiGANMRFGenerator
17
- from .residuals import ResidualCouplingBlock, ResBlock, LRELU_SLOPE
18
- from .commons import init_weights, slice_segments, rand_slice_segments, sequence_mask, convert_pad_shape
19
-
20
- class Generator(torch.nn.Module):
21
- def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
22
- super(Generator, self).__init__()
23
- self.num_kernels = len(resblock_kernel_sizes)
24
- self.num_upsamples = len(upsample_rates)
25
- self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
26
- self.ups_and_resblocks = torch.nn.ModuleList()
27
-
28
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
29
- self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
30
- ch = upsample_initial_channel // (2 ** (i + 1))
31
-
32
- for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
33
- self.ups_and_resblocks.append(ResBlock(ch, k, d))
34
-
35
- self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
36
- self.ups_and_resblocks.apply(init_weights)
37
- if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
38
-
39
- def forward(self, x, g = None):
40
- x = self.conv_pre(x)
41
- if g is not None: x = x + self.cond(g)
42
- resblock_idx = 0
43
-
44
- for _ in range(self.num_upsamples):
45
- x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
46
- resblock_idx += 1
47
- xs = 0
48
-
49
- for _ in range(self.num_kernels):
50
- xs += self.ups_and_resblocks[resblock_idx](x)
51
- resblock_idx += 1
52
-
53
- x = xs / self.num_kernels
54
-
55
- return torch.tanh(self.conv_post(F.leaky_relu(x)))
56
-
57
- def __prepare_scriptable__(self):
58
- for l in self.ups_and_resblocks:
59
- for hook in l._forward_pre_hooks.values():
60
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
61
-
62
- return self
63
-
64
- def remove_weight_norm(self):
65
- for l in self.ups_and_resblocks:
66
- remove_weight_norm(l)
67
-
68
- class SineGen(torch.nn.Module):
69
- def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
70
- super(SineGen, self).__init__()
71
- self.sine_amp = sine_amp
72
- self.noise_std = noise_std
73
- self.harmonic_num = harmonic_num
74
- self.dim = self.harmonic_num + 1
75
- self.sample_rate = samp_rate
76
- self.voiced_threshold = voiced_threshold
77
-
78
- def _f02uv(self, f0):
79
- return torch.ones_like(f0) * (f0 > self.voiced_threshold)
80
-
81
- def forward(self, f0, upp):
82
- with torch.no_grad():
83
- f0 = f0[:, None].transpose(1, 2)
84
-
85
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
86
- f0_buf[:, :, 0] = f0[:, :, 0]
87
- f0_buf[:, :, 1:] = (f0_buf[:, :, 0:1] * torch.arange(2, self.harmonic_num + 2, device=f0.device)[None, None, :])
88
-
89
- rad_values = (f0_buf / float(self.sample_rate)) % 1
90
- rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
91
- rand_ini[:, 0] = 0
92
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
93
-
94
- tmp_over_one = torch.cumsum(rad_values, 1)
95
- tmp_over_one *= upp
96
- tmp_over_one = F.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1)
97
-
98
- rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
99
- tmp_over_one %= 1
100
-
101
- cumsum_shift = torch.zeros_like(rad_values)
102
- cumsum_shift[:, 1:, :] = ((tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0) * -1.0
103
-
104
- uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
105
- sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi) * self.sine_amp
106
- sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
107
-
108
- return sine_waves
109
-
110
- class SourceModuleHnNSF(torch.nn.Module):
111
- def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
112
- super(SourceModuleHnNSF, self).__init__()
113
- self.sine_amp = sine_amp
114
- self.noise_std = add_noise_std
115
- self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
116
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
117
- self.l_tanh = torch.nn.Tanh()
118
-
119
- def forward(self, x, upsample_factor = 1):
120
- return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
121
-
122
- class GeneratorNSF(torch.nn.Module):
123
- def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
124
- super(GeneratorNSF, self).__init__()
125
- self.num_kernels = len(resblock_kernel_sizes)
126
- self.num_upsamples = len(upsample_rates)
127
- self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
128
- self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
129
- self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
130
- self.checkpointing = checkpointing
131
- self.ups = torch.nn.ModuleList()
132
- self.noise_convs = torch.nn.ModuleList()
133
- channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(len(upsample_rates))]
134
- stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
135
-
136
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
137
- self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=(k - u) // 2)))
138
- self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=(stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1), stride=stride_f0s[i], padding=(stride_f0s[i] // 2 if stride_f0s[i] > 1 else 0)))
139
-
140
- self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
141
- self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
142
- self.ups.apply(init_weights)
143
-
144
- if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
145
-
146
- self.upp = math.prod(upsample_rates)
147
- self.lrelu_slope = LRELU_SLOPE
148
-
149
- def forward(self, x, f0, g = None):
150
- har_source = self.m_source(f0, self.upp).transpose(1, 2)
151
- x = self.conv_pre(x)
152
- if g is not None: x = x + self.cond(g)
153
-
154
- for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
155
- x = F.leaky_relu(x, self.lrelu_slope)
156
- x = checkpoint.checkpoint(ups, x, use_reentrant=False) if self.training and self.checkpointing else ups(x)
157
- x += noise_convs(har_source)
158
-
159
- def resblock_forward(x, blocks):
160
- return sum(block(x) for block in blocks) / len(blocks)
161
-
162
- blocks = self.resblocks[i * self.num_kernels:(i + 1) * self.num_kernels]
163
- x = checkpoint.checkpoint(resblock_forward, x, blocks, use_reentrant=False)if self.training and self.checkpointing else resblock_forward(x, blocks)
164
-
165
- return torch.tanh(self.conv_post(F.leaky_relu(x)))
166
-
167
- def remove_weight_norm(self):
168
- for l in self.ups:
169
- remove_weight_norm(l)
170
-
171
- for l in self.resblocks:
172
- l.remove_weight_norm()
173
-
174
- class LayerNorm(torch.nn.Module):
175
- def __init__(self, channels, eps=1e-5):
176
- super().__init__()
177
- self.eps = eps
178
- self.gamma = torch.nn.Parameter(torch.ones(channels))
179
- self.beta = torch.nn.Parameter(torch.zeros(channels))
180
-
181
- def forward(self, x):
182
- x = x.transpose(1, -1)
183
- return F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps).transpose(1, -1)
184
-
185
- class MultiHeadAttention(torch.nn.Module):
186
- def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
187
- super().__init__()
188
- assert channels % n_heads == 0
189
- self.channels = channels
190
- self.out_channels = out_channels
191
- self.n_heads = n_heads
192
- self.p_dropout = p_dropout
193
- self.window_size = window_size
194
- self.heads_share = heads_share
195
- self.block_length = block_length
196
- self.proximal_bias = proximal_bias
197
- self.proximal_init = proximal_init
198
- self.attn = None
199
- self.k_channels = channels // n_heads
200
- self.conv_q = torch.nn.Conv1d(channels, channels, 1)
201
- self.conv_k = torch.nn.Conv1d(channels, channels, 1)
202
- self.conv_v = torch.nn.Conv1d(channels, channels, 1)
203
- self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
204
- self.drop = torch.nn.Dropout(p_dropout)
205
-
206
- if window_size is not None:
207
- n_heads_rel = 1 if heads_share else n_heads
208
- rel_stddev = self.k_channels**-0.5
209
- self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
210
- self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
211
-
212
- torch.nn.init.xavier_uniform_(self.conv_q.weight)
213
- torch.nn.init.xavier_uniform_(self.conv_k.weight)
214
- torch.nn.init.xavier_uniform_(self.conv_v.weight)
215
-
216
- if proximal_init:
217
- with torch.no_grad():
218
- self.conv_k.weight.copy_(self.conv_q.weight)
219
- self.conv_k.bias.copy_(self.conv_q.bias)
220
-
221
- def forward(self, x, c, attn_mask=None):
222
- q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
223
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
224
- return self.conv_o(x)
225
-
226
- def attention(self, query, key, value, mask=None):
227
- b, d, t_s, t_t = (*key.size(), query.size(2))
228
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
229
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
230
-
231
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
232
- if self.window_size is not None:
233
- assert (t_s == t_t), "(t_s == t_t)"
234
- scores = scores + self._relative_position_to_absolute_position(self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), self._get_relative_embeddings(self.emb_rel_k, t_s)))
235
-
236
- if self.proximal_bias:
237
- assert t_s == t_t, "t_s == t_t"
238
- scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
239
-
240
- if mask is not None:
241
- scores = scores.masked_fill(mask == 0, -1e4)
242
- if self.block_length is not None:
243
- assert (t_s == t_t), "(t_s == t_t)"
244
- scores = scores.masked_fill((torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)) == 0, -1e4)
245
-
246
- p_attn = self.drop(F.softmax(scores, dim=-1) )
247
- output = torch.matmul(p_attn, value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3))
248
-
249
- if self.window_size is not None: output = output + self._matmul_with_relative_values(self._absolute_position_to_relative_position(p_attn), self._get_relative_embeddings(self.emb_rel_v, t_s))
250
- return (output.transpose(2, 3).contiguous().view(b, d, t_t)), p_attn
251
-
252
- def _matmul_with_relative_values(self, x, y):
253
- return torch.matmul(x, y.unsqueeze(0))
254
-
255
- def _matmul_with_relative_keys(self, x, y):
256
- return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
257
-
258
- def _get_relative_embeddings(self, relative_embeddings, length):
259
- pad_length = max(length - (self.window_size + 1), 0)
260
- slice_start_position = max((self.window_size + 1) - length, 0)
261
- return (F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
262
-
263
- def _relative_position_to_absolute_position(self, x):
264
- batch, heads, length, _ = x.size()
265
- return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])).view([batch, heads, length * 2 * length]), convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
266
-
267
- def _absolute_position_to_relative_position(self, x):
268
- batch, heads, length, _ = x.size()
269
- return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length**2 + length * (length - 1)]), convert_pad_shape([[0, 0], [0, 0], [length, 0]])).view([batch, heads, length, 2 * length])[:, :, :, 1:]
270
-
271
- def _attention_bias_proximal(self, length):
272
- r = torch.arange(length, dtype=torch.float32)
273
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs((torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)))), 0), 0)
274
-
275
- class FFN(torch.nn.Module):
276
- def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False):
277
- super().__init__()
278
- self.in_channels = in_channels
279
- self.out_channels = out_channels
280
- self.filter_channels = filter_channels
281
- self.kernel_size = kernel_size
282
- self.p_dropout = p_dropout
283
- self.activation = activation
284
- self.causal = causal
285
- self.padding = self._causal_padding if causal else self._same_padding
286
- self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size)
287
- self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size)
288
- self.drop = torch.nn.Dropout(p_dropout)
289
-
290
- def forward(self, x, x_mask):
291
- x = self.conv_1(self.padding(x * x_mask))
292
- return self.conv_2(self.padding(self.drop(((x * torch.sigmoid(1.702 * x)) if self.activation == "gelu" else torch.relu(x))) * x_mask)) * x_mask
293
-
294
- def _causal_padding(self, x):
295
- if self.kernel_size == 1: return x
296
-
297
- return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1), 0]]))
298
-
299
- def _same_padding(self, x):
300
- if self.kernel_size == 1: return x
301
-
302
- return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [((self.kernel_size - 1) // 2), (self.kernel_size // 2)]]))
303
-
304
- class Encoder(torch.nn.Module):
305
- def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs):
306
- super().__init__()
307
- self.hidden_channels = hidden_channels
308
- self.filter_channels = filter_channels
309
- self.n_heads = n_heads
310
- self.n_layers = n_layers
311
- self.kernel_size = kernel_size
312
- self.p_dropout = p_dropout
313
- self.window_size = window_size
314
- self.drop = torch.nn.Dropout(p_dropout)
315
- self.attn_layers = torch.nn.ModuleList()
316
- self.norm_layers_1 = torch.nn.ModuleList()
317
- self.ffn_layers = torch.nn.ModuleList()
318
- self.norm_layers_2 = torch.nn.ModuleList()
319
-
320
- for _ in range(self.n_layers):
321
- self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
322
- self.norm_layers_1.append(LayerNorm(hidden_channels))
323
-
324
- self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
325
- self.norm_layers_2.append(LayerNorm(hidden_channels))
326
-
327
- def forward(self, x, x_mask):
328
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
329
- x = x * x_mask
330
-
331
- for i in range(self.n_layers):
332
- x = self.norm_layers_1[i](x + self.drop(self.attn_layers[i](x, x, attn_mask)))
333
- x = self.norm_layers_2[i](x + self.drop(self.ffn_layers[i](x, x_mask)))
334
-
335
- return x * x_mask
336
-
337
- class TextEncoder(torch.nn.Module):
338
- def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True):
339
- super(TextEncoder, self).__init__()
340
- self.out_channels = out_channels
341
- self.hidden_channels = hidden_channels
342
- self.filter_channels = filter_channels
343
- self.n_heads = n_heads
344
- self.n_layers = n_layers
345
- self.kernel_size = kernel_size
346
- self.p_dropout = float(p_dropout)
347
- self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
348
- self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
349
- if f0: self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
350
- self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
351
- self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
352
-
353
- def forward(self, phone, pitch, lengths):
354
- x = self.emb_phone(phone) if pitch is None else (self.emb_phone(phone) + self.emb_pitch(pitch))
355
- x = torch.transpose(self.lrelu((x * math.sqrt(self.hidden_channels))), 1, -1)
356
- x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
357
- m, logs = torch.split((self.proj(self.encoder(x * x_mask, x_mask)) * x_mask), self.out_channels, dim=1)
358
- return m, logs, x_mask
359
-
360
- class PosteriorEncoder(torch.nn.Module):
361
- def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
362
- super(PosteriorEncoder, self).__init__()
363
- self.in_channels = in_channels
364
- self.out_channels = out_channels
365
- self.hidden_channels = hidden_channels
366
- self.kernel_size = kernel_size
367
- self.dilation_rate = dilation_rate
368
- self.n_layers = n_layers
369
- self.gin_channels = gin_channels
370
- self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
371
- self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
372
- self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
373
-
374
- def forward(self, x, x_lengths, g = None):
375
- x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
376
- m, logs = torch.split((self.proj(self.enc((self.pre(x) * x_mask), x_mask, g=g)) * x_mask), self.out_channels, dim=1)
377
- return ((m + torch.randn_like(m) * torch.exp(logs)) * x_mask), m, logs, x_mask
378
-
379
- def remove_weight_norm(self):
380
- self.enc.remove_weight_norm()
381
-
382
- class Synthesizer(torch.nn.Module):
383
- def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing = False, **kwargs):
384
- super(Synthesizer, self).__init__()
385
- self.spec_channels = spec_channels
386
- self.inter_channels = inter_channels
387
- self.hidden_channels = hidden_channels
388
- self.filter_channels = filter_channels
389
- self.n_heads = n_heads
390
- self.n_layers = n_layers
391
- self.kernel_size = kernel_size
392
- self.p_dropout = float(p_dropout)
393
- self.resblock_kernel_sizes = resblock_kernel_sizes
394
- self.resblock_dilation_sizes = resblock_dilation_sizes
395
- self.upsample_rates = upsample_rates
396
- self.upsample_initial_channel = upsample_initial_channel
397
- self.upsample_kernel_sizes = upsample_kernel_sizes
398
- self.segment_size = segment_size
399
- self.gin_channels = gin_channels
400
- self.spk_embed_dim = spk_embed_dim
401
- self.use_f0 = use_f0
402
- self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0)
403
-
404
- if use_f0:
405
- if vocoder == "RefineGAN": self.dec = RefineGANGenerator(sample_rate=sr, upsample_rates=upsample_rates, num_mels=inter_channels, checkpointing=checkpointing)
406
- elif vocoder == "MRF HiFi-GAN": self.dec = HiFiGANMRFGenerator(in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing)
407
- else: self.dec = GeneratorNSF(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing)
408
- else: self.dec = Generator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
409
-
410
- self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
411
- self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
412
- self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
413
-
414
- def remove_weight_norm(self):
415
- self.dec.remove_weight_norm()
416
- self.flow.remove_weight_norm()
417
- self.enc_q.remove_weight_norm()
418
-
419
- @torch.jit.ignore
420
- def forward(self, phone, phone_lengths, pitch = None, pitchf = None, y = None, y_lengths = None, ds = None):
421
- g = self.emb_g(ds).unsqueeze(-1)
422
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
423
-
424
- if y is not None:
425
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
426
- z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
427
- return (self.dec(z_slice, slice_segments(pitchf, ids_slice, self.segment_size, 2), g=g) if self.use_f0 else self.dec(z_slice, g=g)), ids_slice, x_mask, y_mask, (z, self.flow(z, y_mask, g=g), m_p, logs_p, m_q, logs_q)
428
- else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
429
-
430
- @torch.jit.export
431
- def infer(self, phone, phone_lengths, pitch = None, nsff0 = None, sid = None, rate = None):
432
- g = self.emb_g(sid).unsqueeze(-1)
433
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
434
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
435
-
436
- if rate is not None:
437
- assert isinstance(rate, torch.Tensor)
438
- head = int(z_p.shape[2] * (1.0 - rate.item()))
439
- z_p = z_p[:, :, head:]
440
- x_mask = x_mask[:, :, head:]
441
- if self.use_f0: nsff0 = nsff0[:, head:]
442
-
443
- if self.use_f0:
444
- z = self.flow(z_p, x_mask, g=g, reverse=True)
445
- o = self.dec(z * x_mask, nsff0, g=g)
446
- else:
447
- z = self.flow(z_p, x_mask, g=g, reverse=True)
448
- o = self.dec(z * x_mask, g=g)
449
-
450
- return o, x_mask, (z, z_p, m_p, logs_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/architectures/demucs_separator.py DELETED
@@ -1,160 +0,0 @@
1
- import os
2
- import sys
3
- import yaml
4
- import torch
5
-
6
- import numpy as np
7
-
8
- from pathlib import Path
9
- from hashlib import sha256
10
-
11
- sys.path.append(os.getcwd())
12
-
13
- from main.configs.config import Config
14
- from main.library.uvr5_separator import spec_utils, common_separator
15
- from main.library.uvr5_separator.demucs import hdemucs, states, apply
16
-
17
- translations = Config().translations
18
- sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator"))
19
- DEMUCS_4_SOURCE_MAPPER = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3}
20
-
21
- class DemucsSeparator(common_separator.CommonSeparator):
22
- def __init__(self, common_config, arch_config):
23
- super().__init__(config=common_config)
24
- self.segment_size = arch_config.get("segment_size", "Default")
25
- self.shifts = arch_config.get("shifts", 2)
26
- self.overlap = arch_config.get("overlap", 0.25)
27
- self.segments_enabled = arch_config.get("segments_enabled", True)
28
- self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled))
29
- self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap))
30
- self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
31
- self.audio_file_path = None
32
- self.audio_file_base = None
33
- self.demucs_model_instance = None
34
- self.logger.info(translations["start_demucs"])
35
-
36
- def separate(self, audio_file_path):
37
- self.logger.debug(translations["start_separator"])
38
- source = None
39
- inst_source = {}
40
- self.audio_file_path = audio_file_path
41
- self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
42
- self.logger.debug(translations["prepare_mix"])
43
- mix = self.prepare_mix(self.audio_file_path)
44
- self.logger.debug(translations["demix"].format(shape=mix.shape))
45
- self.logger.debug(translations["cancel_mix"])
46
- self.demucs_model_instance = hdemucs.HDemucs(sources=["drums", "bass", "other", "vocals"])
47
- self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path)))
48
- self.demucs_model_instance = apply.demucs_segments(self.segment_size, self.demucs_model_instance)
49
- self.demucs_model_instance.to(self.torch_device)
50
- self.demucs_model_instance.eval()
51
- self.logger.debug(translations["model_review"])
52
- source = self.demix_demucs(mix)
53
- del self.demucs_model_instance
54
- self.clear_gpu_cache()
55
- self.logger.debug(translations["del_gpu_cache_after_demix"])
56
- output_files = []
57
- self.logger.debug(translations["process_output_file"])
58
-
59
- if isinstance(inst_source, np.ndarray):
60
- self.logger.debug(translations["process_ver"])
61
- inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]] = spec_utils.reshape_sources(inst_source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[common_separator.CommonSeparator.VOCAL_STEM]])
62
- source = inst_source
63
-
64
- if isinstance(source, np.ndarray):
65
- source_length = len(source)
66
- self.logger.debug(translations["source_length"].format(source_length=source_length))
67
- self.logger.debug(translations["set_map"].format(part=source_length))
68
- match source_length:
69
- case 2: self.demucs_source_map = {common_separator.CommonSeparator.INST_STEM: 0, common_separator.CommonSeparator.VOCAL_STEM: 1}
70
- case 6: self.demucs_source_map = {common_separator.CommonSeparator.BASS_STEM: 0, common_separator.CommonSeparator.DRUM_STEM: 1, common_separator.CommonSeparator.OTHER_STEM: 2, common_separator.CommonSeparator.VOCAL_STEM: 3, common_separator.CommonSeparator.GUITAR_STEM: 4, common_separator.CommonSeparator.PIANO_STEM: 5}
71
- case _: self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
72
-
73
- self.logger.debug(translations["process_all_part"])
74
- for stem_name, stem_value in self.demucs_source_map.items():
75
- if self.output_single_stem is not None:
76
- if stem_name.lower() != self.output_single_stem.lower():
77
- self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem))
78
- continue
79
- stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
80
- self.final_process(stem_path, source[stem_value].T, stem_name)
81
- output_files.append(stem_path)
82
- return output_files
83
-
84
- def demix_demucs(self, mix):
85
- self.logger.debug(translations["starting_demix_demucs"])
86
- processed = {}
87
- mix = torch.tensor(mix, dtype=torch.float32)
88
- ref = mix.mean(0)
89
- mix = (mix - ref.mean()) / ref.std()
90
- mix_infer = mix
91
- with torch.no_grad():
92
- self.logger.debug(translations["model_infer"])
93
- sources = apply.apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0]
94
- sources = (sources * ref.std() + ref.mean()).cpu().numpy()
95
- sources[[0, 1]] = sources[[1, 0]]
96
- processed[mix] = sources[:, :, 0:None].copy()
97
- return np.concatenate([s[:, :, 0:None] for s in list(processed.values())], axis=-1)
98
-
99
- class LocalRepo:
100
- def __init__(self, root):
101
- self.root = root
102
- self.scan()
103
-
104
- def scan(self):
105
- self._models, self._checksums = {}, {}
106
- for file in self.root.iterdir():
107
- if file.suffix == ".th":
108
- if "-" in file.stem:
109
- xp_sig, checksum = file.stem.split("-")
110
- self._checksums[xp_sig] = checksum
111
- else: xp_sig = file.stem
112
-
113
- if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig))
114
- self._models[xp_sig] = file
115
-
116
- def has_model(self, sig):
117
- return sig in self._models
118
-
119
- def get_model(self, sig):
120
- try:
121
- file = self._models[sig]
122
- except KeyError:
123
- raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
124
-
125
- if sig in self._checksums: check_checksum(file, self._checksums[sig])
126
- return states.load_model(file)
127
-
128
- class BagOnlyRepo:
129
- def __init__(self, root, model_repo):
130
- self.root = root
131
- self.model_repo = model_repo
132
- self.scan()
133
-
134
- def scan(self):
135
- self._bags = {}
136
- for file in self.root.iterdir():
137
- if file.suffix == ".yaml": self._bags[file.stem] = file
138
-
139
- def get_model(self, name):
140
- try:
141
- yaml_file = self._bags[name]
142
- except KeyError:
143
- raise RuntimeError(translations["name_not_pretrained"].format(name=name))
144
- bag = yaml.safe_load(open(yaml_file))
145
- return apply.BagOfModels([self.model_repo.get_model(sig) for sig in bag["models"]], bag.get("weights"), bag.get("segment"))
146
-
147
- def check_checksum(path, checksum):
148
- sha = sha256()
149
- with open(path, "rb") as file:
150
- while 1:
151
- buf = file.read(2**20)
152
- if not buf: break
153
- sha.update(buf)
154
-
155
- actual_checksum = sha.hexdigest()[: len(checksum)]
156
- if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum))
157
-
158
- def get_demucs_model(name, repo = None):
159
- model_repo = LocalRepo(repo)
160
- return (model_repo.get_model(name) if model_repo.has_model(name) else BagOnlyRepo(repo, model_repo).get_model(name)).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/architectures/mdx_separator.py DELETED
@@ -1,320 +0,0 @@
1
- import os
2
- import sys
3
- import onnx
4
- import torch
5
- import platform
6
- import onnx2torch
7
-
8
- import numpy as np
9
- import onnxruntime as ort
10
-
11
- from tqdm import tqdm
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from main.configs.config import Config
16
- from main.library.uvr5_separator import spec_utils
17
- from main.library.uvr5_separator.common_separator import CommonSeparator
18
-
19
- translations = Config().translations
20
-
21
- class MDXSeparator(CommonSeparator):
22
- def __init__(self, common_config, arch_config):
23
- super().__init__(config=common_config)
24
- self.segment_size = arch_config.get("segment_size")
25
- self.overlap = arch_config.get("overlap")
26
- self.batch_size = arch_config.get("batch_size", 1)
27
- self.hop_length = arch_config.get("hop_length")
28
- self.enable_denoise = arch_config.get("enable_denoise")
29
- self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
30
- self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
31
- self.compensate = self.model_data["compensate"]
32
- self.dim_f = self.model_data["mdx_dim_f_set"]
33
- self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
34
- self.n_fft = self.model_data["mdx_n_fft_scale_set"]
35
- self.config_yaml = self.model_data.get("config_yaml", None)
36
- self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
37
- self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
38
- self.load_model()
39
- self.n_bins = 0
40
- self.trim = 0
41
- self.chunk_size = 0
42
- self.gen_size = 0
43
- self.stft = None
44
- self.primary_source = None
45
- self.secondary_source = None
46
- self.audio_file_path = None
47
- self.audio_file_base = None
48
-
49
- def load_model(self):
50
- self.logger.debug(translations["load_model_onnx"])
51
-
52
- if self.segment_size == self.dim_t:
53
- ort_session_options = ort.SessionOptions()
54
- ort_session_options.log_severity_level = 3 if self.log_level > 10 else 0
55
- ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
56
- self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
57
- self.logger.debug(translations["load_model_onnx_success"])
58
- else:
59
- self.model_run = onnx2torch.convert(onnx.load(self.model_path)) if platform.system() == 'Windows' else onnx2torch.convert(self.model_path)
60
- self.model_run.to(self.torch_device).eval()
61
- self.logger.debug(translations["onnx_to_pytorch"])
62
-
63
- def separate(self, audio_file_path):
64
- self.audio_file_path = audio_file_path
65
- self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
66
- self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
67
- mix = self.prepare_mix(self.audio_file_path)
68
- self.logger.debug(translations["normalization_demix"])
69
- mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
70
- source = self.demix(mix)
71
- self.logger.debug(translations["mix_success"])
72
- output_files = []
73
- self.logger.debug(translations["process_output_file"])
74
-
75
- if not isinstance(self.primary_source, np.ndarray):
76
- self.logger.debug(translations["primary_source"])
77
- self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
78
-
79
- if not isinstance(self.secondary_source, np.ndarray):
80
- self.logger.debug(translations["secondary_source"])
81
- raw_mix = self.demix(mix, is_match_mix=True)
82
-
83
- if self.invert_using_spec:
84
- self.logger.debug(translations["invert_using_spec"])
85
- self.secondary_source = spec_utils.invert_stem(raw_mix, source)
86
- else:
87
- self.logger.debug(translations["invert_using_spec_2"])
88
- self.secondary_source = mix.T - source.T
89
-
90
- if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
91
- self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
92
- self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
93
- self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
94
- output_files.append(self.secondary_stem_output_path)
95
-
96
- if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
97
- self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
98
- if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
99
-
100
- self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
101
- self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
102
- output_files.append(self.primary_stem_output_path)
103
-
104
- return output_files
105
-
106
- def initialize_model_settings(self):
107
- self.logger.debug(translations["starting_model"])
108
-
109
- self.n_bins = self.n_fft // 2 + 1
110
- self.trim = self.n_fft // 2
111
-
112
- self.chunk_size = self.hop_length * (self.segment_size - 1)
113
- self.gen_size = self.chunk_size - 2 * self.trim
114
-
115
- self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
116
-
117
- self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
118
- self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
119
-
120
- def initialize_mix(self, mix, is_ckpt=False):
121
- self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
122
-
123
- if mix.shape[0] != 2:
124
- error_message = translations["!=2"].format(shape=mix.shape[0])
125
- self.logger.error(error_message)
126
- raise ValueError(error_message)
127
-
128
- if is_ckpt:
129
- self.logger.debug(translations["process_check"])
130
- pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
131
- self.logger.debug(f"{translations['cache']}: {pad}")
132
-
133
- mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
134
-
135
- num_chunks = mixture.shape[-1] // self.gen_size
136
- self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
137
-
138
- mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
139
- else:
140
- self.logger.debug(translations["process_no_check"])
141
- mix_waves = []
142
- n_sample = mix.shape[1]
143
-
144
- pad = self.gen_size - n_sample % self.gen_size
145
- self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
146
-
147
- mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
148
- self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
149
-
150
- i = 0
151
- while i < n_sample + pad:
152
- mix_waves.append(np.array(mix_p[:, i : i + self.chunk_size]))
153
-
154
- self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
155
- i += self.gen_size
156
-
157
- mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
158
- self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
159
-
160
- return mix_waves_tensor, pad
161
-
162
- def demix(self, mix, is_match_mix=False):
163
- self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
164
- self.initialize_model_settings()
165
- self.logger.debug(f"{translations['mix_shape']}: {mix.shape}")
166
- tar_waves_ = []
167
-
168
- if is_match_mix:
169
- chunk_size = self.hop_length * (self.segment_size - 1)
170
- overlap = 0.02
171
- self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
172
- else:
173
- chunk_size = self.chunk_size
174
- overlap = self.overlap
175
- self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
176
-
177
- gen_size = chunk_size - 2 * self.trim
178
- self.logger.debug(f"{translations['calc_size']}: {gen_size}")
179
-
180
- mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, gen_size + self.trim - ((mix.shape[-1]) % gen_size)), dtype="float32")), 1)
181
- self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
182
-
183
- step = int((1 - overlap) * chunk_size)
184
- self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
185
-
186
- result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
187
- divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
188
-
189
- total = 0
190
- total_chunks = (mixture.shape[-1] + step - 1) // step
191
- self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
192
-
193
- for i in tqdm(range(0, mixture.shape[-1], step), ncols=100, unit="f"):
194
- total += 1
195
- start = i
196
- end = min(i + chunk_size, mixture.shape[-1])
197
- self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
198
-
199
- chunk_size_actual = end - start
200
- window = None
201
-
202
- if overlap != 0:
203
- window = np.hanning(chunk_size_actual)
204
- window = np.tile(window[None, None, :], (1, 2, 1))
205
- self.logger.debug(translations["window"])
206
-
207
- mix_part_ = mixture[:, start:end]
208
-
209
- if end != i + chunk_size:
210
- pad_size = (i + chunk_size) - end
211
- mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
212
-
213
- mix_waves = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device).split(self.batch_size)
214
-
215
- total_batches = len(mix_waves)
216
- self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
217
-
218
- with torch.no_grad():
219
- batches_processed = 0
220
-
221
- for mix_wave in mix_waves:
222
- batches_processed += 1
223
- self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
224
-
225
- tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
226
-
227
- if window is not None:
228
- tar_waves[..., :chunk_size_actual] *= window
229
- divider[..., start:end] += window
230
- else: divider[..., start:end] += 1
231
-
232
- result[..., start:end] += tar_waves[..., : end - start]
233
-
234
-
235
- self.logger.debug(translations["normalization_2"])
236
- tar_waves = result / divider
237
- tar_waves_.append(tar_waves)
238
-
239
- tar_waves = np.concatenate(np.vstack(tar_waves_)[:, :, self.trim : -self.trim], axis=-1)[:, : mix.shape[-1]]
240
-
241
- source = tar_waves[:, 0:None]
242
- self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
243
-
244
- if not is_match_mix:
245
- source *= self.compensate
246
- self.logger.debug(translations["mix_match"])
247
-
248
- self.logger.debug(translations["mix_success"])
249
- return source
250
-
251
- def run_model(self, mix, is_match_mix=False):
252
- spek = self.stft(mix.to(self.torch_device))
253
- self.logger.debug(translations["stft_2"].format(shape=spek.shape))
254
-
255
- spek[:, :, :3, :] *= 0
256
-
257
- if is_match_mix:
258
- spec_pred = spek.cpu().numpy()
259
- self.logger.debug(translations["is_match_mix"])
260
- else:
261
- if self.enable_denoise:
262
- spec_pred_neg = self.model_run(-spek)
263
- spec_pred_pos = self.model_run(spek)
264
- spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
265
- self.logger.debug(translations["enable_denoise"])
266
- else:
267
- spec_pred = self.model_run(spek)
268
- self.logger.debug(translations["no_denoise"])
269
-
270
- result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
271
- self.logger.debug(f"{translations['stft']}: {result.shape}")
272
-
273
- return result
274
-
275
- class STFT:
276
- def __init__(self, logger, n_fft, hop_length, dim_f, device):
277
- self.logger = logger
278
- self.n_fft = n_fft
279
- self.hop_length = hop_length
280
- self.dim_f = dim_f
281
- self.device = device
282
- self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
283
-
284
- def __call__(self, input_tensor):
285
- is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
286
-
287
- if is_non_standard_device: input_tensor = input_tensor.cpu()
288
-
289
- batch_dimensions = input_tensor.shape[:-2]
290
- channel_dim, time_dim = input_tensor.shape[-2:]
291
-
292
- permuted_stft_output = torch.stft(input_tensor.reshape([-1, time_dim]), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True, return_complex=False).permute([0, 3, 1, 2])
293
- final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
294
-
295
- if is_non_standard_device: final_output = final_output.to(self.device)
296
- return final_output[..., : self.dim_f, :]
297
-
298
- def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
299
- return torch.cat([input_tensor, torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)], -2)
300
-
301
- def calculate_inverse_dimensions(self, input_tensor):
302
- channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
303
-
304
- return input_tensor.shape[:-3], channel_dim, freq_dim, time_dim, self.n_fft // 2 + 1
305
-
306
- def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
307
- permuted_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]).reshape([-1, 2, num_freq_bins, time_dim]).permute([0, 2, 3, 1])
308
-
309
- return permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
310
-
311
- def inverse(self, input_tensor):
312
- is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
313
- if is_non_standard_device: input_tensor = input_tensor.cpu()
314
-
315
- batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
316
- final_output = torch.istft(self.prepare_for_istft(self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins), batch_dimensions, channel_dim, num_freq_bins, time_dim), n_fft=self.n_fft, hop_length=self.hop_length, window=self.hann_window.to(input_tensor.device), center=True).reshape([*batch_dimensions, 2, -1])
317
-
318
- if is_non_standard_device: final_output = final_output.to(self.device)
319
-
320
- return final_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/CREPE.py DELETED
@@ -1,210 +0,0 @@
1
- import os
2
- import torch
3
- import librosa
4
- import functools
5
- import scipy.stats
6
-
7
- import numpy as np
8
-
9
- CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
10
-
11
- class Crepe(torch.nn.Module):
12
- def __init__(self, model='full'):
13
- super().__init__()
14
- if model == 'full':
15
- in_channels = [1, 1024, 128, 128, 128, 256]
16
- out_channels = [1024, 128, 128, 128, 256, 512]
17
- self.in_features = 2048
18
- elif model == 'large':
19
- in_channels = [1, 768, 96, 96, 96, 192]
20
- out_channels = [768, 96, 96, 96, 192, 384]
21
- self.in_features = 1536
22
- elif model == 'medium':
23
- in_channels = [1, 512, 64, 64, 64, 128]
24
- out_channels = [512, 64, 64, 64, 128, 256]
25
- self.in_features = 1024
26
- elif model == 'small':
27
- in_channels = [1, 256, 32, 32, 32, 64]
28
- out_channels = [256, 32, 32, 32, 64, 128]
29
- self.in_features = 512
30
- elif model == 'tiny':
31
- in_channels = [1, 128, 16, 16, 16, 32]
32
- out_channels = [128, 16, 16, 16, 32, 64]
33
- self.in_features = 256
34
-
35
- kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
36
- strides = [(4, 1)] + 5 * [(1, 1)]
37
-
38
- batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
39
-
40
- self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
41
- self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
42
- self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
43
- self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
44
-
45
- self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
46
- self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
47
- self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
48
- self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
49
-
50
- self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
51
- self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
52
- self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
53
- self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
54
-
55
- self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
56
-
57
- def forward(self, x, embed=False):
58
- x = self.embed(x)
59
- if embed: return x
60
-
61
- return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
62
-
63
- def embed(self, x):
64
- x = x[:, None, :, None]
65
-
66
- return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
67
-
68
- def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
69
- return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
70
-
71
- def viterbi(logits):
72
- if not hasattr(viterbi, 'transition'):
73
- xx, yy = np.meshgrid(range(360), range(360))
74
- transition = np.maximum(12 - abs(xx - yy), 0)
75
- viterbi.transition = transition / transition.sum(axis=1, keepdims=True)
76
-
77
- with torch.no_grad():
78
- probs = torch.nn.functional.softmax(logits, dim=1)
79
-
80
- bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
81
- return bins, bins_to_frequency(bins)
82
-
83
- def predict(audio, sample_rate, hop_length=None, fmin=50, fmax=MAX_FMAX, model='full', return_periodicity=False, batch_size=None, device='cpu', pad=True, providers=None, onnx=False):
84
- results = []
85
-
86
- if onnx:
87
- import onnxruntime as ort
88
-
89
- sess_options = ort.SessionOptions()
90
- sess_options.log_severity_level = 3
91
-
92
- session = ort.InferenceSession(os.path.join("assets", "models", "predictors", f"crepe_{model}.onnx"), sess_options=sess_options, providers=providers)
93
-
94
- for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
95
- result = postprocess(torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: frames.cpu().numpy()})[0].transpose(1, 0)[None]), fmin, fmax, return_periodicity)
96
- results.append((result[0], result[1]) if isinstance(result, tuple) else result)
97
-
98
- del session
99
-
100
- if return_periodicity:
101
- pitch, periodicity = zip(*results)
102
- return torch.cat(pitch, 1), torch.cat(periodicity, 1)
103
-
104
- return torch.cat(results, 1)
105
- else:
106
- with torch.no_grad():
107
- for frames in preprocess(audio, sample_rate, hop_length, batch_size, device, pad):
108
- result = postprocess(infer(frames, model, device, embed=False).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2), fmin, fmax, return_periodicity)
109
- results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
110
-
111
- if return_periodicity:
112
- pitch, periodicity = zip(*results)
113
- return torch.cat(pitch, 1), torch.cat(periodicity, 1)
114
-
115
- return torch.cat(results, 1)
116
-
117
- def bins_to_frequency(bins):
118
- cents = CENTS_PER_BIN * bins + 1997.3794084376191
119
- return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
120
-
121
- def frequency_to_bins(frequency, quantize_fn=torch.floor):
122
- return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
123
-
124
- def infer(frames, model='full', device='cpu', embed=False):
125
- if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or (hasattr(infer, 'capacity') and infer.capacity != model): load_model(device, model)
126
- infer.model = infer.model.to(device)
127
-
128
- return infer.model(frames, embed=embed)
129
-
130
- def load_model(device, capacity='full'):
131
- infer.capacity = capacity
132
- infer.model = Crepe(capacity)
133
- infer.model.load_state_dict(torch.load(os.path.join("assets", "models", "predictors", f"crepe_{capacity}.pth"), map_location=device))
134
- infer.model = infer.model.to(torch.device(device))
135
- infer.model.eval()
136
-
137
- def postprocess(probabilities, fmin=0, fmax=MAX_FMAX, return_periodicity=False):
138
- probabilities = probabilities.detach()
139
-
140
- probabilities[:, :frequency_to_bins(torch.tensor(fmin))] = -float('inf')
141
- probabilities[:, frequency_to_bins(torch.tensor(fmax), torch.ceil):] = -float('inf')
142
-
143
- bins, pitch = viterbi(probabilities)
144
-
145
- if not return_periodicity: return pitch
146
- return pitch, periodicity(probabilities, bins)
147
-
148
- def preprocess(audio, sample_rate, hop_length=None, batch_size=None, device='cpu', pad=True):
149
- hop_length = sample_rate // 100 if hop_length is None else hop_length
150
-
151
- if sample_rate != SAMPLE_RATE:
152
- audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
153
- hop_length = int(hop_length * SAMPLE_RATE / sample_rate)
154
-
155
- if pad:
156
- total_frames = 1 + int(audio.size(1) // hop_length)
157
- audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
158
- else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
159
-
160
- batch_size = total_frames if batch_size is None else batch_size
161
-
162
- for i in range(0, total_frames, batch_size):
163
- frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
164
- frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(device)
165
- frames -= frames.mean(dim=1, keepdim=True)
166
- frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
167
-
168
- yield frames
169
-
170
- def periodicity(probabilities, bins):
171
- probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
172
- periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
173
-
174
- return periodicity.reshape(probabilities.size(0), probabilities.size(2))
175
-
176
- def mean(signals, win_length=9):
177
- assert signals.dim() == 2
178
-
179
- signals = signals.unsqueeze(1)
180
- mask = ~torch.isnan(signals)
181
- padding = win_length // 2
182
-
183
- ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
184
- avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
185
- avg_pooled[avg_pooled == 0] = float("nan")
186
-
187
- return avg_pooled.squeeze(1)
188
-
189
- def median(signals, win_length):
190
- assert signals.dim() == 2
191
-
192
- signals = signals.unsqueeze(1)
193
- mask = ~torch.isnan(signals)
194
- padding = win_length // 2
195
-
196
- x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
197
- mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
198
-
199
- x = x.unfold(2, win_length, 1)
200
- mask = mask.unfold(2, win_length, 1)
201
-
202
- x = x.contiguous().view(x.size()[:3] + (-1,))
203
- mask = mask.contiguous().view(mask.size()[:3] + (-1,))
204
-
205
- x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
206
-
207
- median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
208
- median_pooled[torch.isinf(median_pooled)] = float("nan")
209
-
210
- return median_pooled.squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/FCPE.py DELETED
@@ -1,670 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import librosa
5
-
6
- import numpy as np
7
- import soundfile as sf
8
- import torch.nn.functional as F
9
-
10
- from torch import nn, einsum
11
- from functools import partial
12
- from einops import rearrange, repeat, pack, unpack
13
- from torch.nn.utils.parametrizations import weight_norm
14
-
15
- os.environ["LRU_CACHE_CAPACITY"] = "3"
16
-
17
- def exists(val):
18
- return val is not None
19
-
20
- def default(value, d):
21
- return value if exists(value) else d
22
-
23
- def max_neg_value(tensor):
24
- return -torch.finfo(tensor.dtype).max
25
-
26
- def l2norm(tensor):
27
- return F.normalize(tensor, dim = -1).type(tensor.dtype)
28
-
29
- def pad_to_multiple(tensor, multiple, dim=-1, value=0):
30
- seqlen = tensor.shape[dim]
31
- m = seqlen / multiple
32
- if m.is_integer(): return False, tensor
33
- return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
34
-
35
- def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
36
- t = x.shape[1]
37
- dims = (len(x.shape) - dim) * (0, 0)
38
- padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
39
- return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
40
-
41
- def rotate_half(x):
42
- x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
43
- return torch.cat((-x2, x1), dim = -1)
44
-
45
- def apply_rotary_pos_emb(q, k, freqs, scale = 1):
46
- q_len = q.shape[-2]
47
- q_freqs = freqs[..., -q_len:, :]
48
- inv_scale = scale ** -1
49
- if scale.ndim == 2: scale = scale[-q_len:, :]
50
- q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
51
- k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
52
- return q, k
53
-
54
- class LocalAttention(nn.Module):
55
- def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, dim = None, autopad = False, exact_windowsize = False, scale = None, use_rotary_pos_emb = True, use_xpos = False, xpos_scale_base = None):
56
- super().__init__()
57
- look_forward = default(look_forward, 0 if causal else 1)
58
- assert not (causal and look_forward > 0)
59
- self.scale = scale
60
- self.window_size = window_size
61
- self.autopad = autopad
62
- self.exact_windowsize = exact_windowsize
63
- self.causal = causal
64
- self.look_backward = look_backward
65
- self.look_forward = look_forward
66
- self.dropout = nn.Dropout(dropout)
67
- self.shared_qk = shared_qk
68
- self.rel_pos = None
69
- self.use_xpos = use_xpos
70
-
71
- if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
72
- if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
73
- self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
74
-
75
- def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
76
- mask = default(mask, input_mask)
77
- assert not (exists(window_size) and not self.use_xpos)
78
-
79
- _, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
80
- (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
81
-
82
- if autopad:
83
- orig_seq_len = q.shape[1]
84
- (_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
85
-
86
- b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
87
- scale = default(self.scale, dim_head ** -0.5)
88
- assert (n % window_size) == 0
89
- windows = n // window_size
90
- if shared_qk: k = l2norm(k)
91
-
92
- seq = torch.arange(n, device = device)
93
- b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
94
- bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
95
- bq = bq * scale
96
- look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
97
- bk = look_around(bk, **look_around_kwargs)
98
- bv = look_around(bv, **look_around_kwargs)
99
-
100
- if exists(self.rel_pos):
101
- pos_emb, xpos_scale = self.rel_pos(bk)
102
- bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
103
-
104
- bq_t = b_t
105
- bq_k = look_around(b_t, **look_around_kwargs)
106
- bq_t = rearrange(bq_t, '... i -> ... i 1')
107
- bq_k = rearrange(bq_k, '... j -> ... 1 j')
108
- pad_mask = bq_k == pad_value
109
- sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
110
-
111
- if exists(attn_bias):
112
- heads = attn_bias.shape[0]
113
- assert (b % heads) == 0
114
- attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
115
- sim = sim + attn_bias
116
-
117
- mask_value = max_neg_value(sim)
118
-
119
- if shared_qk:
120
- self_mask = bq_t == bq_k
121
- sim = sim.masked_fill(self_mask, -5e4)
122
- del self_mask
123
-
124
- if causal:
125
- causal_mask = bq_t < bq_k
126
- if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
127
- sim = sim.masked_fill(causal_mask, mask_value)
128
- del causal_mask
129
-
130
- sim = sim.masked_fill(((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask, mask_value) if not causal and self.exact_windowsize else sim.masked_fill(pad_mask, mask_value)
131
-
132
- if exists(mask):
133
- batch = mask.shape[0]
134
- assert (b % batch) == 0
135
- h = b // mask.shape[0]
136
- if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
137
- mask = repeat(rearrange(look_around(rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size), **{**look_around_kwargs, 'pad_value': False}), '... j -> ... 1 j'), 'b ... -> (b h) ...', h = h)
138
- sim = sim.masked_fill(~mask, mask_value)
139
- del mask
140
-
141
- out = rearrange(einsum('b h i j, b h j e -> b h i e', self.dropout(sim.softmax(dim = -1)), bv), 'b w n d -> b (w n) d')
142
- if autopad: out = out[:, :orig_seq_len, :]
143
- out, *_ = unpack(out, packed_shape, '* n d')
144
- return out
145
-
146
- class SinusoidalEmbeddings(nn.Module):
147
- def __init__(self, dim, scale_base = None, use_xpos = False, theta = 10000):
148
- super().__init__()
149
- inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
150
- self.register_buffer('inv_freq', inv_freq)
151
- self.use_xpos = use_xpos
152
- self.scale_base = scale_base
153
- assert not (use_xpos and not exists(scale_base))
154
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
155
- self.register_buffer('scale', scale, persistent = False)
156
-
157
- def forward(self, x):
158
- seq_len, device = x.shape[-2], x.device
159
- t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
160
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
161
- freqs = torch.cat((freqs, freqs), dim = -1)
162
-
163
- if not self.use_xpos: return freqs, torch.ones(1, device = device)
164
-
165
- power = (t - (seq_len // 2)) / self.scale_base
166
- scale = self.scale ** rearrange(power, 'n -> n 1')
167
- return freqs, torch.cat((scale, scale), dim = -1)
168
-
169
- def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
170
- try:
171
- data, sample_rate = sf.read(full_path, always_2d=True)
172
- except Exception as e:
173
- print(f"{full_path}: {e}")
174
-
175
- if return_empty_on_exception: return [], sample_rate or target_sr or 48000
176
- else: raise
177
-
178
- data = data[:, 0] if len(data.shape) > 1 else data
179
- assert len(data) > 2
180
-
181
- max_mag = (-np.iinfo(data.dtype).min if np.issubdtype(data.dtype, np.integer) else max(np.amax(data), -np.amin(data)))
182
- data = torch.FloatTensor(data.astype(np.float32)) / ((2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0))
183
-
184
- if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception: return [], sample_rate or target_sr or 48000
185
-
186
- if target_sr is not None and sample_rate != target_sr:
187
- data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr))
188
- sample_rate = target_sr
189
-
190
- return data, sample_rate
191
-
192
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
193
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
194
-
195
- def dynamic_range_decompression(x, C=1):
196
- return np.exp(x) / C
197
-
198
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
199
- return torch.log(torch.clamp(x, min=clip_val) * C)
200
-
201
- def dynamic_range_decompression_torch(x, C=1):
202
- return torch.exp(x) / C
203
-
204
- class STFT:
205
- def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
206
- self.target_sr = sr
207
- self.n_mels = n_mels
208
- self.n_fft = n_fft
209
- self.win_size = win_size
210
- self.hop_length = hop_length
211
- self.fmin = fmin
212
- self.fmax = fmax
213
- self.clip_val = clip_val
214
- self.mel_basis = {}
215
- self.hann_window = {}
216
-
217
- def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
218
- n_fft = self.n_fft
219
- win_size = self.win_size
220
- hop_length = self.hop_length
221
- fmax = self.fmax
222
- factor = 2 ** (keyshift / 12)
223
- win_size_new = int(np.round(win_size * factor))
224
- hop_length_new = int(np.round(hop_length * speed))
225
- mel_basis = self.mel_basis if not train else {}
226
- hann_window = self.hann_window if not train else {}
227
- mel_basis_key = str(fmax) + "_" + str(y.device)
228
-
229
- if mel_basis_key not in mel_basis:
230
- from librosa.filters import mel as librosa_mel_fn
231
- mel_basis[mel_basis_key] = torch.from_numpy(librosa_mel_fn(sr=self.target_sr, n_fft=n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=fmax)).float().to(y.device)
232
-
233
- keyshift_key = str(keyshift) + "_" + str(y.device)
234
- if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
235
-
236
- pad_left = (win_size_new - hop_length_new) // 2
237
- pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
238
- spec = torch.stft(torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1), int(np.round(n_fft * factor)), hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
239
- spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
240
-
241
- if keyshift != 0:
242
- size = n_fft // 2 + 1
243
- resize = spec.size(1)
244
- spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
245
-
246
- return dynamic_range_compression_torch(torch.matmul(mel_basis[mel_basis_key], spec), clip_val=self.clip_val)
247
-
248
- def __call__(self, audiopath):
249
- audio, _ = load_wav_to_torch(audiopath, target_sr=self.target_sr)
250
- return self.get_mel(audio.unsqueeze(0)).squeeze(0)
251
-
252
- stft = STFT()
253
-
254
- def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
255
- b, h, *_ = data.shape
256
-
257
- data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
258
- ratio = projection_matrix.shape[0] ** -0.5
259
-
260
- data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
261
- diag_data = ((torch.sum(data**2, dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
262
-
263
- return (ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps) if is_query else ratio * (torch.exp(data_dash - diag_data + eps))).type_as(data)
264
-
265
- def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
266
- unstructured_block = torch.randn((cols, cols), device=device)
267
-
268
- q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
269
- q, r = map(lambda t: t.to(device), (q, r))
270
-
271
- if qr_uniform_q:
272
- d = torch.diag(r, 0)
273
- q *= d.sign()
274
-
275
- return q.t()
276
-
277
- def empty(tensor):
278
- return tensor.numel() == 0
279
-
280
- def cast_tuple(val):
281
- return (val,) if not isinstance(val, tuple) else val
282
-
283
- class PCmer(nn.Module):
284
- def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
285
- super().__init__()
286
- self.num_layers = num_layers
287
- self.num_heads = num_heads
288
- self.dim_model = dim_model
289
- self.dim_values = dim_values
290
- self.dim_keys = dim_keys
291
- self.residual_dropout = residual_dropout
292
- self.attention_dropout = attention_dropout
293
- self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
294
-
295
- def forward(self, phone, mask=None):
296
- for layer in self._layers:
297
- phone = layer(phone, mask)
298
-
299
- return phone
300
-
301
- class _EncoderLayer(nn.Module):
302
- def __init__(self, parent: PCmer):
303
- super().__init__()
304
- self.conformer = ConformerConvModule(parent.dim_model)
305
- self.norm = nn.LayerNorm(parent.dim_model)
306
- self.dropout = nn.Dropout(parent.residual_dropout)
307
- self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
308
-
309
- def forward(self, phone, mask=None):
310
- phone = phone + (self.attn(self.norm(phone), mask=mask))
311
- return phone + (self.conformer(phone))
312
-
313
- def calc_same_padding(kernel_size):
314
- pad = kernel_size // 2
315
- return (pad, pad - (kernel_size + 1) % 2)
316
-
317
- class Swish(nn.Module):
318
- def forward(self, x):
319
- return x * x.sigmoid()
320
-
321
- class Transpose(nn.Module):
322
- def __init__(self, dims):
323
- super().__init__()
324
- assert len(dims) == 2, "dims == 2"
325
-
326
- self.dims = dims
327
-
328
- def forward(self, x):
329
- return x.transpose(*self.dims)
330
-
331
- class GLU(nn.Module):
332
- def __init__(self, dim):
333
- super().__init__()
334
- self.dim = dim
335
-
336
- def forward(self, x):
337
- out, gate = x.chunk(2, dim=self.dim)
338
- return out * gate.sigmoid()
339
-
340
- class DepthWiseConv1d(nn.Module):
341
- def __init__(self, chan_in, chan_out, kernel_size, padding):
342
- super().__init__()
343
- self.padding = padding
344
- self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
345
-
346
- def forward(self, x):
347
- return self.conv(F.pad(x, self.padding))
348
-
349
- class ConformerConvModule(nn.Module):
350
- def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
351
- super().__init__()
352
- inner_dim = dim * expansion_factor
353
- self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=(calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0))), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
354
-
355
- def forward(self, x):
356
- return self.net(x)
357
-
358
- def linear_attention(q, k, v):
359
- return torch.einsum("...ed,...nd->...ne", k, q) if v is None else torch.einsum("...de,...nd,...n->...ne", torch.einsum("...nd,...ne->...de", k, v), q, 1.0 / (torch.einsum("...nd,...d->...n", q, k.sum(dim=-2).type_as(q)) + 1e-8))
360
-
361
- def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
362
- nb_full_blocks = int(nb_rows / nb_columns)
363
- block_list = []
364
-
365
- for _ in range(nb_full_blocks):
366
- block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
367
-
368
- remaining_rows = nb_rows - nb_full_blocks * nb_columns
369
- if remaining_rows > 0: block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)[:remaining_rows])
370
-
371
- if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
372
- elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
373
- else: raise ValueError(f"{scaling} != 0, 1")
374
-
375
- return torch.diag(multiplier) @ torch.cat(block_list)
376
-
377
- class FastAttention(nn.Module):
378
- def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
379
- super().__init__()
380
- nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
381
- self.dim_heads = dim_heads
382
- self.nb_features = nb_features
383
- self.ortho_scaling = ortho_scaling
384
- self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
385
- projection_matrix = self.create_projection()
386
- self.register_buffer("projection_matrix", projection_matrix)
387
- self.generalized_attention = generalized_attention
388
- self.kernel_fn = kernel_fn
389
- self.no_projection = no_projection
390
- self.causal = causal
391
-
392
- @torch.no_grad()
393
- def redraw_projection_matrix(self):
394
- projections = self.create_projection()
395
- self.projection_matrix.copy_(projections)
396
-
397
- del projections
398
-
399
- def forward(self, q, k, v):
400
- if self.no_projection: q, k = q.softmax(dim=-1), (torch.exp(k) if self.causal else k.softmax(dim=-2))
401
- else:
402
- create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
403
- q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
404
-
405
- attn_fn = linear_attention if not self.causal else self.causal_linear_fn
406
- return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
407
-
408
- class SelfAttention(nn.Module):
409
- def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
410
- super().__init__()
411
- assert dim % heads == 0
412
- dim_head = default(dim_head, dim // heads)
413
- inner_dim = dim_head * heads
414
- self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
415
- self.heads = heads
416
- self.global_heads = heads - local_heads
417
- self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
418
- self.to_q = nn.Linear(dim, inner_dim)
419
- self.to_k = nn.Linear(dim, inner_dim)
420
- self.to_v = nn.Linear(dim, inner_dim)
421
- self.to_out = nn.Linear(inner_dim, dim)
422
- self.dropout = nn.Dropout(dropout)
423
-
424
- @torch.no_grad()
425
- def redraw_projection_matrix(self):
426
- self.fast_attention.redraw_projection_matrix()
427
-
428
- def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
429
- _, _, _, h, gh = *x.shape, self.heads, self.global_heads
430
- cross_attend = exists(context)
431
-
432
- context = default(context, x)
433
- context_mask = default(context_mask, mask) if not cross_attend else context_mask
434
-
435
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
436
- (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
437
-
438
- attn_outs = []
439
-
440
- if not empty(q):
441
- if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
442
-
443
- if cross_attend: pass
444
- else: out = self.fast_attention(q, k, v)
445
-
446
- attn_outs.append(out)
447
-
448
- if not empty(lq):
449
- assert (not cross_attend), "not cross_attend"
450
-
451
- out = self.local_attn(lq, lk, lv, input_mask=mask)
452
- attn_outs.append(out)
453
-
454
- return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
455
-
456
- def l2_regularization(model, l2_alpha):
457
- l2_loss = []
458
-
459
- for module in model.modules():
460
- if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
461
-
462
- return l2_alpha * sum(l2_loss)
463
-
464
- class _FCPE(nn.Module):
465
- def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, use_siren=False, use_full=False, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
466
- super().__init__()
467
- if use_siren: raise ValueError("Siren not support")
468
- if use_full: raise ValueError("Model full not support")
469
- self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
470
- self.loss_l2_regularization = (loss_l2_regularization if (loss_l2_regularization is not None) else False)
471
- self.loss_l2_regularization_scale = (loss_l2_regularization_scale if (loss_l2_regularization_scale is not None) else 1)
472
- self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
473
- self.loss_grad1_mse_scale = (loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1)
474
- self.f0_max = f0_max if (f0_max is not None) else 1975.5
475
- self.f0_min = f0_min if (f0_min is not None) else 32.70
476
- self.confidence = confidence if (confidence is not None) else False
477
- self.threshold = threshold if (threshold is not None) else 0.05
478
- self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
479
- self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
480
- self.register_buffer("cent_table", self.cent_table_b)
481
- _leaky = nn.LeakyReLU()
482
- self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), _leaky, nn.Conv1d(n_chans, n_chans, 3, 1, 1))
483
- self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
484
- self.norm = nn.LayerNorm(n_chans)
485
- self.n_out = out_dims
486
- self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
487
-
488
- def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"):
489
- if cdecoder == "argmax": self.cdecoder = self.cents_decoder
490
- elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
491
-
492
- x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
493
-
494
- if not infer:
495
- loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
496
- if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
497
- x = loss_all
498
-
499
- if infer:
500
- x = self.cent_to_f0(self.cdecoder(x))
501
- x = (1 + x / 700).log() if not return_hz_f0 else x
502
-
503
- return x
504
-
505
- def cents_decoder(self, y, mask=True):
506
- B, N, _ = y.size()
507
- rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
508
-
509
- if mask:
510
- confident = torch.max(y, dim=-1, keepdim=True)[0]
511
- confident_mask = torch.ones_like(confident)
512
-
513
- confident_mask[confident <= self.threshold] = float("-INF")
514
- rtn = rtn * confident_mask
515
-
516
- return (rtn, confident) if self.confidence else rtn
517
-
518
- def cents_local_decoder(self, y, mask=True):
519
- B, N, _ = y.size()
520
-
521
- confident, max_index = torch.max(y, dim=-1, keepdim=True)
522
- local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
523
-
524
- y_l = torch.gather(y, -1, local_argmax_index)
525
- rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
526
-
527
- if mask:
528
- confident_mask = torch.ones_like(confident)
529
- confident_mask[confident <= self.threshold] = float("-INF")
530
-
531
- rtn = rtn * confident_mask
532
-
533
- return (rtn, confident) if self.confidence else rtn
534
-
535
- def cent_to_f0(self, cent):
536
- return 10.0 * 2 ** (cent / 1200.0)
537
-
538
- def f0_to_cent(self, f0):
539
- return 1200.0 * torch.log2(f0 / 10.0)
540
-
541
- def gaussian_blurred_cent(self, cents):
542
- B, N, _ = cents.size()
543
- return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
544
-
545
- class FCPEInfer:
546
- def __init__(self, model_path, device=None, dtype=torch.float32, providers=None, onnx=False):
547
- if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
548
- self.wav2mel = Wav2Mel(device=device, dtype=dtype)
549
- self.device = device
550
- self.dtype = dtype
551
- self.onnx = onnx
552
-
553
- if self.onnx:
554
- import onnxruntime as ort
555
-
556
- sess_options = ort.SessionOptions()
557
- sess_options.log_severity_level = 3
558
-
559
- self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
560
- else:
561
- ckpt = torch.load(model_path, map_location=torch.device(self.device))
562
- self.args = DotDict(ckpt["config"])
563
- model = _FCPE(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, use_siren=self.args.model.use_siren, use_full=self.args.model.use_full, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.args.model.f0_max, f0_min=self.args.model.f0_min, confidence=self.args.model.confidence)
564
-
565
- model.to(self.device).to(self.dtype)
566
- model.load_state_dict(ckpt["model"])
567
-
568
- model.eval()
569
- self.model = model
570
-
571
- @torch.no_grad()
572
- def __call__(self, audio, sr, threshold=0.05):
573
- if not self.onnx: self.model.threshold = threshold
574
- mel = self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype)
575
-
576
- return torch.as_tensor(self.model.run(["pitchf"], {"mel": mel.detach().cpu().numpy(), "threshold": np.array(threshold, dtype=np.float32)})[0], dtype=self.dtype, device=self.device).squeeze() if self.onnx else self.model(mel=mel, infer=True, return_hz_f0=True)
577
-
578
- class Wav2Mel:
579
- def __init__(self, device=None, dtype=torch.float32):
580
- self.sample_rate = 16000
581
- self.hop_size = 160
582
- if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
583
- self.device = device
584
- self.dtype = dtype
585
- self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
586
- self.resample_kernel = {}
587
-
588
- def extract_nvstft(self, audio, keyshift=0, train=False):
589
- return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
590
-
591
- def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
592
- audio = audio.to(self.dtype).to(self.device)
593
-
594
- if sample_rate == self.sample_rate: audio_res = audio
595
- else:
596
- key_str = str(sample_rate)
597
-
598
- if key_str not in self.resample_kernel:
599
- from torchaudio.transforms import Resample
600
- self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
601
-
602
- self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
603
- audio_res = self.resample_kernel[key_str](audio)
604
-
605
- mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
606
- n_frames = int(audio.shape[1] // self.hop_size) + 1
607
- mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
608
-
609
- return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
610
-
611
- def __call__(self, audio, sample_rate, keyshift=0, train=False):
612
- return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
613
-
614
- class DotDict(dict):
615
- def __getattr__(*args):
616
- val = dict.get(*args)
617
- return DotDict(val) if type(val) is dict else val
618
-
619
- __setattr__ = dict.__setitem__
620
- __delattr__ = dict.__delitem__
621
-
622
- class FCPE:
623
- def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=44100, threshold=0.05, providers=None, onnx=False):
624
- self.fcpe = FCPEInfer(model_path, device=device, dtype=dtype, providers=providers, onnx=onnx)
625
- self.hop_length = hop_length
626
- self.f0_min = f0_min
627
- self.f0_max = f0_max
628
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
629
- self.threshold = threshold
630
- self.sample_rate = sample_rate
631
- self.dtype = dtype
632
- self.name = "fcpe"
633
-
634
- def repeat_expand(self, content, target_len, mode = "nearest"):
635
- ndim = content.ndim
636
- content = (content[None, None] if ndim == 1 else content[None] if ndim == 2 else content)
637
-
638
- assert content.ndim == 3
639
- is_np = isinstance(content, np.ndarray)
640
-
641
- results = torch.nn.functional.interpolate(torch.from_numpy(content) if is_np else content, size=target_len, mode=mode)
642
- results = results.numpy() if is_np else results
643
- return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
644
-
645
- def post_process(self, x, sample_rate, f0, pad_to):
646
- f0 = (torch.from_numpy(f0).float().to(x.device) if isinstance(f0, np.ndarray) else f0)
647
- f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
648
-
649
- vuv_vector = torch.zeros_like(f0)
650
- vuv_vector[f0 > 0.0] = 1.0
651
- vuv_vector[f0 <= 0.0] = 0.0
652
-
653
- nzindex = torch.nonzero(f0).squeeze()
654
- f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
655
- vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
656
-
657
- if f0.shape[0] <= 0: return np.zeros(pad_to), vuv_vector.cpu().numpy()
658
- if f0.shape[0] == 1: return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
659
-
660
- return np.interp(np.arange(pad_to) * self.hop_length / sample_rate, self.hop_length / sample_rate * nzindex.cpu().numpy(), f0, left=f0[0], right=f0[-1]), vuv_vector.cpu().numpy()
661
-
662
- def compute_f0(self, wav, p_len=None):
663
- x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
664
- p_len = x.shape[0] // self.hop_length if p_len is None else p_len
665
-
666
- f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)
667
- f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
668
-
669
- if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
670
- return self.post_process(x, self.sample_rate, f0, p_len)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/RMVPE.py DELETED
@@ -1,260 +0,0 @@
1
- import torch
2
-
3
- import numpy as np
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- from librosa.filters import mel
8
-
9
- N_MELS, N_CLASS = 128, 360
10
-
11
- class ConvBlockRes(nn.Module):
12
- def __init__(self, in_channels, out_channels, momentum=0.01):
13
- super(ConvBlockRes, self).__init__()
14
- self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
15
-
16
- if in_channels != out_channels:
17
- self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
18
- self.is_shortcut = True
19
- else: self.is_shortcut = False
20
-
21
- def forward(self, x):
22
- return self.conv(x) + self.shortcut(x) if self.is_shortcut else self.conv(x) + x
23
-
24
- class ResEncoderBlock(nn.Module):
25
- def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
26
- super(ResEncoderBlock, self).__init__()
27
- self.n_blocks = n_blocks
28
- self.conv = nn.ModuleList()
29
- self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
30
-
31
- for _ in range(n_blocks - 1):
32
- self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
33
-
34
- self.kernel_size = kernel_size
35
- if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
36
-
37
- def forward(self, x):
38
- for i in range(self.n_blocks):
39
- x = self.conv[i](x)
40
-
41
- if self.kernel_size is not None: return x, self.pool(x)
42
- else: return x
43
-
44
- class Encoder(nn.Module):
45
- def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
46
- super(Encoder, self).__init__()
47
- self.n_encoders = n_encoders
48
- self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
49
-
50
- self.layers = nn.ModuleList()
51
- self.latent_channels = []
52
-
53
- for _ in range(self.n_encoders):
54
- self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
55
- self.latent_channels.append([out_channels, in_size])
56
- in_channels = out_channels
57
- out_channels *= 2
58
- in_size //= 2
59
-
60
- self.out_size = in_size
61
- self.out_channel = out_channels
62
-
63
- def forward(self, x):
64
- concat_tensors = []
65
- x = self.bn(x)
66
-
67
- for i in range(self.n_encoders):
68
- t, x = self.layers[i](x)
69
- concat_tensors.append(t)
70
-
71
- return x, concat_tensors
72
-
73
- class Intermediate(nn.Module):
74
- def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
75
- super(Intermediate, self).__init__()
76
- self.n_inters = n_inters
77
- self.layers = nn.ModuleList()
78
- self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
79
-
80
- for _ in range(self.n_inters - 1):
81
- self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
82
-
83
- def forward(self, x):
84
- for i in range(self.n_inters):
85
- x = self.layers[i](x)
86
-
87
- return x
88
-
89
- class ResDecoderBlock(nn.Module):
90
- def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
91
- super(ResDecoderBlock, self).__init__()
92
- out_padding = (0, 1) if stride == (1, 2) else (1, 1)
93
- self.n_blocks = n_blocks
94
-
95
- self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
96
- self.conv2 = nn.ModuleList()
97
- self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
98
-
99
- for _ in range(n_blocks - 1):
100
- self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
101
-
102
- def forward(self, x, concat_tensor):
103
- x = torch.cat((self.conv1(x), concat_tensor), dim=1)
104
-
105
- for i in range(self.n_blocks):
106
- x = self.conv2[i](x)
107
-
108
- return x
109
-
110
- class Decoder(nn.Module):
111
- def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
112
- super(Decoder, self).__init__()
113
- self.layers = nn.ModuleList()
114
- self.n_decoders = n_decoders
115
-
116
- for _ in range(self.n_decoders):
117
- out_channels = in_channels // 2
118
- self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
119
- in_channels = out_channels
120
-
121
- def forward(self, x, concat_tensors):
122
- for i in range(self.n_decoders):
123
- x = self.layers[i](x, concat_tensors[-1 - i])
124
-
125
- return x
126
-
127
- class DeepUnet(nn.Module):
128
- def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
129
- super(DeepUnet, self).__init__()
130
- self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
131
- self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
132
- self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
133
-
134
- def forward(self, x):
135
- x, concat_tensors = self.encoder(x)
136
- return self.decoder(self.intermediate(x), concat_tensors)
137
-
138
- class E2E(nn.Module):
139
- def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
140
- super(E2E, self).__init__()
141
- self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
142
- self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
143
- self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
144
-
145
- def forward(self, mel):
146
- return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
147
-
148
- class MelSpectrogram(torch.nn.Module):
149
- def __init__(self, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
150
- super().__init__()
151
- n_fft = win_length if n_fft is None else n_fft
152
- self.hann_window = {}
153
- mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
154
- mel_basis = torch.from_numpy(mel_basis).float()
155
- self.register_buffer("mel_basis", mel_basis)
156
- self.n_fft = win_length if n_fft is None else n_fft
157
- self.hop_length = hop_length
158
- self.win_length = win_length
159
- self.sample_rate = sample_rate
160
- self.n_mel_channels = n_mel_channels
161
- self.clamp = clamp
162
-
163
- def forward(self, audio, keyshift=0, speed=1, center=True):
164
- factor = 2 ** (keyshift / 12)
165
- win_length_new = int(np.round(self.win_length * factor))
166
- keyshift_key = str(keyshift) + "_" + str(audio.device)
167
- if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
168
-
169
- fft = torch.stft(audio, n_fft=int(np.round(self.n_fft * factor)), hop_length=int(np.round(self.hop_length * speed)), win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
170
- magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
171
-
172
- if keyshift != 0:
173
- size = self.n_fft // 2 + 1
174
- resize = magnitude.size(1)
175
- if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
176
- magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
177
-
178
- mel_output = torch.matmul(self.mel_basis, magnitude)
179
- return torch.log(torch.clamp(mel_output, min=self.clamp))
180
-
181
- class RMVPE:
182
- def __init__(self, model_path, device=None, providers=None, onnx=False):
183
- self.resample_kernel = {}
184
- self.onnx = onnx
185
-
186
- if self.onnx:
187
- import onnxruntime as ort
188
-
189
- sess_options = ort.SessionOptions()
190
- sess_options.log_severity_level = 3
191
-
192
- self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
193
- else:
194
- model = E2E(4, 1, (2, 2))
195
- ckpt = torch.load(model_path, map_location="cpu")
196
- model.load_state_dict(ckpt)
197
- model.eval()
198
- self.model = model.to(device)
199
-
200
- self.resample_kernel = {}
201
- self.device = device
202
- self.mel_extractor = MelSpectrogram(N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
203
- cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
204
- self.cents_mapping = np.pad(cents_mapping, (4, 4))
205
-
206
- def mel2hidden(self, mel):
207
- with torch.no_grad():
208
- n_frames = mel.shape[-1]
209
- mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
210
-
211
- hidden = self.model.run([self.model.get_outputs()[0].name], input_feed={self.model.get_inputs()[0].name: mel.cpu().numpy()})[0] if self.onnx else self.model(mel.float())
212
- return hidden[:, :n_frames]
213
-
214
- def decode(self, hidden, thred=0.03):
215
- f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
216
- f0[f0 == 10] = 0
217
-
218
- return f0
219
-
220
- def infer_from_audio(self, audio, thred=0.03):
221
- hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
222
-
223
- return self.decode(hidden.squeeze(0).cpu().numpy() if not self.onnx else hidden[0], thred=thred)
224
-
225
- def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
226
- hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
227
-
228
- f0 = self.decode(hidden.squeeze(0).cpu().numpy() if not self.onnx else hidden[0], thred=thred)
229
- f0[(f0 < f0_min) | (f0 > f0_max)] = 0
230
-
231
- return f0
232
-
233
- def to_local_average_cents(self, salience, thred=0.05):
234
- center = np.argmax(salience, axis=1)
235
- salience = np.pad(salience, ((0, 0), (4, 4)))
236
-
237
- center += 4
238
- todo_salience, todo_cents_mapping = [], []
239
-
240
- starts = center - 4
241
- ends = center + 5
242
-
243
- for idx in range(salience.shape[0]):
244
- todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
245
- todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
246
-
247
- todo_salience = np.array(todo_salience)
248
-
249
- devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
250
- devided[np.max(salience, axis=1) <= thred] = 0
251
-
252
- return devided
253
-
254
- class BiGRU(nn.Module):
255
- def __init__(self, input_features, hidden_features, num_layers):
256
- super(BiGRU, self).__init__()
257
- self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
258
-
259
- def forward(self, x):
260
- return self.gru(x)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/WORLD.py DELETED
@@ -1,90 +0,0 @@
1
- import os
2
- import torch
3
- import shutil
4
- import ctypes
5
- import platform
6
-
7
- import numpy as np
8
- import tempfile as tf
9
-
10
- class DioOption(ctypes.Structure):
11
- _fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("ChannelsInOctave", ctypes.c_double), ("FramePeriod", ctypes.c_double), ("Speed", ctypes.c_int), ("AllowedRange", ctypes.c_double)]
12
-
13
- class HarvestOption(ctypes.Structure):
14
- _fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("FramePeriod", ctypes.c_double)]
15
-
16
- class PYWORLD:
17
- def __init__(self):
18
- model = torch.load(os.path.join("assets", "models", "predictors", "world.pth"), map_location="cpu")
19
-
20
- model_type, suffix = (("world_64" if platform.architecture()[0] == "64bit" else "world_86"), ".dll") if platform.system() == "Windows" else ("world_linux", ".so")
21
-
22
- self.temp_folder = os.path.join("assets", "models", "predictors", "temp")
23
- os.makedirs(self.temp_folder, exist_ok=True)
24
-
25
- with tf.NamedTemporaryFile(delete=False, suffix=suffix, dir=self.temp_folder) as temp_file:
26
- temp_file.write(model[model_type])
27
- temp_path = temp_file.name
28
-
29
- self.world_dll = ctypes.CDLL(temp_path)
30
-
31
- def harvest(self, x, fs, f0_floor=50, f0_ceil=1100, frame_period=10):
32
- self.world_dll.Harvest.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(HarvestOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
33
- self.world_dll.Harvest.restype = None
34
-
35
- self.world_dll.InitializeHarvestOption.argtypes = [ctypes.POINTER(HarvestOption)]
36
- self.world_dll.InitializeHarvestOption.restype = None
37
-
38
- self.world_dll.GetSamplesForHarvest.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
39
- self.world_dll.GetSamplesForHarvest.restype = ctypes.c_int
40
-
41
- option = HarvestOption()
42
- self.world_dll.InitializeHarvestOption(ctypes.byref(option))
43
-
44
- option.F0Floor = f0_floor
45
- option.F0Ceil = f0_ceil
46
- option.FramePeriod = frame_period
47
-
48
- f0_length = self.world_dll.GetSamplesForHarvest(fs, len(x), option.FramePeriod)
49
- f0 = (ctypes.c_double * f0_length)()
50
- tpos = (ctypes.c_double * f0_length)()
51
-
52
- self.world_dll.Harvest((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
53
- return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
54
-
55
- def dio(self, x, fs, f0_floor=50, f0_ceil=1100, channels_in_octave=2, frame_period=10, speed=1, allowed_range=0.1):
56
- self.world_dll.Dio.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(DioOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
57
- self.world_dll.Dio.restype = None
58
-
59
- self.world_dll.InitializeDioOption.argtypes = [ctypes.POINTER(DioOption)]
60
- self.world_dll.InitializeDioOption.restype = None
61
-
62
- self.world_dll.GetSamplesForDIO.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
63
- self.world_dll.GetSamplesForDIO.restype = ctypes.c_int
64
-
65
- option = DioOption()
66
- self.world_dll.InitializeDioOption(ctypes.byref(option))
67
-
68
- option.F0Floor = f0_floor
69
- option.F0Ceil = f0_ceil
70
- option.ChannelsInOctave = channels_in_octave
71
- option.FramePeriod = frame_period
72
- option.Speed = speed
73
- option.AllowedRange = allowed_range
74
-
75
- f0_length = self.world_dll.GetSamplesForDIO(fs, len(x), option.FramePeriod)
76
- f0 = (ctypes.c_double * f0_length)()
77
- tpos = (ctypes.c_double * f0_length)()
78
-
79
- self.world_dll.Dio((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
80
- return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
81
-
82
- def stonemask(self, x, fs, tpos, f0):
83
- self.world_dll.StoneMask.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
84
- self.world_dll.StoneMask.restype = None
85
-
86
- out_f0 = (ctypes.c_double * len(f0))()
87
- self.world_dll.StoneMask((ctypes.c_double * len(x))(*x), len(x), fs, (ctypes.c_double * len(tpos))(*tpos), (ctypes.c_double * len(f0))(*f0), len(f0), out_f0)
88
-
89
- if os.path.exists(self.temp_folder): shutil.rmtree(self.temp_folder, ignore_errors=True)
90
- return np.array(out_f0, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/utils.py DELETED
@@ -1,100 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import codecs
5
- import librosa
6
- import logging
7
-
8
- import numpy as np
9
- import soundfile as sf
10
-
11
- from pydub import AudioSegment, silence
12
-
13
- sys.path.append(os.getcwd())
14
-
15
- from main.tools import huggingface
16
- from main.configs.config import Config
17
-
18
- for l in ["httpx", "httpcore"]:
19
- logging.getLogger(l).setLevel(logging.ERROR)
20
-
21
- translations = Config().translations
22
-
23
-
24
- def check_predictors(method):
25
- def download(predictors):
26
- if not os.path.exists(os.path.join("assets", "models", "predictors", predictors)): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cerqvpgbef/", "rot13") + predictors, os.path.join("assets", "models", "predictors", predictors))
27
-
28
- model_dict = {**dict.fromkeys(["rmvpe", "rmvpe-legacy"], "rmvpe.pt"), **dict.fromkeys(["rmvpe-onnx", "rmvpe-legacy-onnx"], "rmvpe.onnx"), **dict.fromkeys(["fcpe", "fcpe-legacy"], "fcpe.pt"), **dict.fromkeys(["fcpe-onnx", "fcpe-legacy-onnx"], "fcpe.onnx"), **dict.fromkeys(["crepe-full", "mangio-crepe-full"], "crepe_full.pth"), **dict.fromkeys(["crepe-full-onnx", "mangio-crepe-full-onnx"], "crepe_full.onnx"), **dict.fromkeys(["crepe-large", "mangio-crepe-large"], "crepe_large.pth"), **dict.fromkeys(["crepe-large-onnx", "mangio-crepe-large-onnx"], "crepe_large.onnx"), **dict.fromkeys(["crepe-medium", "mangio-crepe-medium"], "crepe_medium.pth"), **dict.fromkeys(["crepe-medium-onnx", "mangio-crepe-medium-onnx"], "crepe_medium.onnx"), **dict.fromkeys(["crepe-small", "mangio-crepe-small"], "crepe_small.pth"), **dict.fromkeys(["crepe-small-onnx", "mangio-crepe-small-onnx"], "crepe_small.onnx"), **dict.fromkeys(["crepe-tiny", "mangio-crepe-tiny"], "crepe_tiny.pth"), **dict.fromkeys(["crepe-tiny-onnx", "mangio-crepe-tiny-onnx"], "crepe_tiny.onnx"), **dict.fromkeys(["harvest", "dio"], "world.pth")}
29
-
30
- if "hybrid" in method:
31
- methods_str = re.search("hybrid\[(.+)\]", method)
32
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
33
- for method in methods:
34
- if method in model_dict: download(model_dict[method])
35
- elif method in model_dict: download(model_dict[method])
36
-
37
- def check_embedders(hubert):
38
- if hubert in ["contentvec_base", "hubert_base", "japanese_hubert_base", "korean_hubert_base", "chinese_hubert_base", "Hidden_Rabbit_last", "portuguese_hubert_base"]:
39
- model_path = os.path.join("assets", "models", "embedders", hubert + '.pt')
40
- if not os.path.exists(model_path): huggingface.HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/rzorqqref/", "rot13") + f"{hubert}.pt", model_path)
41
-
42
- def load_audio(file):
43
- try:
44
- file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
45
- if not os.path.isfile(file): raise FileNotFoundError(translations["not_found"].format(name=file))
46
-
47
- audio, sr = sf.read(file)
48
-
49
- if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
50
- if sr != 16000: audio = librosa.resample(audio, orig_sr=sr, target_sr=16000, res_type="soxr_vhq")
51
- except Exception as e:
52
- raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
53
- return audio.flatten()
54
-
55
- def process_audio(logger, file_path, output_path):
56
- try:
57
- song = pydub_convert(AudioSegment.from_file(file_path))
58
- cut_files, time_stamps = [], []
59
-
60
- for i, (start_i, end_i) in enumerate(silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70)):
61
- chunk = song[start_i:end_i]
62
- if len(chunk) > 10:
63
- chunk_file_path = os.path.join(output_path, f"chunk{i}.wav")
64
- if os.path.exists(chunk_file_path): os.remove(chunk_file_path)
65
- chunk.export(chunk_file_path, format="wav")
66
- cut_files.append(chunk_file_path)
67
- time_stamps.append((start_i, end_i))
68
- else: logger.debug(translations["skip_file"].format(i=i, chunk=len(chunk)))
69
- logger.info(f"{translations['split_total']}: {len(cut_files)}")
70
- return cut_files, time_stamps
71
- except Exception as e:
72
- raise RuntimeError(f"{translations['process_audio_error']}: {e}")
73
-
74
- def merge_audio(files_list, time_stamps, original_file_path, output_path, format):
75
- try:
76
- def extract_number(filename):
77
- match = re.search(r'_(\d+)', filename)
78
- return int(match.group(1)) if match else 0
79
-
80
- total_duration = len(AudioSegment.from_file(original_file_path))
81
- combined = AudioSegment.empty()
82
- current_position = 0
83
-
84
- for file, (start_i, end_i) in zip(sorted(files_list, key=extract_number), time_stamps):
85
- if start_i > current_position: combined += AudioSegment.silent(duration=start_i - current_position)
86
- combined += AudioSegment.from_file(file)
87
- current_position = end_i
88
-
89
- if current_position < total_duration: combined += AudioSegment.silent(duration=total_duration - current_position)
90
- combined.export(output_path, format=format)
91
- return output_path
92
- except Exception as e:
93
- raise RuntimeError(f"{translations['merge_error']}: {e}")
94
-
95
- def pydub_convert(audio):
96
- samples = np.frombuffer(audio.raw_data, dtype=np.int16)
97
-
98
- if samples.dtype != np.int16: samples = (samples * 32767).astype(np.int16)
99
-
100
- return AudioSegment(samples.tobytes(), frame_rate=audio.frame_rate, sample_width=samples.dtype.itemsize, channels=audio.channels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/common_separator.py DELETED
@@ -1,250 +0,0 @@
1
- import os
2
- import gc
3
- import sys
4
- import torch
5
- import librosa
6
-
7
- import numpy as np
8
- import soundfile as sf
9
-
10
- from pydub import AudioSegment
11
-
12
- sys.path.append(os.getcwd())
13
-
14
- from .spec_utils import normalize
15
- from main.configs.config import Config
16
-
17
- translations = Config().translations
18
-
19
- class CommonSeparator:
20
- ALL_STEMS = "All Stems"
21
- VOCAL_STEM = "Vocals"
22
- INST_STEM = "Instrumental"
23
- OTHER_STEM = "Other"
24
- BASS_STEM = "Bass"
25
- DRUM_STEM = "Drums"
26
- GUITAR_STEM = "Guitar"
27
- PIANO_STEM = "Piano"
28
- SYNTH_STEM = "Synthesizer"
29
- STRINGS_STEM = "Strings"
30
- WOODWINDS_STEM = "Woodwinds"
31
- BRASS_STEM = "Brass"
32
- WIND_INST_STEM = "Wind Inst"
33
- NO_OTHER_STEM = "No Other"
34
- NO_BASS_STEM = "No Bass"
35
- NO_DRUM_STEM = "No Drums"
36
- NO_GUITAR_STEM = "No Guitar"
37
- NO_PIANO_STEM = "No Piano"
38
- NO_SYNTH_STEM = "No Synthesizer"
39
- NO_STRINGS_STEM = "No Strings"
40
- NO_WOODWINDS_STEM = "No Woodwinds"
41
- NO_WIND_INST_STEM = "No Wind Inst"
42
- NO_BRASS_STEM = "No Brass"
43
- PRIMARY_STEM = "Primary Stem"
44
- SECONDARY_STEM = "Secondary Stem"
45
- LEAD_VOCAL_STEM = "lead_only"
46
- BV_VOCAL_STEM = "backing_only"
47
- LEAD_VOCAL_STEM_I = "with_lead_vocals"
48
- BV_VOCAL_STEM_I = "with_backing_vocals"
49
- LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
50
- BV_VOCAL_STEM_LABEL = "Backing Vocals"
51
- NO_STEM = "No "
52
- STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
53
- NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
54
-
55
- def __init__(self, config):
56
- self.logger = config.get("logger")
57
- self.log_level = config.get("log_level")
58
- self.torch_device = config.get("torch_device")
59
- self.torch_device_cpu = config.get("torch_device_cpu")
60
- self.torch_device_mps = config.get("torch_device_mps")
61
- self.onnx_execution_provider = config.get("onnx_execution_provider")
62
- self.model_name = config.get("model_name")
63
- self.model_path = config.get("model_path")
64
- self.model_data = config.get("model_data")
65
- self.output_dir = config.get("output_dir")
66
- self.output_format = config.get("output_format")
67
- self.output_bitrate = config.get("output_bitrate")
68
- self.normalization_threshold = config.get("normalization_threshold")
69
- self.enable_denoise = config.get("enable_denoise")
70
- self.output_single_stem = config.get("output_single_stem")
71
- self.invert_using_spec = config.get("invert_using_spec")
72
- self.sample_rate = config.get("sample_rate")
73
- self.primary_stem_name = None
74
- self.secondary_stem_name = None
75
-
76
- if "training" in self.model_data and "instruments" in self.model_data["training"]:
77
- instruments = self.model_data["training"]["instruments"]
78
- if instruments:
79
- self.primary_stem_name = instruments[0]
80
- self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
81
-
82
- if self.primary_stem_name is None:
83
- self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
84
- self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
85
-
86
- self.is_karaoke = self.model_data.get("is_karaoke", False)
87
- self.is_bv_model = self.model_data.get("is_bv_model", False)
88
- self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
89
- self.logger.debug(translations["info"].format(model_name=self.model_name, model_path=self.model_path))
90
- self.logger.debug(translations["info_2"].format(output_dir=self.output_dir, output_format=self.output_format))
91
- self.logger.debug(translations["info_3"].format(normalization_threshold=self.normalization_threshold))
92
- self.logger.debug(translations["info_4"].format(enable_denoise=self.enable_denoise, output_single_stem=self.output_single_stem))
93
- self.logger.debug(translations["info_5"].format(invert_using_spec=self.invert_using_spec, sample_rate=self.sample_rate))
94
- self.logger.debug(translations["info_6"].format(primary_stem_name=self.primary_stem_name, secondary_stem_name=self.secondary_stem_name))
95
- self.logger.debug(translations["info_7"].format(is_karaoke=self.is_karaoke, is_bv_model=self.is_bv_model, bv_model_rebalance=self.bv_model_rebalance))
96
- self.audio_file_path = None
97
- self.audio_file_base = None
98
- self.primary_source = None
99
- self.secondary_source = None
100
- self.primary_stem_output_path = None
101
- self.secondary_stem_output_path = None
102
- self.cached_sources_map = {}
103
-
104
- def secondary_stem(self, primary_stem):
105
- primary_stem = primary_stem if primary_stem else self.NO_STEM
106
- return self.STEM_PAIR_MAPPER[primary_stem] if primary_stem in self.STEM_PAIR_MAPPER else primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
107
-
108
- def separate(self, audio_file_path):
109
- pass
110
-
111
- def final_process(self, stem_path, source, stem_name):
112
- self.logger.debug(translations["success_process"].format(stem_name=stem_name))
113
- self.write_audio(stem_path, source)
114
- return {stem_name: source}
115
-
116
- def cached_sources_clear(self):
117
- self.cached_sources_map = {}
118
-
119
- def cached_source_callback(self, model_architecture, model_name=None):
120
- model, sources = None, None
121
- mapper = self.cached_sources_map[model_architecture]
122
- for key, value in mapper.items():
123
- if model_name in key:
124
- model = key
125
- sources = value
126
-
127
- return model, sources
128
-
129
- def cached_model_source_holder(self, model_architecture, sources, model_name=None):
130
- self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
131
-
132
- def prepare_mix(self, mix):
133
- audio_path = mix
134
- if not isinstance(mix, np.ndarray):
135
- self.logger.debug(f"{translations['load_audio']}: {mix}")
136
- mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
137
- self.logger.debug(translations["load_audio_success"].format(sr=sr, shape=mix.shape))
138
- else:
139
- self.logger.debug(translations["convert_mix"])
140
- mix = mix.T
141
- self.logger.debug(translations["convert_shape"].format(shape=mix.shape))
142
-
143
- if isinstance(audio_path, str):
144
- if not np.any(mix):
145
- error_msg = translations["audio_not_valid"].format(audio_path=audio_path)
146
- self.logger.error(error_msg)
147
- raise ValueError(error_msg)
148
- else: self.logger.debug(translations["audio_valid"])
149
-
150
- if mix.ndim == 1:
151
- self.logger.debug(translations["mix_single"])
152
- mix = np.asfortranarray([mix, mix])
153
- self.logger.debug(translations["convert_mix_audio"])
154
-
155
- self.logger.debug(translations["mix_success_2"])
156
- return mix
157
-
158
- def write_audio(self, stem_path, stem_source):
159
- duration_seconds = librosa.get_duration(filename=self.audio_file_path)
160
- duration_hours = duration_seconds / 3600
161
- self.logger.info(translations["duration"].format(duration_hours=f"{duration_hours:.2f}", duration_seconds=f"{duration_seconds:.2f}"))
162
-
163
- if duration_hours >= 1:
164
- self.logger.debug(translations["write"].format(name="soundfile"))
165
- self.write_audio_soundfile(stem_path, stem_source)
166
- else:
167
- self.logger.info(translations["write"].format(name="pydub"))
168
- self.write_audio_pydub(stem_path, stem_source)
169
-
170
- def write_audio_pydub(self, stem_path, stem_source):
171
- self.logger.debug(f"{translations['write_audio'].format(name='write_audio_pydub')} {stem_path}")
172
- stem_source = normalize(wave=stem_source, max_peak=self.normalization_threshold)
173
-
174
- if np.max(np.abs(stem_source)) < 1e-6:
175
- self.logger.warning(translations["original_not_valid"])
176
- return
177
-
178
- if self.output_dir:
179
- os.makedirs(self.output_dir, exist_ok=True)
180
- stem_path = os.path.join(self.output_dir, stem_path)
181
-
182
- self.logger.debug(f"{translations['shape_audio']}: {stem_source.shape}")
183
- self.logger.debug(f"{translations['convert_data']}: {stem_source.dtype}")
184
-
185
- if stem_source.dtype != np.int16:
186
- stem_source = (stem_source * 32767).astype(np.int16)
187
- self.logger.debug(translations["original_source_to_int16"])
188
-
189
- stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
190
- stem_source_interleaved[0::2] = stem_source[:, 0]
191
- stem_source_interleaved[1::2] = stem_source[:, 1]
192
- self.logger.debug(f"{translations['shape_audio_2']}: {stem_source_interleaved.shape}")
193
-
194
- try:
195
- audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
196
- self.logger.debug(translations["create_audiosegment"])
197
- except (IOError, ValueError) as e:
198
- self.logger.error(f"{translations['create_audiosegment_error']}: {e}")
199
- return
200
-
201
- file_format = stem_path.lower().split(".")[-1]
202
-
203
- if file_format == "m4a": file_format = "mp4"
204
- elif file_format == "mka": file_format = "matroska"
205
-
206
- try:
207
- audio_segment.export(stem_path, format=file_format, bitrate="320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate)
208
- self.logger.debug(f"{translations['export_success']} {stem_path}")
209
- except (IOError, ValueError) as e:
210
- self.logger.error(f"{translations['export_error']}: {e}")
211
-
212
- def write_audio_soundfile(self, stem_path, stem_source):
213
- self.logger.debug(f"{translations['write_audio'].format(name='write_audio_soundfile')}: {stem_path}")
214
-
215
- if stem_source.shape[1] == 2:
216
- if stem_source.flags["F_CONTIGUOUS"]: stem_source = np.ascontiguousarray(stem_source)
217
- else:
218
- stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
219
- stereo_interleaved[0::2] = stem_source[:, 0]
220
- stereo_interleaved[1::2] = stem_source[:, 1]
221
- stem_source = stereo_interleaved
222
-
223
- self.logger.debug(f"{translations['shape_audio_2']}: {stem_source.shape}")
224
-
225
- try:
226
- sf.write(stem_path, stem_source, self.sample_rate)
227
- self.logger.debug(f"{translations['export_success']} {stem_path}")
228
- except Exception as e:
229
- self.logger.error(f"{translations['export_error']}: {e}")
230
-
231
- def clear_gpu_cache(self):
232
- self.logger.debug(translations["clean"])
233
- gc.collect()
234
-
235
- if self.torch_device == torch.device("mps"):
236
- self.logger.debug(translations["clean_cache"].format(name="MPS"))
237
- torch.mps.empty_cache()
238
-
239
- if self.torch_device == torch.device("cuda"):
240
- self.logger.debug(translations["clean_cache"].format(name="CUDA"))
241
- torch.cuda.empty_cache()
242
-
243
- def clear_file_specific_paths(self):
244
- self.logger.info(translations["del_path"])
245
- self.audio_file_path = None
246
- self.audio_file_base = None
247
- self.primary_source = None
248
- self.secondary_source = None
249
- self.primary_stem_output_path = None
250
- self.secondary_stem_output_path = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/apply.py DELETED
@@ -1,250 +0,0 @@
1
- import tqdm
2
- import torch
3
- import random
4
-
5
- from torch import nn
6
- from torch.nn import functional as F
7
- from concurrent.futures import ThreadPoolExecutor
8
-
9
- from .utils import center_trim
10
-
11
- class DummyPoolExecutor:
12
- class DummyResult:
13
- def __init__(self, func, *args, **kwargs):
14
- self.func = func
15
- self.args = args
16
- self.kwargs = kwargs
17
-
18
- def result(self):
19
- return self.func(*self.args, **self.kwargs)
20
-
21
- def __init__(self, workers=0):
22
- pass
23
-
24
- def submit(self, func, *args, **kwargs):
25
- return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
26
-
27
- def __enter__(self):
28
- return self
29
-
30
- def __exit__(self, exc_type, exc_value, exc_tb):
31
- return
32
-
33
- class BagOfModels(nn.Module):
34
- def __init__(self, models, weights = None, segment = None):
35
- super().__init__()
36
- assert len(models) > 0
37
- first = models[0]
38
-
39
- for other in models:
40
- assert other.sources == first.sources
41
- assert other.samplerate == first.samplerate
42
- assert other.audio_channels == first.audio_channels
43
-
44
- if segment is not None: other.segment = segment
45
-
46
- self.audio_channels = first.audio_channels
47
- self.samplerate = first.samplerate
48
- self.sources = first.sources
49
- self.models = nn.ModuleList(models)
50
-
51
- if weights is None: weights = [[1.0 for _ in first.sources] for _ in models]
52
- else:
53
- assert len(weights) == len(models)
54
-
55
- for weight in weights:
56
- assert len(weight) == len(first.sources)
57
-
58
- self.weights = weights
59
-
60
- def forward(self, x):
61
- pass
62
-
63
- class TensorChunk:
64
- def __init__(self, tensor, offset=0, length=None):
65
- total_length = tensor.shape[-1]
66
- assert offset >= 0
67
- assert offset < total_length
68
-
69
- length = total_length - offset if length is None else min(total_length - offset, length)
70
-
71
- if isinstance(tensor, TensorChunk):
72
- self.tensor = tensor.tensor
73
- self.offset = offset + tensor.offset
74
- else:
75
- self.tensor = tensor
76
- self.offset = offset
77
-
78
- self.length = length
79
- self.device = tensor.device
80
-
81
- @property
82
- def shape(self):
83
- shape = list(self.tensor.shape)
84
- shape[-1] = self.length
85
- return shape
86
-
87
- def padded(self, target_length):
88
- delta = target_length - self.length
89
- total_length = self.tensor.shape[-1]
90
- assert delta >= 0
91
-
92
- start = self.offset - delta // 2
93
- end = start + target_length
94
-
95
- correct_start = max(0, start)
96
- correct_end = min(total_length, end)
97
-
98
- pad_left = correct_start - start
99
- pad_right = end - correct_end
100
-
101
- out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
102
-
103
- assert out.shape[-1] == target_length
104
- return out
105
-
106
- def tensor_chunk(tensor_or_chunk):
107
- if isinstance(tensor_or_chunk, TensorChunk): return tensor_or_chunk
108
- else:
109
- assert isinstance(tensor_or_chunk, torch.Tensor)
110
- return TensorChunk(tensor_or_chunk)
111
-
112
- def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
113
- global fut_length, bag_num, prog_bar
114
-
115
- device = mix.device if device is None else torch.device(device)
116
- if pool is None: pool = ThreadPoolExecutor(num_workers) if num_workers > 0 and device.type == "cpu" else DummyPoolExecutor()
117
-
118
- kwargs = {
119
- "shifts": shifts,
120
- "split": split,
121
- "overlap": overlap,
122
- "transition_power": transition_power,
123
- "progress": progress,
124
- "device": device,
125
- "pool": pool,
126
- "set_progress_bar": set_progress_bar,
127
- "static_shifts": static_shifts,
128
- }
129
-
130
- if isinstance(model, BagOfModels):
131
- estimates, fut_length, prog_bar, current_model = 0, 0, 0, 0
132
- totals = [0] * len(model.sources)
133
- bag_num = len(model.models)
134
-
135
- for sub_model, weight in zip(model.models, model.weights):
136
- original_model_device = next(iter(sub_model.parameters())).device
137
- sub_model.to(device)
138
- fut_length += fut_length
139
- current_model += 1
140
- out = apply_model(sub_model, mix, **kwargs)
141
- sub_model.to(original_model_device)
142
-
143
- for k, inst_weight in enumerate(weight):
144
- out[:, k, :, :] *= inst_weight
145
- totals[k] += inst_weight
146
-
147
- estimates += out
148
- del out
149
-
150
- for k in range(estimates.shape[1]):
151
- estimates[:, k, :, :] /= totals[k]
152
-
153
- return estimates
154
-
155
- model.to(device)
156
- model.eval()
157
- assert transition_power >= 1
158
- batch, channels, length = mix.shape
159
-
160
- if shifts:
161
- kwargs["shifts"] = 0
162
- max_shift = int(0.5 * model.samplerate)
163
- mix = tensor_chunk(mix)
164
- padded_mix = mix.padded(length + 2 * max_shift)
165
- out = 0
166
-
167
- for _ in range(shifts):
168
- offset = random.randint(0, max_shift)
169
- shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
170
- shifted_out = apply_model(model, shifted, **kwargs)
171
- out += shifted_out[..., max_shift - offset :]
172
-
173
- out /= shifts
174
- return out
175
- elif split:
176
- kwargs["split"] = False
177
- out = torch.zeros(batch, len(model.sources), channels, length, device=mix.device)
178
- sum_weight = torch.zeros(length, device=mix.device)
179
- segment = int(model.samplerate * model.segment)
180
- stride = int((1 - overlap) * segment)
181
- offsets = range(0, length, stride)
182
- weight = torch.cat([torch.arange(1, segment // 2 + 1, device=device), torch.arange(segment - segment // 2, 0, -1, device=device)])
183
- assert len(weight) == segment
184
- weight = (weight / weight.max()) ** transition_power
185
- futures = []
186
-
187
- for offset in offsets:
188
- chunk = TensorChunk(mix, offset, segment)
189
- future = pool.submit(apply_model, model, chunk, **kwargs)
190
- futures.append((future, offset))
191
- offset += segment
192
-
193
- if progress: futures = tqdm.tqdm(futures)
194
-
195
- for future, offset in futures:
196
- if set_progress_bar:
197
- fut_length = len(futures) * bag_num * static_shifts
198
- prog_bar += 1
199
- set_progress_bar(0.1, (0.8 / fut_length * prog_bar))
200
-
201
- chunk_out = future.result()
202
- chunk_length = chunk_out.shape[-1]
203
-
204
- out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
205
- sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device)
206
-
207
- assert sum_weight.min() > 0
208
-
209
- out /= sum_weight
210
- return out
211
- else:
212
- valid_length = model.valid_length(length) if hasattr(model, "valid_length") else length
213
- mix = tensor_chunk(mix)
214
- padded_mix = mix.padded(valid_length).to(device)
215
-
216
- with torch.no_grad():
217
- out = model(padded_mix)
218
-
219
- return center_trim(out, length)
220
-
221
- def demucs_segments(demucs_segment, demucs_model):
222
- if demucs_segment == "Default":
223
- segment = None
224
-
225
- if isinstance(demucs_model, BagOfModels):
226
- if segment is not None:
227
- for sub in demucs_model.models:
228
- sub.segment = segment
229
- else:
230
- if segment is not None: sub.segment = segment
231
- else:
232
- try:
233
- segment = int(demucs_segment)
234
- if isinstance(demucs_model, BagOfModels):
235
- if segment is not None:
236
- for sub in demucs_model.models:
237
- sub.segment = segment
238
- else:
239
- if segment is not None: sub.segment = segment
240
- except:
241
- segment = None
242
-
243
- if isinstance(demucs_model, BagOfModels):
244
- if segment is not None:
245
- for sub in demucs_model.models:
246
- sub.segment = segment
247
- else:
248
- if segment is not None: sub.segment = segment
249
-
250
- return demucs_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/demucs.py DELETED
@@ -1,370 +0,0 @@
1
- import math
2
- import torch
3
- import inspect
4
-
5
- from torch import nn
6
-
7
- from torch.nn import functional as F
8
-
9
- from .utils import center_trim
10
- from .states import capture_init
11
-
12
-
13
-
14
- def unfold(a, kernel_size, stride):
15
- *shape, length = a.shape
16
- n_frames = math.ceil(length / stride)
17
- tgt_length = (n_frames - 1) * stride + kernel_size
18
- a = F.pad(a, (0, tgt_length - length))
19
- strides = list(a.stride())
20
- assert strides[-1] == 1
21
- strides = strides[:-1] + [stride, 1]
22
- return a.as_strided([*shape, n_frames, kernel_size], strides)
23
-
24
- def rescale_conv(conv, reference):
25
- scale = (conv.weight.std().detach() / reference) ** 0.5
26
- conv.weight.data /= scale
27
- if conv.bias is not None: conv.bias.data /= scale
28
-
29
- def rescale_module(module, reference):
30
- for sub in module.modules():
31
- if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): rescale_conv(sub, reference)
32
-
33
- class BLSTM(nn.Module):
34
- def __init__(self, dim, layers=1, max_steps=None, skip=False):
35
- super().__init__()
36
- assert max_steps is None or max_steps % 4 == 0
37
- self.max_steps = max_steps
38
- self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
39
- self.linear = nn.Linear(2 * dim, dim)
40
- self.skip = skip
41
-
42
- def forward(self, x):
43
- B, C, T = x.shape
44
- y = x
45
- framed = False
46
-
47
- if self.max_steps is not None and T > self.max_steps:
48
- width = self.max_steps
49
- stride = width // 2
50
- frames = unfold(x, width, stride)
51
- nframes = frames.shape[2]
52
- framed = True
53
- x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
54
-
55
- x = x.permute(2, 0, 1)
56
- x = self.lstm(x)[0]
57
- x = self.linear(x)
58
- x = x.permute(1, 2, 0)
59
-
60
- if framed:
61
- out = []
62
- frames = x.reshape(B, -1, C, width)
63
- limit = stride // 2
64
-
65
- for k in range(nframes):
66
- if k == 0: out.append(frames[:, k, :, :-limit])
67
- elif k == nframes - 1: out.append(frames[:, k, :, limit:])
68
- else: out.append(frames[:, k, :, limit:-limit])
69
-
70
- out = torch.cat(out, -1)
71
- out = out[..., :T]
72
- x = out
73
-
74
- if self.skip: x = x + y
75
- return x
76
-
77
- class LayerScale(nn.Module):
78
- def __init__(self, channels, init = 0):
79
- super().__init__()
80
- self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
81
- self.scale.data[:] = init
82
-
83
- def forward(self, x):
84
- return self.scale[:, None] * x
85
-
86
- class DConv(nn.Module):
87
- def __init__(self, channels, compress = 4, depth = 2, init = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True):
88
- super().__init__()
89
- assert kernel % 2 == 1
90
- self.channels = channels
91
- self.compress = compress
92
- self.depth = abs(depth)
93
- dilate = depth > 0
94
- norm_fn = lambda d: nn.Identity()
95
- if norm: norm_fn = lambda d: nn.GroupNorm(1, d)
96
- hidden = int(channels / compress)
97
- act = nn.GELU if gelu else nn.ReLU
98
- self.layers = nn.ModuleList([])
99
-
100
- for d in range(self.depth):
101
- dilation = 2**d if dilate else 1
102
- padding = dilation * (kernel // 2)
103
-
104
- mods = [nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), norm_fn(hidden), act(), nn.Conv1d(hidden, 2 * channels, 1), norm_fn(2 * channels), nn.GLU(1), LayerScale(channels, init)]
105
-
106
- if attn: mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
107
- if lstm: mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
108
- layer = nn.Sequential(*mods)
109
- self.layers.append(layer)
110
-
111
- def forward(self, x):
112
- for layer in self.layers:
113
- x = x + layer(x)
114
-
115
- return x
116
-
117
- class LocalState(nn.Module):
118
- def __init__(self, channels, heads = 4, nfreqs = 0, ndecay = 4):
119
- super().__init__()
120
- assert channels % heads == 0, (channels, heads)
121
- self.heads = heads
122
- self.nfreqs = nfreqs
123
- self.ndecay = ndecay
124
- self.content = nn.Conv1d(channels, channels, 1)
125
- self.query = nn.Conv1d(channels, channels, 1)
126
- self.key = nn.Conv1d(channels, channels, 1)
127
-
128
- if nfreqs: self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
129
-
130
- if ndecay:
131
- self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
132
- self.query_decay.weight.data *= 0.01
133
- assert self.query_decay.bias is not None
134
- self.query_decay.bias.data[:] = -2
135
-
136
- self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
137
-
138
- def forward(self, x):
139
- B, C, T = x.shape
140
- heads = self.heads
141
- indexes = torch.arange(T, device=x.device, dtype=x.dtype)
142
- delta = indexes[:, None] - indexes[None, :]
143
- queries = self.query(x).view(B, heads, -1, T)
144
- keys = self.key(x).view(B, heads, -1, T)
145
- dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
146
- dots /= keys.shape[2] ** 0.5
147
-
148
- if self.nfreqs:
149
- periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
150
- freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
151
- freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5
152
- dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
153
-
154
- if self.ndecay:
155
- decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
156
- decay_q = self.query_decay(x).view(B, heads, -1, T)
157
- decay_q = torch.sigmoid(decay_q) / 2
158
- decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
159
- dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
160
-
161
- dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
162
- weights = torch.softmax(dots, dim=2)
163
- content = self.content(x).view(B, heads, -1, T)
164
- result = torch.einsum("bhts,bhct->bhcs", weights, content)
165
-
166
- if self.nfreqs:
167
- time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
168
- result = torch.cat([result, time_sig], 2)
169
-
170
- result = result.reshape(B, -1, T)
171
- return x + self.proj(result)
172
-
173
- class Demucs(nn.Module):
174
- @capture_init
175
- def __init__(self, sources, audio_channels=2, channels=64, growth=2.0, depth=6, rewrite=True, lstm_layers=0, kernel_size=8, stride=4, context=1, gelu=True, glu=True, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, normalize=True, resample=True, rescale=0.1, samplerate=44100, segment=4 * 10):
176
- super().__init__()
177
- self.audio_channels = audio_channels
178
- self.sources = sources
179
- self.kernel_size = kernel_size
180
- self.context = context
181
- self.stride = stride
182
- self.depth = depth
183
- self.resample = resample
184
- self.channels = channels
185
- self.normalize = normalize
186
- self.samplerate = samplerate
187
- self.segment = segment
188
- self.encoder = nn.ModuleList()
189
- self.decoder = nn.ModuleList()
190
- self.skip_scales = nn.ModuleList()
191
-
192
- if glu:
193
- activation = nn.GLU(dim=1)
194
- ch_scale = 2
195
- else:
196
- activation = nn.ReLU()
197
- ch_scale = 1
198
-
199
- act2 = nn.GELU if gelu else nn.ReLU
200
-
201
- in_channels = audio_channels
202
- padding = 0
203
-
204
- for index in range(depth):
205
- norm_fn = lambda d: nn.Identity()
206
- if index >= norm_starts: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
207
-
208
- encode = []
209
- encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2()]
210
- attn = index >= dconv_attn
211
- lstm = index >= dconv_lstm
212
-
213
- if dconv_mode & 1: encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
214
- if rewrite: encode += [nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation]
215
- self.encoder.append(nn.Sequential(*encode))
216
-
217
- decode = []
218
- out_channels = in_channels if index > 0 else len(self.sources) * audio_channels
219
-
220
- if rewrite: decode += [nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation]
221
- if dconv_mode & 2: decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
222
- decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)]
223
-
224
- if index > 0: decode += [norm_fn(out_channels), act2()]
225
- self.decoder.insert(0, nn.Sequential(*decode))
226
- in_channels = channels
227
- channels = int(growth * channels)
228
-
229
- channels = in_channels
230
- self.lstm = BLSTM(channels, lstm_layers) if lstm_layers else None
231
- if rescale: rescale_module(self, reference=rescale)
232
-
233
- def valid_length(self, length):
234
- if self.resample: length *= 2
235
-
236
- for _ in range(self.depth):
237
- length = math.ceil((length - self.kernel_size) / self.stride) + 1
238
- length = max(1, length)
239
-
240
- for _ in range(self.depth):
241
- length = (length - 1) * self.stride + self.kernel_size
242
-
243
- if self.resample: length = math.ceil(length / 2)
244
- return int(length)
245
-
246
- def forward(self, mix):
247
- x = mix
248
- length = x.shape[-1]
249
-
250
- if self.normalize:
251
- mono = mix.mean(dim=1, keepdim=True)
252
- mean = mono.mean(dim=-1, keepdim=True)
253
- std = mono.std(dim=-1, keepdim=True)
254
- x = (x - mean) / (1e-5 + std)
255
- else:
256
- mean = 0
257
- std = 1
258
-
259
- delta = self.valid_length(length) - length
260
- x = F.pad(x, (delta // 2, delta - delta // 2))
261
-
262
- if self.resample: x = resample_frac(x, 1, 2)
263
- saved = []
264
-
265
- for encode in self.encoder:
266
- x = encode(x)
267
- saved.append(x)
268
-
269
- if self.lstm: x = self.lstm(x)
270
-
271
- for decode in self.decoder:
272
- skip = saved.pop(-1)
273
- skip = center_trim(skip, x)
274
- x = decode(x + skip)
275
-
276
- if self.resample: x = resample_frac(x, 2, 1)
277
-
278
- x = x * std + mean
279
- x = center_trim(x, length)
280
- x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
281
- return x
282
-
283
- def load_state_dict(self, state, strict=True):
284
- for idx in range(self.depth):
285
- for a in ["encoder", "decoder"]:
286
- for b in ["bias", "weight"]:
287
- new = f"{a}.{idx}.3.{b}"
288
- old = f"{a}.{idx}.2.{b}"
289
-
290
- if old in state and new not in state: state[new] = state.pop(old)
291
- super().load_state_dict(state, strict=strict)
292
-
293
- class ResampleFrac(torch.nn.Module):
294
- def __init__(self, old_sr, new_sr, zeros = 24, rolloff = 0.945):
295
- super().__init__()
296
- gcd = math.gcd(old_sr, new_sr)
297
- self.old_sr = old_sr // gcd
298
- self.new_sr = new_sr // gcd
299
- self.zeros = zeros
300
- self.rolloff = rolloff
301
- self._init_kernels()
302
-
303
- def _init_kernels(self):
304
- if self.old_sr == self.new_sr: return
305
-
306
- kernels = []
307
- sr = min(self.new_sr, self.old_sr)
308
- sr *= self.rolloff
309
-
310
- self._width = math.ceil(self.zeros * self.old_sr / sr)
311
- idx = torch.arange(-self._width, self._width + self.old_sr).float()
312
-
313
- for i in range(self.new_sr):
314
- t = ((-i/self.new_sr + idx/self.old_sr) * sr).clamp_(-self.zeros, self.zeros)
315
- t *= math.pi
316
-
317
- kernel = sinc(t) * (torch.cos(t/self.zeros/2)**2)
318
- kernel.div_(kernel.sum())
319
- kernels.append(kernel)
320
-
321
- self.register_buffer("kernel", torch.stack(kernels).view(self.new_sr, 1, -1))
322
-
323
- def forward(self, x, output_length = None, full = False):
324
- if self.old_sr == self.new_sr: return x
325
- shape = x.shape
326
- length = x.shape[-1]
327
-
328
- x = x.reshape(-1, length)
329
- y = F.conv1d(F.pad(x[:, None], (self._width, self._width + self.old_sr), mode='replicate'), self.kernel, stride=self.old_sr).transpose(1, 2).reshape(list(shape[:-1]) + [-1])
330
-
331
- float_output_length = torch.as_tensor(self.new_sr * length / self.old_sr)
332
- max_output_length = torch.ceil(float_output_length).long()
333
- default_output_length = torch.floor(float_output_length).long()
334
-
335
- if output_length is None: applied_output_length = max_output_length if full else default_output_length
336
- elif output_length < 0 or output_length > max_output_length: raise ValueError("output_length < 0 or output_length > max_output_length")
337
- else:
338
- applied_output_length = torch.tensor(output_length)
339
- if full: raise ValueError("full=True")
340
-
341
- return y[..., :applied_output_length]
342
-
343
- def __repr__(self):
344
- return simple_repr(self)
345
-
346
- def sinc(x):
347
- return torch.where(x == 0, torch.tensor(1., device=x.device, dtype=x.dtype), torch.sin(x) / x)
348
-
349
- def simple_repr(obj, attrs = None, overrides = {}):
350
- params = inspect.signature(obj.__class__).parameters
351
- attrs_repr = []
352
-
353
- if attrs is None: attrs = list(params.keys())
354
- for attr in attrs:
355
- display = False
356
-
357
- if attr in overrides: value = overrides[attr]
358
- elif hasattr(obj, attr): value = getattr(obj, attr)
359
- else: continue
360
-
361
- if attr in params:
362
- param = params[attr]
363
- if param.default is inspect._empty or value != param.default: display = True
364
- else: display = True
365
-
366
- if display: attrs_repr.append(f"{attr}={value}")
367
- return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
368
-
369
- def resample_frac(x, old_sr, new_sr, zeros = 24, rolloff = 0.945, output_length = None, full = False):
370
- return ResampleFrac(old_sr, new_sr, zeros, rolloff).to(x)(x, output_length, full)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/hdemucs.py DELETED
@@ -1,760 +0,0 @@
1
- import math
2
- import torch
3
-
4
- from torch import nn
5
- from copy import deepcopy
6
-
7
- from torch.nn import functional as F
8
-
9
- from .states import capture_init
10
- from .demucs import DConv, rescale_module
11
-
12
-
13
- def spectro(x, n_fft=512, hop_length=None, pad=0):
14
- *other, length = x.shape
15
- x = x.reshape(-1, length)
16
- device_type = x.device.type
17
- is_other_gpu = not device_type in ["cuda", "cpu"]
18
- if is_other_gpu: x = x.cpu()
19
- z = torch.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=torch.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
20
- _, freqs, frame = z.shape
21
- return z.view(*other, freqs, frame)
22
-
23
- def ispectro(z, hop_length=None, length=None, pad=0):
24
- *other, freqs, frames = z.shape
25
- n_fft = 2 * freqs - 2
26
- z = z.view(-1, freqs, frames)
27
- win_length = n_fft // (1 + pad)
28
- device_type = z.device.type
29
- is_other_gpu = not device_type in ["cuda", "cpu"]
30
- if is_other_gpu: z = z.cpu()
31
- x = torch.istft(z, n_fft, hop_length, window=torch.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
32
- _, length = x.shape
33
- return x.view(*other, length)
34
-
35
- def atan2(y, x):
36
- pi = 2 * torch.asin(torch.tensor(1.0))
37
- x += ((x == 0) & (y == 0)) * 1.0
38
- out = torch.atan(y / x)
39
- out += ((y >= 0) & (x < 0)) * pi
40
- out -= ((y < 0) & (x < 0)) * pi
41
- out *= 1 - ((y > 0) & (x == 0)) * 1.0
42
- out += ((y > 0) & (x == 0)) * (pi / 2)
43
- out *= 1 - ((y < 0) & (x == 0)) * 1.0
44
- out += ((y < 0) & (x == 0)) * (-pi / 2)
45
- return out
46
-
47
- def _norm(x):
48
- return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
49
-
50
- def _mul_add(a, b, out = None):
51
- target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
52
- if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
53
-
54
- if out is a:
55
- real_a = a[..., 0]
56
- out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
57
- out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
58
- else:
59
- out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
60
- out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
61
-
62
- return out
63
-
64
- def _mul(a, b, out = None):
65
- target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
66
- if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
67
-
68
- if out is a:
69
- real_a = a[..., 0]
70
- out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
71
- out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
72
- else:
73
- out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
74
- out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
75
-
76
- return out
77
-
78
- def _inv(z, out = None):
79
- ez = _norm(z)
80
- if out is None or out.shape != z.shape: out = torch.zeros_like(z)
81
-
82
- out[..., 0] = z[..., 0] / ez
83
- out[..., 1] = -z[..., 1] / ez
84
-
85
- return out
86
-
87
- def _conj(z, out = None):
88
- if out is None or out.shape != z.shape: out = torch.zeros_like(z)
89
-
90
- out[..., 0] = z[..., 0]
91
- out[..., 1] = -z[..., 1]
92
-
93
- return out
94
-
95
- def _invert(M, out = None):
96
- nb_channels = M.shape[-2]
97
- if out is None or out.shape != M.shape: out = torch.empty_like(M)
98
-
99
- if nb_channels == 1: out = _inv(M, out)
100
- elif nb_channels == 2:
101
- det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
102
- det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
103
- invDet = _inv(det)
104
- out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
105
- out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
106
- out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
107
- out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
108
- else: raise Exception("Torch == 2 Channels")
109
- return out
110
-
111
- def expectation_maximization(y, x, iterations = 2, eps = 1e-10, batch_size = 200):
112
- (nb_frames, nb_bins, nb_channels) = x.shape[:-1]
113
- nb_sources = y.shape[-1]
114
- regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
115
- regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
116
- R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
117
- weight = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
118
- v = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
119
-
120
- for _ in range(iterations):
121
- v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
122
- for j in range(nb_sources):
123
- R[j] = torch.tensor(0.0, device=x.device)
124
-
125
- weight = torch.tensor(eps, device=x.device)
126
- pos = 0
127
- batch_size = batch_size if batch_size else nb_frames
128
-
129
- while pos < nb_frames:
130
- t = torch.arange(pos, min(nb_frames, pos + batch_size))
131
- pos = int(t[-1]) + 1
132
-
133
- R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
134
- weight = weight + torch.sum(v[t, ..., j], dim=0)
135
-
136
- R[j] = R[j] / weight[..., None, None, None]
137
- weight = torch.zeros_like(weight)
138
-
139
- if y.requires_grad: y = y.clone()
140
-
141
- pos = 0
142
-
143
- while pos < nb_frames:
144
- t = torch.arange(pos, min(nb_frames, pos + batch_size))
145
- pos = int(t[-1]) + 1
146
-
147
- y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
148
-
149
- Cxx = regularization
150
-
151
- for j in range(nb_sources):
152
- Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
153
-
154
- inv_Cxx = _invert(Cxx)
155
-
156
- for j in range(nb_sources):
157
- gain = torch.zeros_like(inv_Cxx)
158
- indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
159
-
160
- for index in indices:
161
- gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
162
-
163
- gain = gain * v[t, ..., None, None, None, j]
164
-
165
- for i in range(nb_channels):
166
- y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
167
-
168
- return y, v, R
169
-
170
- def wiener(targets_spectrograms, mix_stft, iterations = 1, softmask = False, residual = False, scale_factor = 10.0, eps = 1e-10):
171
- if softmask: y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
172
- else:
173
- angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
174
- nb_sources = targets_spectrograms.shape[-1]
175
- y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
176
- y[..., 0, :] = targets_spectrograms * torch.cos(angle)
177
- y[..., 1, :] = targets_spectrograms * torch.sin(angle)
178
-
179
- if residual: y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
180
- if iterations == 0: return y
181
-
182
- max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
183
- mix_stft = mix_stft / max_abs
184
- y = y / max_abs
185
- y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
186
- y = y * max_abs
187
-
188
- return y
189
-
190
- def _covariance(y_j):
191
- (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
192
-
193
- Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
194
- indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
195
-
196
- for index in indices:
197
- Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
198
-
199
- return Cj
200
-
201
- def pad1d(x, paddings, mode = "constant", value = 0.0):
202
- x0 = x
203
- length = x.shape[-1]
204
- padding_left, padding_right = paddings
205
-
206
- if mode == "reflect":
207
- max_pad = max(padding_left, padding_right)
208
-
209
- if length <= max_pad:
210
- extra_pad = max_pad - length + 1
211
- extra_pad_right = min(padding_right, extra_pad)
212
- extra_pad_left = extra_pad - extra_pad_right
213
- paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
214
- x = F.pad(x, (extra_pad_left, extra_pad_right))
215
-
216
- out = F.pad(x, paddings, mode, value)
217
-
218
- assert out.shape[-1] == length + padding_left + padding_right
219
- assert (out[..., padding_left : padding_left + length] == x0).all()
220
- return out
221
-
222
- class ScaledEmbedding(nn.Module):
223
- def __init__(self, num_embeddings, embedding_dim, scale = 10.0, smooth=False):
224
- super().__init__()
225
- self.embedding = nn.Embedding(num_embeddings, embedding_dim)
226
-
227
- if smooth:
228
- weight = torch.cumsum(self.embedding.weight.data, dim=0)
229
- weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
230
- self.embedding.weight.data[:] = weight
231
-
232
- self.embedding.weight.data /= scale
233
- self.scale = scale
234
-
235
- @property
236
- def weight(self):
237
- return self.embedding.weight * self.scale
238
-
239
- def forward(self, x):
240
- return self.embedding(x) * self.scale
241
-
242
- class HEncLayer(nn.Module):
243
- def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
244
- super().__init__()
245
- norm_fn = lambda d: nn.Identity()
246
- if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
247
- pad = kernel_size // 4 if pad else 0
248
-
249
- klass = nn.Conv1d
250
- self.freq = freq
251
- self.kernel_size = kernel_size
252
- self.stride = stride
253
- self.empty = empty
254
- self.norm = norm
255
- self.pad = pad
256
-
257
- if freq:
258
- kernel_size = [kernel_size, 1]
259
- stride = [stride, 1]
260
- pad = [pad, 0]
261
- klass = nn.Conv2d
262
-
263
- self.conv = klass(chin, chout, kernel_size, stride, pad)
264
- if self.empty: return
265
-
266
- self.norm1 = norm_fn(chout)
267
- self.rewrite = None
268
-
269
- if rewrite:
270
- self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
271
- self.norm2 = norm_fn(2 * chout)
272
-
273
- self.dconv = None
274
- if dconv: self.dconv = DConv(chout, **dconv_kw)
275
-
276
- def forward(self, x, inject=None):
277
- if not self.freq and x.dim() == 4:
278
- B, C, Fr, T = x.shape
279
- x = x.view(B, -1, T)
280
-
281
- if not self.freq:
282
- le = x.shape[-1]
283
- if not le % self.stride == 0: x = F.pad(x, (0, self.stride - (le % self.stride)))
284
-
285
- y = self.conv(x)
286
- if self.empty: return y
287
-
288
- if inject is not None:
289
- assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
290
-
291
- if inject.dim() == 3 and y.dim() == 4: inject = inject[:, :, None]
292
- y = y + inject
293
-
294
- y = F.gelu(self.norm1(y))
295
-
296
- if self.dconv:
297
- if self.freq:
298
- B, C, Fr, T = y.shape
299
- y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
300
-
301
- y = self.dconv(y)
302
- if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
303
-
304
- if self.rewrite:
305
- z = self.norm2(self.rewrite(y))
306
- z = F.glu(z, dim=1)
307
- else: z = y
308
-
309
- return z
310
-
311
- class MultiWrap(nn.Module):
312
- def __init__(self, layer, split_ratios):
313
- super().__init__()
314
- self.split_ratios = split_ratios
315
- self.layers = nn.ModuleList()
316
- self.conv = isinstance(layer, HEncLayer)
317
- assert not layer.norm
318
- assert layer.freq
319
- assert layer.pad
320
-
321
- if not self.conv: assert not layer.context_freq
322
-
323
- for _ in range(len(split_ratios) + 1):
324
- lay = deepcopy(layer)
325
-
326
- if self.conv: lay.conv.padding = (0, 0)
327
- else: lay.pad = False
328
-
329
- for m in lay.modules():
330
- if hasattr(m, "reset_parameters"): m.reset_parameters()
331
-
332
- self.layers.append(lay)
333
-
334
- def forward(self, x, skip=None, length=None):
335
- B, C, Fr, T = x.shape
336
- ratios = list(self.split_ratios) + [1]
337
- start = 0
338
- outs = []
339
-
340
- for ratio, layer in zip(ratios, self.layers):
341
- if self.conv:
342
- pad = layer.kernel_size // 4
343
-
344
- if ratio == 1:
345
- limit = Fr
346
- frames = -1
347
- else:
348
- limit = int(round(Fr * ratio))
349
- le = limit - start
350
-
351
- if start == 0: le += pad
352
-
353
- frames = round((le - layer.kernel_size) / layer.stride + 1)
354
- limit = start + (frames - 1) * layer.stride + layer.kernel_size
355
-
356
- if start == 0: limit -= pad
357
-
358
- assert limit - start > 0, (limit, start)
359
- assert limit <= Fr, (limit, Fr)
360
-
361
- y = x[:, :, start:limit, :]
362
-
363
- if start == 0: y = F.pad(y, (0, 0, pad, 0))
364
- if ratio == 1: y = F.pad(y, (0, 0, 0, pad))
365
-
366
- outs.append(layer(y))
367
- start = limit - layer.kernel_size + layer.stride
368
- else:
369
- limit = Fr if ratio == 1 else int(round(Fr * ratio))
370
-
371
- last = layer.last
372
- layer.last = True
373
-
374
- y = x[:, :, start:limit]
375
- s = skip[:, :, start:limit]
376
- out, _ = layer(y, s, None)
377
-
378
- if outs:
379
- outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
380
- out = out[:, :, layer.stride :]
381
-
382
- if ratio == 1: out = out[:, :, : -layer.stride // 2, :]
383
- if start == 0: out = out[:, :, layer.stride // 2 :, :]
384
-
385
- outs.append(out)
386
- layer.last = last
387
- start = limit
388
-
389
- out = torch.cat(outs, dim=2)
390
- if not self.conv and not last: out = F.gelu(out)
391
-
392
- if self.conv: return out
393
- else: return out, None
394
-
395
- class HDecLayer(nn.Module):
396
- def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True):
397
- super().__init__()
398
- norm_fn = lambda d: nn.Identity()
399
-
400
- if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
401
- pad = kernel_size // 4 if pad else 0
402
-
403
- self.pad = pad
404
- self.last = last
405
- self.freq = freq
406
- self.chin = chin
407
- self.empty = empty
408
- self.stride = stride
409
- self.kernel_size = kernel_size
410
- self.norm = norm
411
- self.context_freq = context_freq
412
- klass = nn.Conv1d
413
- klass_tr = nn.ConvTranspose1d
414
-
415
- if freq:
416
- kernel_size = [kernel_size, 1]
417
- stride = [stride, 1]
418
- klass = nn.Conv2d
419
- klass_tr = nn.ConvTranspose2d
420
-
421
- self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
422
- self.norm2 = norm_fn(chout)
423
-
424
- if self.empty: return
425
- self.rewrite = None
426
-
427
- if rewrite:
428
- if context_freq: self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
429
- else: self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
430
-
431
- self.norm1 = norm_fn(2 * chin)
432
-
433
- self.dconv = None
434
- if dconv: self.dconv = DConv(chin, **dconv_kw)
435
-
436
- def forward(self, x, skip, length):
437
- if self.freq and x.dim() == 3:
438
- B, C, T = x.shape
439
- x = x.view(B, self.chin, -1, T)
440
-
441
- if not self.empty:
442
- x = x + skip
443
-
444
- y = F.glu(self.norm1(self.rewrite(x)), dim=1) if self.rewrite else x
445
-
446
- if self.dconv:
447
- if self.freq:
448
- B, C, Fr, T = y.shape
449
- y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
450
-
451
- y = self.dconv(y)
452
-
453
- if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
454
- else:
455
- y = x
456
- assert skip is None
457
-
458
- z = self.norm2(self.conv_tr(y))
459
-
460
- if self.freq:
461
- if self.pad: z = z[..., self.pad : -self.pad, :]
462
- else:
463
- z = z[..., self.pad : self.pad + length]
464
- assert z.shape[-1] == length, (z.shape[-1], length)
465
-
466
- if not self.last: z = F.gelu(z)
467
- return z, y
468
-
469
- class HDemucs(nn.Module):
470
- @capture_init
471
- def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=6, rewrite=True, hybrid=True, hybrid_old=False, multi_freqs=None, multi_freqs_depth=2, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, rescale=0.1, samplerate=44100, segment=4 * 10):
472
- super().__init__()
473
- self.cac = cac
474
- self.wiener_residual = wiener_residual
475
- self.audio_channels = audio_channels
476
- self.sources = sources
477
- self.kernel_size = kernel_size
478
- self.context = context
479
- self.stride = stride
480
- self.depth = depth
481
- self.channels = channels
482
- self.samplerate = samplerate
483
- self.segment = segment
484
- self.nfft = nfft
485
- self.hop_length = nfft // 4
486
- self.wiener_iters = wiener_iters
487
- self.end_iters = end_iters
488
- self.freq_emb = None
489
- self.hybrid = hybrid
490
- self.hybrid_old = hybrid_old
491
- if hybrid_old: assert hybrid
492
- if hybrid: assert wiener_iters == end_iters
493
- self.encoder = nn.ModuleList()
494
- self.decoder = nn.ModuleList()
495
-
496
- if hybrid:
497
- self.tencoder = nn.ModuleList()
498
- self.tdecoder = nn.ModuleList()
499
-
500
- chin = audio_channels
501
- chin_z = chin
502
-
503
- if self.cac: chin_z *= 2
504
-
505
- chout = channels_time or channels
506
- chout_z = channels
507
- freqs = nfft // 2
508
-
509
- for index in range(depth):
510
- lstm = index >= dconv_lstm
511
- attn = index >= dconv_attn
512
- norm = index >= norm_starts
513
- freq = freqs > 1
514
- stri = stride
515
- ker = kernel_size
516
-
517
- if not freq:
518
- assert freqs == 1
519
-
520
- ker = time_stride * 2
521
- stri = time_stride
522
-
523
- pad = True
524
- last_freq = False
525
-
526
- if freq and freqs <= kernel_size:
527
- ker = freqs
528
- pad = False
529
- last_freq = True
530
-
531
- kw = {
532
- "kernel_size": ker,
533
- "stride": stri,
534
- "freq": freq,
535
- "pad": pad,
536
- "norm": norm,
537
- "rewrite": rewrite,
538
- "norm_groups": norm_groups,
539
- "dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
540
- }
541
-
542
- kwt = dict(kw)
543
- kwt["freq"] = 0
544
- kwt["kernel_size"] = kernel_size
545
- kwt["stride"] = stride
546
- kwt["pad"] = True
547
- kw_dec = dict(kw)
548
-
549
- multi = False
550
-
551
- if multi_freqs and index < multi_freqs_depth:
552
- multi = True
553
- kw_dec["context_freq"] = False
554
-
555
- if last_freq:
556
- chout_z = max(chout, chout_z)
557
- chout = chout_z
558
-
559
- enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
560
- if hybrid and freq:
561
- tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
562
- self.tencoder.append(tenc)
563
-
564
- if multi: enc = MultiWrap(enc, multi_freqs)
565
-
566
- self.encoder.append(enc)
567
- if index == 0:
568
- chin = self.audio_channels * len(self.sources)
569
- chin_z = chin
570
-
571
- if self.cac: chin_z *= 2
572
-
573
- dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
574
- if multi: dec = MultiWrap(dec, multi_freqs)
575
-
576
- if hybrid and freq:
577
- tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
578
- self.tdecoder.insert(0, tdec)
579
-
580
- self.decoder.insert(0, dec)
581
- chin = chout
582
- chin_z = chout_z
583
- chout = int(growth * chout)
584
- chout_z = int(growth * chout_z)
585
-
586
- if freq:
587
- if freqs <= kernel_size: freqs = 1
588
- else: freqs //= stride
589
-
590
- if index == 0 and freq_emb:
591
- self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
592
- self.freq_emb_scale = freq_emb
593
-
594
- if rescale: rescale_module(self, reference=rescale)
595
-
596
- def _spec(self, x):
597
- hl = self.hop_length
598
- nfft = self.nfft
599
-
600
- if self.hybrid:
601
- assert hl == nfft // 4
602
- le = int(math.ceil(x.shape[-1] / hl))
603
- pad = hl // 2 * 3
604
- x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") if not self.hybrid_old else pad1d(x, (pad, pad + le * hl - x.shape[-1]))
605
-
606
- z = spectro(x, nfft, hl)[..., :-1, :]
607
- if self.hybrid:
608
- assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
609
- z = z[..., 2 : 2 + le]
610
-
611
- return z
612
-
613
- def _ispec(self, z, length=None, scale=0):
614
- hl = self.hop_length // (4**scale)
615
- z = F.pad(z, (0, 0, 0, 1))
616
-
617
- if self.hybrid:
618
- z = F.pad(z, (2, 2))
619
- pad = hl // 2 * 3
620
- le = hl * int(math.ceil(length / hl)) + 2 * pad if not self.hybrid_old else hl * int(math.ceil(length / hl))
621
- x = ispectro(z, hl, length=le)
622
- x = x[..., pad : pad + length] if not self.hybrid_old else x[..., :length]
623
- else: x = ispectro(z, hl, length)
624
-
625
- return x
626
-
627
- def _magnitude(self, z):
628
- if self.cac:
629
- B, C, Fr, T = z.shape
630
- m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
631
- m = m.reshape(B, C * 2, Fr, T)
632
- else: m = z.abs()
633
-
634
- return m
635
-
636
- def _mask(self, z, m):
637
- niters = self.wiener_iters
638
- if self.cac:
639
- B, S, C, Fr, T = m.shape
640
- out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
641
- out = torch.view_as_complex(out.contiguous())
642
- return out
643
-
644
- if self.training: niters = self.end_iters
645
-
646
- if niters < 0:
647
- z = z[:, None]
648
- return z / (1e-8 + z.abs()) * m
649
- else: return self._wiener(m, z, niters)
650
-
651
- def _wiener(self, mag_out, mix_stft, niters):
652
- init = mix_stft.dtype
653
- wiener_win_len = 300
654
- residual = self.wiener_residual
655
- B, S, C, Fq, T = mag_out.shape
656
- mag_out = mag_out.permute(0, 4, 3, 2, 1)
657
- mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
658
- outs = []
659
-
660
- for sample in range(B):
661
- pos = 0
662
- out = []
663
-
664
- for pos in range(0, T, wiener_win_len):
665
- frame = slice(pos, pos + wiener_win_len)
666
- z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
667
- out.append(z_out.transpose(-1, -2))
668
-
669
- outs.append(torch.cat(out, dim=0))
670
-
671
- out = torch.view_as_complex(torch.stack(outs, 0))
672
- out = out.permute(0, 4, 3, 2, 1).contiguous()
673
-
674
- if residual: out = out[:, :-1]
675
- assert list(out.shape) == [B, S, C, Fq, T]
676
- return out.to(init)
677
-
678
- def forward(self, mix):
679
- x = mix
680
- length = x.shape[-1]
681
- z = self._spec(mix)
682
- mag = self._magnitude(z).to(mix.device)
683
- x = mag
684
- B, C, Fq, T = x.shape
685
- mean = x.mean(dim=(1, 2, 3), keepdim=True)
686
- std = x.std(dim=(1, 2, 3), keepdim=True)
687
- x = (x - mean) / (1e-5 + std)
688
-
689
- if self.hybrid:
690
- xt = mix
691
- meant = xt.mean(dim=(1, 2), keepdim=True)
692
- stdt = xt.std(dim=(1, 2), keepdim=True)
693
- xt = (xt - meant) / (1e-5 + stdt)
694
-
695
- saved, saved_t, lengths, lengths_t = [], [], [], []
696
-
697
- for idx, encode in enumerate(self.encoder):
698
- lengths.append(x.shape[-1])
699
- inject = None
700
-
701
- if self.hybrid and idx < len(self.tencoder):
702
- lengths_t.append(xt.shape[-1])
703
- tenc = self.tencoder[idx]
704
- xt = tenc(xt)
705
-
706
- if not tenc.empty: saved_t.append(xt)
707
- else: inject = xt
708
-
709
- x = encode(x, inject)
710
-
711
- if idx == 0 and self.freq_emb is not None:
712
- frs = torch.arange(x.shape[-2], device=x.device)
713
- emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
714
- x = x + self.freq_emb_scale * emb
715
-
716
- saved.append(x)
717
-
718
- x = torch.zeros_like(x)
719
- if self.hybrid: xt = torch.zeros_like(x)
720
-
721
- for idx, decode in enumerate(self.decoder):
722
- skip = saved.pop(-1)
723
- x, pre = decode(x, skip, lengths.pop(-1))
724
-
725
- if self.hybrid: offset = self.depth - len(self.tdecoder)
726
-
727
- if self.hybrid and idx >= offset:
728
- tdec = self.tdecoder[idx - offset]
729
- length_t = lengths_t.pop(-1)
730
-
731
- if tdec.empty:
732
- assert pre.shape[2] == 1, pre.shape
733
-
734
- pre = pre[:, :, 0]
735
- xt, _ = tdec(pre, None, length_t)
736
- else:
737
- skip = saved_t.pop(-1)
738
- xt, _ = tdec(xt, skip, length_t)
739
-
740
- assert len(saved) == 0
741
- assert len(lengths_t) == 0
742
- assert len(saved_t) == 0
743
-
744
- S = len(self.sources)
745
- x = x.view(B, S, -1, Fq, T)
746
- x = x * std[:, None] + mean[:, None]
747
- device_type = x.device.type
748
- device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
749
- x_is_other_gpu = not device_type in ["cuda", "cpu"]
750
- if x_is_other_gpu: x = x.cpu()
751
- zout = self._mask(z, x)
752
- x = self._ispec(zout, length)
753
- if x_is_other_gpu: x = x.to(device_load)
754
-
755
- if self.hybrid:
756
- xt = xt.view(B, S, -1, length)
757
- xt = xt * stdt[:, None] + meant[:, None]
758
- x = xt + x
759
-
760
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/htdemucs.py DELETED
@@ -1,600 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
- import random
6
-
7
- import numpy as np
8
-
9
- from torch import nn
10
- from einops import rearrange
11
- from fractions import Fraction
12
- from torch.nn import functional as F
13
-
14
- sys.path.append(os.getcwd())
15
-
16
- from .states import capture_init
17
- from .demucs import rescale_module
18
- from main.configs.config import Config
19
- from .hdemucs import pad1d, spectro, ispectro, wiener, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
20
-
21
- translations = Config().translations
22
-
23
- def create_sin_embedding(length, dim, shift = 0, device="cpu", max_period=10000):
24
- assert dim % 2 == 0
25
- pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
26
- half_dim = dim // 2
27
- adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
28
- phase = pos / (max_period ** (adim / (half_dim - 1)))
29
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
30
-
31
- def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
32
- if d_model % 4 != 0: raise ValueError(translations["dims"].format(dims=d_model))
33
- pe = torch.zeros(d_model, height, width)
34
- d_model = int(d_model / 2)
35
- div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model))
36
- pos_w = torch.arange(0.0, width).unsqueeze(1)
37
- pos_h = torch.arange(0.0, height).unsqueeze(1)
38
- pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
39
- pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
40
- pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
41
- pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
42
-
43
- return pe[None, :].to(device)
44
-
45
- def create_sin_embedding_cape(length, dim, batch_size, mean_normalize, augment, max_global_shift = 0.0, max_local_shift = 0.0, max_scale = 1.0, device = "cpu", max_period = 10000.0):
46
- assert dim % 2 == 0
47
- pos = 1.0 * torch.arange(length).view(-1, 1, 1)
48
- pos = pos.repeat(1, batch_size, 1)
49
- if mean_normalize: pos -= torch.nanmean(pos, dim=0, keepdim=True)
50
-
51
- if augment:
52
- delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1])
53
- delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1])
54
- log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1])
55
- pos = (pos + delta + delta_local) * np.exp(log_lambdas)
56
-
57
- pos = pos.to(device)
58
- half_dim = dim // 2
59
- adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
60
- phase = pos / (max_period ** (adim / (half_dim - 1)))
61
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float()
62
-
63
- class MyGroupNorm(nn.GroupNorm):
64
- def __init__(self, *args, **kwargs):
65
- super().__init__(*args, **kwargs)
66
-
67
- def forward(self, x):
68
- x = x.transpose(1, 2)
69
- return super().forward(x).transpose(1, 2)
70
-
71
- class LayerScale(nn.Module):
72
- def __init__(self, channels, init = 0, channel_last=False):
73
- super().__init__()
74
- self.channel_last = channel_last
75
- self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
76
- self.scale.data[:] = init
77
-
78
- def forward(self, x):
79
- if self.channel_last: return self.scale * x
80
- else: return self.scale[:, None] * x
81
-
82
- class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
83
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, group_norm=0, norm_first=False, norm_out=False, layer_norm_eps=1e-5, layer_scale=False, init_values=1e-4, device=None, dtype=None, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, auto_sparsity=False, sparsity=0.95, batch_first=False):
84
- factory_kwargs = {"device": device, "dtype": dtype}
85
- super().__init__(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, device=device, dtype=dtype)
86
- self.auto_sparsity = auto_sparsity
87
-
88
- if group_norm:
89
- self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
90
- self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
91
-
92
- self.norm_out = None
93
- if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
94
-
95
- self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
96
- self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
97
-
98
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
99
- x = src
100
- T, B, C = x.shape
101
-
102
- if self.norm_first:
103
- x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
104
- x = x + self.gamma_2(self._ff_block(self.norm2(x)))
105
- if self.norm_out: x = self.norm_out(x)
106
- else:
107
- x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)))
108
- x = self.norm2(x + self.gamma_2(self._ff_block(x)))
109
-
110
- return x
111
-
112
- class CrossTransformerEncoder(nn.Module):
113
- def __init__(self, dim, emb = "sin", hidden_scale = 4.0, num_heads = 8, num_layers = 6, cross_first = False, dropout = 0.0, max_positions = 1000, norm_in = True, norm_in_group = False, group_norm = False, norm_first = False, norm_out = False, max_period = 10000.0, weight_decay = 0.0, lr = None, layer_scale = False, gelu = True, sin_random_shift = 0, weight_pos_embed = 1.0, cape_mean_normalize = True, cape_augment = True, cape_glob_loc_scale = [5000.0, 1.0, 1.4], sparse_self_attn = False, sparse_cross_attn = False, mask_type = "diag", mask_random_seed = 42, sparse_attn_window = 500, global_window = 50, auto_sparsity = False, sparsity = 0.95):
114
- super().__init__()
115
- assert dim % num_heads == 0
116
- hidden_dim = int(dim * hidden_scale)
117
- self.num_layers = num_layers
118
- self.classic_parity = 1 if cross_first else 0
119
- self.emb = emb
120
- self.max_period = max_period
121
- self.weight_decay = weight_decay
122
- self.weight_pos_embed = weight_pos_embed
123
- self.sin_random_shift = sin_random_shift
124
-
125
- if emb == "cape":
126
- self.cape_mean_normalize = cape_mean_normalize
127
- self.cape_augment = cape_augment
128
- self.cape_glob_loc_scale = cape_glob_loc_scale
129
-
130
- if emb == "scaled": self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
131
-
132
- self.lr = lr
133
- activation = F.gelu if gelu else F.relu
134
-
135
- if norm_in:
136
- self.norm_in = nn.LayerNorm(dim)
137
- self.norm_in_t = nn.LayerNorm(dim)
138
- elif norm_in_group:
139
- self.norm_in = MyGroupNorm(int(norm_in_group), dim)
140
- self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
141
- else:
142
- self.norm_in = nn.Identity()
143
- self.norm_in_t = nn.Identity()
144
-
145
- self.layers = nn.ModuleList()
146
- self.layers_t = nn.ModuleList()
147
-
148
- kwargs_common = {
149
- "d_model": dim,
150
- "nhead": num_heads,
151
- "dim_feedforward": hidden_dim,
152
- "dropout": dropout,
153
- "activation": activation,
154
- "group_norm": group_norm,
155
- "norm_first": norm_first,
156
- "norm_out": norm_out,
157
- "layer_scale": layer_scale,
158
- "mask_type": mask_type,
159
- "mask_random_seed": mask_random_seed,
160
- "sparse_attn_window": sparse_attn_window,
161
- "global_window": global_window,
162
- "sparsity": sparsity,
163
- "auto_sparsity": auto_sparsity,
164
- "batch_first": True,
165
- }
166
-
167
- kwargs_classic_encoder = dict(kwargs_common)
168
- kwargs_classic_encoder.update({"sparse": sparse_self_attn})
169
- kwargs_cross_encoder = dict(kwargs_common)
170
- kwargs_cross_encoder.update({"sparse": sparse_cross_attn})
171
-
172
- for idx in range(num_layers):
173
- if idx % 2 == self.classic_parity:
174
- self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
175
- self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
176
- else:
177
- self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
178
- self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
179
-
180
- def forward(self, x, xt):
181
- B, C, Fr, T1 = x.shape
182
-
183
- pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period)
184
- pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
185
-
186
- x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
187
- x = self.norm_in(x)
188
- x = x + self.weight_pos_embed * pos_emb_2d
189
-
190
- B, C, T2 = xt.shape
191
- xt = rearrange(xt, "b c t2 -> b t2 c")
192
-
193
- pos_emb = self._get_pos_embedding(T2, B, C, x.device)
194
- pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
195
-
196
- xt = self.norm_in_t(xt)
197
- xt = xt + self.weight_pos_embed * pos_emb
198
-
199
- for idx in range(self.num_layers):
200
- if idx % 2 == self.classic_parity:
201
- x = self.layers[idx](x)
202
- xt = self.layers_t[idx](xt)
203
- else:
204
- old_x = x
205
- x = self.layers[idx](x, xt)
206
- xt = self.layers_t[idx](xt, old_x)
207
-
208
- x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
209
- xt = rearrange(xt, "b t2 c -> b c t2")
210
- return x, xt
211
-
212
- def _get_pos_embedding(self, T, B, C, device):
213
- if self.emb == "sin":
214
- shift = random.randrange(self.sin_random_shift + 1)
215
- pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period)
216
- elif self.emb == "cape":
217
- if self.training: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=self.cape_augment, max_global_shift=self.cape_glob_loc_scale[0], max_local_shift=self.cape_glob_loc_scale[1], max_scale=self.cape_glob_loc_scale[2])
218
- else: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False)
219
- elif self.emb == "scaled":
220
- pos = torch.arange(T, device=device)
221
- pos_emb = self.position_embeddings(pos)[:, None]
222
-
223
- return pos_emb
224
-
225
- def make_optim_group(self):
226
- group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
227
- if self.lr is not None: group["lr"] = self.lr
228
- return group
229
-
230
- class CrossTransformerEncoderLayer(nn.Module):
231
- def __init__(self, d_model, nhead, dim_feedforward = 2048, dropout = 0.1, activation=F.relu, layer_norm_eps = 1e-5, layer_scale = False, init_values = 1e-4, norm_first = False, group_norm = False, norm_out = False, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, sparsity=0.95, auto_sparsity=None, device=None, dtype=None, batch_first=False):
232
- factory_kwargs = {"device": device, "dtype": dtype}
233
- super().__init__()
234
- self.auto_sparsity = auto_sparsity
235
- self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
236
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
237
- self.dropout = nn.Dropout(dropout)
238
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
239
- self.norm_first = norm_first
240
-
241
- if group_norm:
242
- self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
243
- self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
244
- self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
245
- else:
246
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
247
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
248
- self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
249
-
250
- self.norm_out = None
251
- if self.norm_first & norm_out:
252
- self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
253
-
254
- self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
255
- self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
256
- self.dropout1 = nn.Dropout(dropout)
257
- self.dropout2 = nn.Dropout(dropout)
258
-
259
- if isinstance(activation, str): self.activation = self._get_activation_fn(activation)
260
- else: self.activation = activation
261
-
262
- def forward(self, q, k, mask=None):
263
- if self.norm_first:
264
- x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
265
- x = x + self.gamma_2(self._ff_block(self.norm3(x)))
266
-
267
- if self.norm_out: x = self.norm_out(x)
268
- else:
269
- x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
270
- x = self.norm2(x + self.gamma_2(self._ff_block(x)))
271
-
272
- return x
273
-
274
- def _ca_block(self, q, k, attn_mask=None):
275
- x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
276
- return self.dropout1(x)
277
-
278
- def _ff_block(self, x):
279
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
280
- return self.dropout2(x)
281
-
282
- def _get_activation_fn(self, activation):
283
- if activation == "relu": return F.relu
284
- elif activation == "gelu": return F.gelu
285
- raise RuntimeError(translations["activation"].format(activation=activation))
286
-
287
- class HTDemucs(nn.Module):
288
- @capture_init
289
- def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=4, rewrite=True, multi_freqs=None, multi_freqs_depth=3, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=8, dconv_init=1e-3, bottom_channels=0, t_layers=5, t_emb="sin", t_hidden_scale=4.0, t_heads=8, t_dropout=0.0, t_max_positions=10000, t_norm_in=True, t_norm_in_group=False, t_group_norm=False, t_norm_first=True, t_norm_out=True, t_max_period=10000.0, t_weight_decay=0.0, t_lr=None, t_layer_scale=True, t_gelu=True, t_weight_pos_embed=1.0, t_sin_random_shift=0, t_cape_mean_normalize=True, t_cape_augment=True, t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], t_sparse_self_attn=False, t_sparse_cross_attn=False, t_mask_type="diag", t_mask_random_seed=42, t_sparse_attn_window=500, t_global_window=100, t_sparsity=0.95, t_auto_sparsity=False, t_cross_first=False, rescale=0.1, samplerate=44100, segment=4 * 10, use_train_segment=True):
290
- super().__init__()
291
- self.cac = cac
292
- self.wiener_residual = wiener_residual
293
- self.audio_channels = audio_channels
294
- self.sources = sources
295
- self.kernel_size = kernel_size
296
- self.context = context
297
- self.stride = stride
298
- self.depth = depth
299
- self.bottom_channels = bottom_channels
300
- self.channels = channels
301
- self.samplerate = samplerate
302
- self.segment = segment
303
- self.use_train_segment = use_train_segment
304
- self.nfft = nfft
305
- self.hop_length = nfft // 4
306
- self.wiener_iters = wiener_iters
307
- self.end_iters = end_iters
308
- self.freq_emb = None
309
- assert wiener_iters == end_iters
310
- self.encoder = nn.ModuleList()
311
- self.decoder = nn.ModuleList()
312
- self.tencoder = nn.ModuleList()
313
- self.tdecoder = nn.ModuleList()
314
- chin = audio_channels
315
- chin_z = chin
316
- if self.cac: chin_z *= 2
317
- chout = channels_time or channels
318
- chout_z = channels
319
- freqs = nfft // 2
320
-
321
- for index in range(depth):
322
- norm = index >= norm_starts
323
- freq = freqs > 1
324
- stri = stride
325
- ker = kernel_size
326
-
327
- if not freq:
328
- assert freqs == 1
329
- ker = time_stride * 2
330
- stri = time_stride
331
-
332
- pad = True
333
- last_freq = False
334
-
335
- if freq and freqs <= kernel_size:
336
- ker = freqs
337
- pad = False
338
- last_freq = True
339
-
340
- kw = {
341
- "kernel_size": ker,
342
- "stride": stri,
343
- "freq": freq,
344
- "pad": pad,
345
- "norm": norm,
346
- "rewrite": rewrite,
347
- "norm_groups": norm_groups,
348
- "dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
349
- }
350
-
351
- kwt = dict(kw)
352
- kwt["freq"] = 0
353
- kwt["kernel_size"] = kernel_size
354
- kwt["stride"] = stride
355
- kwt["pad"] = True
356
- kw_dec = dict(kw)
357
- multi = False
358
-
359
- if multi_freqs and index < multi_freqs_depth:
360
- multi = True
361
- kw_dec["context_freq"] = False
362
-
363
- if last_freq:
364
- chout_z = max(chout, chout_z)
365
- chout = chout_z
366
-
367
- enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
368
- if freq:
369
- tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
370
- self.tencoder.append(tenc)
371
-
372
- if multi: enc = MultiWrap(enc, multi_freqs)
373
-
374
- self.encoder.append(enc)
375
- if index == 0:
376
- chin = self.audio_channels * len(self.sources)
377
- chin_z = chin
378
- if self.cac: chin_z *= 2
379
-
380
- dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
381
- if multi: dec = MultiWrap(dec, multi_freqs)
382
-
383
- if freq:
384
- tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
385
- self.tdecoder.insert(0, tdec)
386
-
387
- self.decoder.insert(0, dec)
388
- chin = chout
389
- chin_z = chout_z
390
- chout = int(growth * chout)
391
- chout_z = int(growth * chout_z)
392
-
393
- if freq:
394
- if freqs <= kernel_size: freqs = 1
395
- else: freqs //= stride
396
-
397
- if index == 0 and freq_emb:
398
- self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
399
- self.freq_emb_scale = freq_emb
400
-
401
- if rescale: rescale_module(self, reference=rescale)
402
- transformer_channels = channels * growth ** (depth - 1)
403
-
404
- if bottom_channels:
405
- self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
406
- self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1)
407
- self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1)
408
- self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1)
409
- transformer_channels = bottom_channels
410
-
411
- if t_layers > 0: self.crosstransformer = CrossTransformerEncoder(dim=transformer_channels, emb=t_emb, hidden_scale=t_hidden_scale, num_heads=t_heads, num_layers=t_layers, cross_first=t_cross_first, dropout=t_dropout, max_positions=t_max_positions, norm_in=t_norm_in, norm_in_group=t_norm_in_group, group_norm=t_group_norm, norm_first=t_norm_first, norm_out=t_norm_out, max_period=t_max_period, weight_decay=t_weight_decay, lr=t_lr, layer_scale=t_layer_scale, gelu=t_gelu, sin_random_shift=t_sin_random_shift, weight_pos_embed=t_weight_pos_embed, cape_mean_normalize=t_cape_mean_normalize, cape_augment=t_cape_augment, cape_glob_loc_scale=t_cape_glob_loc_scale, sparse_self_attn=t_sparse_self_attn, sparse_cross_attn=t_sparse_cross_attn, mask_type=t_mask_type, mask_random_seed=t_mask_random_seed, sparse_attn_window=t_sparse_attn_window, global_window=t_global_window, sparsity=t_sparsity, auto_sparsity=t_auto_sparsity)
412
- else: self.crosstransformer = None
413
-
414
- def _spec(self, x):
415
- hl = self.hop_length
416
- nfft = self.nfft
417
- assert hl == nfft // 4
418
- le = int(math.ceil(x.shape[-1] / hl))
419
- pad = hl // 2 * 3
420
- x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
421
- z = spectro(x, nfft, hl)[..., :-1, :]
422
- assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
423
- z = z[..., 2 : 2 + le]
424
- return z
425
-
426
- def _ispec(self, z, length=None, scale=0):
427
- hl = self.hop_length // (4**scale)
428
- z = F.pad(z, (0, 0, 0, 1))
429
- z = F.pad(z, (2, 2))
430
- pad = hl // 2 * 3
431
- le = hl * int(math.ceil(length / hl)) + 2 * pad
432
- x = ispectro(z, hl, length=le)
433
- x = x[..., pad : pad + length]
434
- return x
435
-
436
- def _magnitude(self, z):
437
- if self.cac:
438
- B, C, Fr, T = z.shape
439
- m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
440
- m = m.reshape(B, C * 2, Fr, T)
441
- else: m = z.abs()
442
- return m
443
-
444
- def _mask(self, z, m):
445
- niters = self.wiener_iters
446
- if self.cac:
447
- B, S, C, Fr, T = m.shape
448
- out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
449
- out = torch.view_as_complex(out.contiguous())
450
- return out
451
-
452
- if self.training: niters = self.end_iters
453
-
454
- if niters < 0:
455
- z = z[:, None]
456
- return z / (1e-8 + z.abs()) * m
457
- else: return self._wiener(m, z, niters)
458
-
459
- def _wiener(self, mag_out, mix_stft, niters):
460
- init = mix_stft.dtype
461
- wiener_win_len = 300
462
- residual = self.wiener_residual
463
- B, S, C, Fq, T = mag_out.shape
464
- mag_out = mag_out.permute(0, 4, 3, 2, 1)
465
- mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
466
-
467
- outs = []
468
-
469
- for sample in range(B):
470
- pos = 0
471
- out = []
472
-
473
- for pos in range(0, T, wiener_win_len):
474
- frame = slice(pos, pos + wiener_win_len)
475
- z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
476
- out.append(z_out.transpose(-1, -2))
477
-
478
- outs.append(torch.cat(out, dim=0))
479
-
480
- out = torch.view_as_complex(torch.stack(outs, 0))
481
- out = out.permute(0, 4, 3, 2, 1).contiguous()
482
-
483
- if residual: out = out[:, :-1]
484
- assert list(out.shape) == [B, S, C, Fq, T]
485
- return out.to(init)
486
-
487
- def valid_length(self, length):
488
- if not self.use_train_segment: return length
489
-
490
- training_length = int(self.segment * self.samplerate)
491
- if training_length < length: raise ValueError(translations["length_or_training_length"].format(length=length, training_length=training_length))
492
-
493
- return training_length
494
-
495
- def forward(self, mix):
496
- length = mix.shape[-1]
497
- length_pre_pad = None
498
-
499
- if self.use_train_segment:
500
- if self.training: self.segment = Fraction(mix.shape[-1], self.samplerate)
501
- else:
502
- training_length = int(self.segment * self.samplerate)
503
-
504
- if mix.shape[-1] < training_length:
505
- length_pre_pad = mix.shape[-1]
506
- mix = F.pad(mix, (0, training_length - length_pre_pad))
507
-
508
- z = self._spec(mix)
509
- mag = self._magnitude(z).to(mix.device)
510
- x = mag
511
- B, C, Fq, T = x.shape
512
- mean = x.mean(dim=(1, 2, 3), keepdim=True)
513
- std = x.std(dim=(1, 2, 3), keepdim=True)
514
- x = (x - mean) / (1e-5 + std)
515
- xt = mix
516
- meant = xt.mean(dim=(1, 2), keepdim=True)
517
- stdt = xt.std(dim=(1, 2), keepdim=True)
518
- xt = (xt - meant) / (1e-5 + stdt)
519
-
520
- saved, saved_t, lengths, lengths_t = [], [], [], []
521
-
522
- for idx, encode in enumerate(self.encoder):
523
- lengths.append(x.shape[-1])
524
- inject = None
525
-
526
- if idx < len(self.tencoder):
527
- lengths_t.append(xt.shape[-1])
528
- tenc = self.tencoder[idx]
529
- xt = tenc(xt)
530
-
531
- if not tenc.empty: saved_t.append(xt)
532
- else: inject = xt
533
-
534
- x = encode(x, inject)
535
- if idx == 0 and self.freq_emb is not None:
536
- frs = torch.arange(x.shape[-2], device=x.device)
537
- emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
538
- x = x + self.freq_emb_scale * emb
539
-
540
- saved.append(x)
541
-
542
- if self.crosstransformer:
543
- if self.bottom_channels:
544
- b, c, f, t = x.shape
545
- x = rearrange(x, "b c f t-> b c (f t)")
546
- x = self.channel_upsampler(x)
547
- x = rearrange(x, "b c (f t)-> b c f t", f=f)
548
- xt = self.channel_upsampler_t(xt)
549
-
550
- x, xt = self.crosstransformer(x, xt)
551
-
552
- if self.bottom_channels:
553
- x = rearrange(x, "b c f t-> b c (f t)")
554
- x = self.channel_downsampler(x)
555
- x = rearrange(x, "b c (f t)-> b c f t", f=f)
556
- xt = self.channel_downsampler_t(xt)
557
-
558
- for idx, decode in enumerate(self.decoder):
559
- skip = saved.pop(-1)
560
- x, pre = decode(x, skip, lengths.pop(-1))
561
- offset = self.depth - len(self.tdecoder)
562
-
563
- if idx >= offset:
564
- tdec = self.tdecoder[idx - offset]
565
- length_t = lengths_t.pop(-1)
566
-
567
- if tdec.empty:
568
- assert pre.shape[2] == 1, pre.shape
569
- pre = pre[:, :, 0]
570
- xt, _ = tdec(pre, None, length_t)
571
- else:
572
- skip = saved_t.pop(-1)
573
- xt, _ = tdec(xt, skip, length_t)
574
-
575
- assert len(saved) == 0
576
- assert len(lengths_t) == 0
577
- assert len(saved_t) == 0
578
-
579
- S = len(self.sources)
580
- x = x.view(B, S, -1, Fq, T)
581
- x = x * std[:, None] + mean[:, None]
582
- device_type = x.device.type
583
- device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
584
- x_is_other_gpu = not device_type in ["cuda", "cpu"]
585
- if x_is_other_gpu: x = x.cpu()
586
- zout = self._mask(z, x)
587
-
588
- if self.use_train_segment: x = self._ispec(zout, length) if self.training else self._ispec(zout, training_length)
589
- else: x = self._ispec(zout, length)
590
-
591
- if x_is_other_gpu: x = x.to(device_load)
592
-
593
- if self.use_train_segment: xt = xt.view(B, S, -1, length) if self.training else xt.view(B, S, -1, training_length)
594
- else: xt = xt.view(B, S, -1, length)
595
-
596
- xt = xt * stdt[:, None] + meant[:, None]
597
- x = xt + x
598
-
599
- if length_pre_pad: x = x[..., :length_pre_pad]
600
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/states.py DELETED
@@ -1,55 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
- import inspect
5
- import warnings
6
- import functools
7
-
8
- from pathlib import Path
9
-
10
- sys.path.append(os.getcwd())
11
-
12
- from main.configs.config import Config
13
- translations = Config().translations
14
-
15
- def load_model(path_or_package, strict=False):
16
- if isinstance(path_or_package, dict): package = path_or_package
17
- elif isinstance(path_or_package, (str, Path)):
18
- with warnings.catch_warnings():
19
- warnings.simplefilter("ignore")
20
- package = torch.load(path_or_package, map_location="cpu")
21
- else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
22
- klass = package["klass"]
23
- args = package["args"]
24
- kwargs = package["kwargs"]
25
- if strict: model = klass(*args, **kwargs)
26
- else:
27
- sig = inspect.signature(klass)
28
- for key in list(kwargs):
29
- if key not in sig.parameters:
30
- warnings.warn(translations["del_parameter"] + key)
31
- del kwargs[key]
32
- model = klass(*args, **kwargs)
33
- state = package["state"]
34
- set_state(model, state)
35
- return model
36
-
37
- def restore_quantized_state(model, state):
38
- assert "meta" in state
39
- quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"])
40
- quantizer.restore_quantized_state(state)
41
- quantizer.detach()
42
-
43
- def set_state(model, state, quantizer=None):
44
- if state.get("__quantized"):
45
- if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
46
- else: restore_quantized_state(model, state)
47
- else: model.load_state_dict(state)
48
- return state
49
-
50
- def capture_init(init):
51
- @functools.wraps(init)
52
- def __init__(self, *args, **kwargs):
53
- self._init_args_kwargs = (args, kwargs)
54
- init(self, *args, **kwargs)
55
- return __init__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/utils.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
-
3
- def center_trim(tensor, reference):
4
- ref_size = reference.size(-1) if isinstance(reference, torch.Tensor) else reference
5
- delta = tensor.size(-1) - ref_size
6
- if delta < 0: raise ValueError(f"tensor > parameter: {delta}.")
7
- if delta: tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
8
- return tensor
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/spec_utils.py DELETED
@@ -1,900 +0,0 @@
1
- import os
2
- import six
3
- import sys
4
- import librosa
5
- import tempfile
6
- import platform
7
- import subprocess
8
-
9
- import numpy as np
10
- import soundfile as sf
11
-
12
- from scipy.signal import correlate, hilbert
13
-
14
- sys.path.append(os.getcwd())
15
-
16
- from main.configs.config import Config
17
- translations = Config().translations
18
-
19
- OPERATING_SYSTEM = platform.system()
20
- SYSTEM_ARCH = platform.platform()
21
- SYSTEM_PROC = platform.processor()
22
- ARM = "arm"
23
- AUTO_PHASE = "Automatic"
24
- POSITIVE_PHASE = "Positive Phase"
25
- NEGATIVE_PHASE = "Negative Phase"
26
- NONE_P = ("None",)
27
- LOW_P = ("Shifts: Low",)
28
- MED_P = ("Shifts: Medium",)
29
- HIGH_P = ("Shifts: High",)
30
- VHIGH_P = "Shifts: Very High"
31
- MAXIMUM_P = "Shifts: Maximum"
32
- BASE_PATH_RUB = sys._MEIPASS if getattr(sys, 'frozen', False) else os.path.dirname(os.path.abspath(__file__))
33
- DEVNULL = open(os.devnull, 'w') if six.PY2 else subprocess.DEVNULL
34
- MAX_SPEC = "Max Spec"
35
- MIN_SPEC = "Min Spec"
36
- LIN_ENSE = "Linear Ensemble"
37
- MAX_WAV = MAX_SPEC
38
- MIN_WAV = MIN_SPEC
39
- AVERAGE = "Average"
40
-
41
- progress_value, last_update_time = 0, 0
42
- wav_resolution = "sinc_fastest"
43
- wav_resolution_float_resampling = wav_resolution
44
-
45
- def crop_center(h1, h2):
46
- h1_shape = h1.size()
47
- h2_shape = h2.size()
48
-
49
- if h1_shape[3] == h2_shape[3]: return h1
50
- elif h1_shape[3] < h2_shape[3]: raise ValueError("h1_shape[3] > h2_shape[3]")
51
-
52
- s_time = (h1_shape[3] - h2_shape[3]) // 2
53
-
54
- h1 = h1[:, :, :, s_time:s_time + h2_shape[3]]
55
- return h1
56
-
57
- def preprocess(X_spec):
58
- return np.abs(X_spec), np.angle(X_spec)
59
-
60
- def make_padding(width, cropsize, offset):
61
- roi_size = cropsize - offset * 2
62
-
63
- if roi_size == 0: roi_size = cropsize
64
- return offset, roi_size - (width % roi_size) + offset, roi_size
65
-
66
- def normalize(wave, max_peak=1.0):
67
- maxv = np.abs(wave).max()
68
-
69
- if maxv > max_peak: wave *= max_peak / maxv
70
- return wave
71
-
72
- def auto_transpose(audio_array):
73
- if audio_array.shape[1] == 2: return audio_array.T
74
- return audio_array
75
-
76
- def write_array_to_mem(audio_data, subtype):
77
- if isinstance(audio_data, np.ndarray):
78
- import io
79
-
80
- audio_buffer = io.BytesIO()
81
- sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
82
-
83
- audio_buffer.seek(0)
84
- return audio_buffer
85
- else: return audio_data
86
-
87
- def spectrogram_to_image(spec, mode="magnitude"):
88
- if mode == "magnitude": y = np.log10((np.abs(spec) if np.iscomplexobj(spec) else spec)**2 + 1e-8)
89
- elif mode == "phase": y = np.angle(spec) if np.iscomplexobj(spec) else spec
90
-
91
- y -= y.min()
92
- y *= 255 / y.max()
93
- img = np.uint8(y)
94
-
95
- if y.ndim == 3:
96
- img = img.transpose(1, 2, 0)
97
- img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
98
-
99
- return img
100
-
101
- def reduce_vocal_aggressively(X, y, softmask):
102
- y_mag_tmp = np.abs(y)
103
- v_mag_tmp = np.abs(X - y)
104
-
105
- return np.clip(y_mag_tmp - v_mag_tmp * (v_mag_tmp > y_mag_tmp) * softmask, 0, np.inf) * np.exp(1.0j * np.angle(y))
106
-
107
- def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
108
- mask = y_mask
109
-
110
- try:
111
- if min_range < fade_size * 2: raise ValueError("min_range >= fade_size * 2")
112
-
113
- idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
114
- start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
115
- end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
116
- artifact_idx = np.where(end_idx - start_idx > min_range)[0]
117
- weight = np.zeros_like(y_mask)
118
-
119
- if len(artifact_idx) > 0:
120
- start_idx = start_idx[artifact_idx]
121
- end_idx = end_idx[artifact_idx]
122
- old_e = None
123
-
124
- for s, e in zip(start_idx, end_idx):
125
- if old_e is not None and s - old_e < fade_size: s = old_e - fade_size * 2
126
-
127
- if s != 0: weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
128
- else: s -= fade_size
129
-
130
- if e != y_mask.shape[2]: weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
131
- else: e += fade_size
132
-
133
- weight[:, :, s + fade_size : e - fade_size] = 1
134
- old_e = e
135
-
136
- v_mask = 1 - y_mask
137
- y_mask += weight * v_mask
138
- mask = y_mask
139
- except Exception as e:
140
- import traceback
141
- print(translations["not_success"], f'{type(e).__name__}: "{e}"\n{traceback.format_exc()}"')
142
-
143
- return mask
144
-
145
- def align_wave_head_and_tail(a, b):
146
- l = min([a[0].size, b[0].size])
147
- return a[:l, :l], b[:l, :l]
148
-
149
- def convert_channels(spec, mp, band):
150
- cc = mp.param["band"][band].get("convert_channels")
151
-
152
- if "mid_side_c" == cc:
153
- spec_left = np.add(spec[0], spec[1] * 0.25)
154
- spec_right = np.subtract(spec[1], spec[0] * 0.25)
155
- elif "mid_side" == cc:
156
- spec_left = np.add(spec[0], spec[1]) / 2
157
- spec_right = np.subtract(spec[0], spec[1])
158
- elif "stereo_n" == cc:
159
- spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
160
- spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
161
- else: return spec
162
-
163
- return np.asfortranarray([spec_left, spec_right])
164
-
165
- def combine_spectrograms(specs, mp, is_v51_model=False):
166
- l = min([specs[i].shape[2] for i in specs])
167
- spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
168
- offset = 0
169
- bands_n = len(mp.param["band"])
170
-
171
- for d in range(1, bands_n + 1):
172
- h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
173
- spec_c[:, offset : offset + h, :l] = specs[d][:, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l]
174
- offset += h
175
-
176
- if offset > mp.param["bins"]: raise ValueError("offset > mp.param['bins']")
177
-
178
- if mp.param["pre_filter_start"] > 0:
179
- if is_v51_model: spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
180
- else:
181
- if bands_n == 1: spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
182
- else:
183
- import math
184
- gp = 1
185
-
186
- for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
187
- g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
188
- gp = g
189
- spec_c[:, b, :] *= g
190
-
191
- return np.asfortranarray(spec_c)
192
-
193
- def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
194
- if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
195
-
196
- if not is_v51_model:
197
- if mp.param["reverse"]:
198
- wave_left = np.flip(np.asfortranarray(wave[0]))
199
- wave_right = np.flip(np.asfortranarray(wave[1]))
200
- elif mp.param["mid_side"]:
201
- wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
202
- wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
203
- elif mp.param["mid_side_b2"]:
204
- wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
205
- wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
206
- else:
207
- wave_left = np.asfortranarray(wave[0])
208
- wave_right = np.asfortranarray(wave[1])
209
- else:
210
- wave_left = np.asfortranarray(wave[0])
211
- wave_right = np.asfortranarray(wave[1])
212
-
213
- spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
214
- spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
215
-
216
- spec = np.asfortranarray([spec_left, spec_right])
217
-
218
- if is_v51_model: spec = convert_channels(spec, mp, band)
219
- return spec
220
-
221
- def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
222
- spec_left = np.asfortranarray(spec[0])
223
- spec_right = np.asfortranarray(spec[1])
224
-
225
- wave_left = librosa.istft(spec_left, hop_length=hop_length)
226
- wave_right = librosa.istft(spec_right, hop_length=hop_length)
227
-
228
- if is_v51_model:
229
- cc = mp.param["band"][band].get("convert_channels")
230
-
231
- if "mid_side_c" == cc: return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)])
232
- elif "mid_side" == cc: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
233
- elif "stereo_n" == cc: return np.asfortranarray([np.subtract(wave_left, wave_right * 0.25), np.subtract(wave_right, wave_left * 0.25)])
234
- else:
235
- if mp.param["reverse"]: return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
236
- elif mp.param["mid_side"]: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
237
- elif mp.param["mid_side_b2"]: return np.asfortranarray([np.add(wave_right / 1.25, 0.4 * wave_left), np.subtract(wave_left / 1.25, 0.4 * wave_right)])
238
-
239
- return np.asfortranarray([wave_left, wave_right])
240
-
241
- def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False):
242
- bands_n = len(mp.param["band"])
243
- offset = 0
244
-
245
- for d in range(1, bands_n + 1):
246
- bp = mp.param["band"][d]
247
- spec_s = np.zeros(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
248
- h = bp["crop_stop"] - bp["crop_start"]
249
- spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
250
- offset += h
251
-
252
- if d == bands_n:
253
- if extra_bins_h:
254
- max_bin = bp["n_fft"] // 2
255
- spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
256
-
257
- if bp["hpf_start"] > 0:
258
- if is_v51_model: spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
259
- else: spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
260
-
261
- wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model) if bands_n == 1 else np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
262
- else:
263
- sr = mp.param["band"][d + 1]["sr"]
264
- if d == 1:
265
- if is_v51_model: spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
266
- else: spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
267
-
268
- try:
269
- wave = librosa.resample(spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
270
- except ValueError as e:
271
- print(f"{translations['resample_error']}: {e}")
272
- print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
273
- else:
274
- if is_v51_model:
275
- spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
276
- spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
277
- else:
278
- spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
279
- spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
280
-
281
- try:
282
- wave = librosa.resample(np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
283
- except ValueError as e:
284
- print(f"{translations['resample_error']}: {e}")
285
- print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
286
-
287
- return wave
288
-
289
- def get_lp_filter_mask(n_bins, bin_start, bin_stop):
290
- return np.concatenate([np.ones((bin_start - 1, 1)), np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], np.zeros((n_bins - bin_stop, 1))], axis=0)
291
-
292
- def get_hp_filter_mask(n_bins, bin_start, bin_stop):
293
- return np.concatenate([np.zeros((bin_stop + 1, 1)), np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], np.ones((n_bins - bin_start - 2, 1))], axis=0)
294
-
295
- def fft_lp_filter(spec, bin_start, bin_stop):
296
- g = 1.0
297
-
298
- for b in range(bin_start, bin_stop):
299
- g -= 1 / (bin_stop - bin_start)
300
- spec[:, b, :] = g * spec[:, b, :]
301
-
302
- spec[:, bin_stop:, :] *= 0
303
- return spec
304
-
305
- def fft_hp_filter(spec, bin_start, bin_stop):
306
- g = 1.0
307
-
308
- for b in range(bin_start, bin_stop, -1):
309
- g -= 1 / (bin_start - bin_stop)
310
- spec[:, b, :] = g * spec[:, b, :]
311
-
312
- spec[:, 0 : bin_stop + 1, :] *= 0
313
- return spec
314
-
315
- def spectrogram_to_wave_old(spec, hop_length=1024):
316
- if spec.ndim == 2: wave = librosa.istft(spec, hop_length=hop_length)
317
- elif spec.ndim == 3: wave = np.asfortranarray([librosa.istft(np.asfortranarray(spec[0]), hop_length=hop_length), librosa.istft(np.asfortranarray(spec[1]), hop_length=hop_length)])
318
-
319
- return wave
320
-
321
- def wave_to_spectrogram_old(wave, hop_length, n_fft):
322
- return np.asfortranarray([librosa.stft(np.asfortranarray(wave[0]), n_fft=n_fft, hop_length=hop_length), librosa.stft(np.asfortranarray(wave[1]), n_fft=n_fft, hop_length=hop_length)])
323
-
324
- def mirroring(a, spec_m, input_high_end, mp):
325
- if "mirroring" == a:
326
- mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1) * np.exp(1.0j * np.angle(input_high_end))
327
-
328
- return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
329
-
330
- if "mirroring2" == a:
331
- mi = np.multiply(np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1), input_high_end * 1.7)
332
-
333
- return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
334
-
335
- def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
336
- aggr = aggressiveness["value"] * 2
337
-
338
- if aggr != 0:
339
- if is_non_accom_stem:
340
- aggr = 1 - aggr
341
-
342
- if np.any(aggr > 10) or np.any(aggr < -10): print(f"{translations['warnings']}: {aggr}")
343
-
344
- aggr = [aggr, aggr]
345
-
346
- if aggressiveness["aggr_correction"] is not None:
347
- aggr[0] += aggressiveness["aggr_correction"]["left"]
348
- aggr[1] += aggressiveness["aggr_correction"]["right"]
349
-
350
- for ch in range(2):
351
- mask[ch, : aggressiveness["split_bin"]] = np.power(mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3)
352
- mask[ch, aggressiveness["split_bin"] :] = np.power(mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch])
353
-
354
- return mask
355
-
356
- def stft(wave, nfft, hl):
357
- return np.asfortranarray([librosa.stft(np.asfortranarray(wave[0]), n_fft=nfft, hop_length=hl), librosa.stft(np.asfortranarray(wave[1]), n_fft=nfft, hop_length=hl)])
358
-
359
- def istft(spec, hl):
360
- return np.asfortranarray([librosa.istft(np.asfortranarray(spec[0]), hop_length=hl), librosa.istft(np.asfortranarray(spec[1]), hop_length=hl)])
361
-
362
- def spec_effects(wave, algorithm="Default", value=None):
363
- if np.isnan(wave).any() or np.isinf(wave).any(): print(f"{translations['warnings_2']}: {wave.shape}")
364
- spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
365
-
366
- if algorithm == "Min_Mag": wave = istft(np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0]), 1024)
367
- elif algorithm == "Max_Mag": wave = istft(np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0]), 1024)
368
- elif algorithm == "Default": wave = (wave[1] * value) + (wave[0] * (1 - value))
369
- elif algorithm == "Invert_p":
370
- X_mag, y_mag = np.abs(spec[0]), np.abs(spec[1])
371
- wave = istft(spec[1] - np.where(X_mag >= y_mag, X_mag, y_mag) * np.exp(1.0j * np.angle(spec[0])), 1024)
372
-
373
- return wave
374
-
375
- def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
376
- wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
377
- if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
378
-
379
- return wave
380
-
381
- def wave_to_spectrogram_no_mp(wave):
382
- spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
383
-
384
- if spec.ndim == 1: spec = np.asfortranarray([spec, spec])
385
- return spec
386
-
387
- def invert_audio(specs, invert_p=True):
388
- ln = min([specs[0].shape[2], specs[1].shape[2]])
389
- specs[0] = specs[0][:, :, :ln]
390
- specs[1] = specs[1][:, :, :ln]
391
-
392
- if invert_p:
393
- X_mag, y_mag = np.abs(specs[0]), np.abs(specs[1])
394
- v_spec = specs[1] - np.where(X_mag >= y_mag, X_mag, y_mag) * np.exp(1.0j * np.angle(specs[0]))
395
- else:
396
- specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
397
- v_spec = specs[0] - specs[1]
398
-
399
- return v_spec
400
-
401
- def invert_stem(mixture, stem):
402
- return -spectrogram_to_wave_no_mp(invert_audio([wave_to_spectrogram_no_mp(mixture), wave_to_spectrogram_no_mp(stem)])).T
403
-
404
- def ensembling(a, inputs, is_wavs=False):
405
- for i in range(1, len(inputs)):
406
- if i == 1: input = inputs[0]
407
-
408
- if is_wavs:
409
- ln = min([input.shape[1], inputs[i].shape[1]])
410
- input = input[:, :ln]
411
- inputs[i] = inputs[i][:, :ln]
412
- else:
413
- ln = min([input.shape[2], inputs[i].shape[2]])
414
- input = input[:, :, :ln]
415
- inputs[i] = inputs[i][:, :, :ln]
416
-
417
- if MIN_SPEC == a: input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
418
- if MAX_SPEC == a: input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
419
-
420
- return input
421
-
422
- def ensemble_for_align(waves):
423
- specs = []
424
-
425
- for wav in waves:
426
- spec = wave_to_spectrogram_no_mp(wav.T)
427
- specs.append(spec)
428
-
429
- wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
430
- wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
431
-
432
- return wav_aligned
433
-
434
- def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False):
435
- wavs_ = []
436
-
437
- if algorithm == AVERAGE:
438
- output = average_audio(audio_input)
439
- samplerate = 44100
440
- else:
441
- specs = []
442
-
443
- for i in range(len(audio_input)):
444
- wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
445
- wavs_.append(wave)
446
- specs.append( wave if is_wave else wave_to_spectrogram_no_mp(wave))
447
-
448
- wave_shapes = [w.shape[1] for w in wavs_]
449
- target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
450
-
451
- output = ensembling(algorithm, specs, is_wavs=True) if is_wave else spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
452
- output = to_shape(output, target_shape.shape)
453
-
454
- sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
455
-
456
- def to_shape(x, target_shape):
457
- padding_list = []
458
-
459
- for x_dim, target_dim in zip(x.shape, target_shape):
460
- padding_list.append((0, target_dim - x_dim))
461
-
462
- return np.pad(x, tuple(padding_list), mode="constant")
463
-
464
- def to_shape_minimize(x, target_shape):
465
- padding_list = []
466
-
467
- for x_dim, target_dim in zip(x.shape, target_shape):
468
- padding_list.append((0, target_dim - x_dim))
469
-
470
- return np.pad(x, tuple(padding_list), mode="constant")
471
-
472
- def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
473
- if len(audio.shape) == 2:
474
- channel = np.argmax(np.sum(np.abs(audio), axis=1))
475
- audio = audio[channel]
476
-
477
- for i in range(0, len(audio), frame_length):
478
- if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold: return (i / sr) * 1000
479
-
480
- return (len(audio) / sr) * 1000
481
-
482
- def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024):
483
- def find_silence_end(audio):
484
- if len(audio.shape) == 2:
485
- channel = np.argmax(np.sum(np.abs(audio), axis=1))
486
- audio_mono = audio[channel]
487
- else: audio_mono = audio
488
-
489
- for i in range(0, len(audio_mono), frame_length):
490
- if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold: return i
491
-
492
- return len(audio_mono)
493
-
494
- ref_silence_end = find_silence_end(reference_audio)
495
- target_silence_end = find_silence_end(target_audio)
496
- silence_difference = ref_silence_end - target_silence_end
497
-
498
- try:
499
- silence_difference_p = ((ref_silence_end / 44100) * 1000) - ((target_silence_end / 44100) * 1000)
500
- except Exception as e:
501
- pass
502
-
503
- if silence_difference > 0: return np.hstack((np.zeros((target_audio.shape[0], silence_difference))if len(target_audio.shape) == 2 else np.zeros(silence_difference), target_audio))
504
- elif silence_difference < 0: return target_audio[:, -silence_difference:]if len(target_audio.shape) == 2 else target_audio[-silence_difference:]
505
- else: return target_audio
506
-
507
- def match_array_shapes(array_1, array_2, is_swap=False):
508
-
509
- if is_swap: array_1, array_2 = array_1.T, array_2.T
510
-
511
- if array_1.shape[1] > array_2.shape[1]: array_1 = array_1[:, : array_2.shape[1]]
512
- elif array_1.shape[1] < array_2.shape[1]:
513
- padding = array_2.shape[1] - array_1.shape[1]
514
- array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
515
-
516
- if is_swap: array_1, array_2 = array_1.T, array_2.T
517
-
518
- return array_1
519
-
520
- def match_mono_array_shapes(array_1, array_2):
521
- if len(array_1) > len(array_2): array_1 = array_1[: len(array_2)]
522
- elif len(array_1) < len(array_2):
523
- padding = len(array_2) - len(array_1)
524
- array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
525
-
526
- return array_1
527
-
528
- def change_pitch_semitones(y, sr, semitone_shift):
529
- factor = 2 ** (semitone_shift / 12)
530
- y_pitch_tuned = []
531
-
532
- for y_channel in y:
533
- y_pitch_tuned.append(librosa.resample(y_channel, orig_sr=sr, target_sr=sr * factor, res_type=wav_resolution_float_resampling))
534
-
535
- y_pitch_tuned = np.array(y_pitch_tuned)
536
- new_sr = sr * factor
537
-
538
- return y_pitch_tuned, new_sr
539
-
540
- def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True):
541
- wav, sr = librosa.load(audio_file, sr=44100, mono=False)
542
- if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
543
-
544
- if not is_time_correction: wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0]
545
- else:
546
- if is_pitch: wav_1, wav_2 = pitch_shift(wav[0], sr, rate, rbargs=None), pitch_shift(wav[1], sr, rate, rbargs=None)
547
- else: wav_1, wav_2 = time_stretch(wav[0], sr, rate, rbargs=None), time_stretch(wav[1], sr, rate, rbargs=None)
548
-
549
- if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
550
- if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
551
-
552
- wav_mix = np.asfortranarray([wav_1, wav_2])
553
-
554
- sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
555
- save_format(export_path)
556
-
557
-
558
- def average_audio(audio):
559
- waves, wave_shapes, final_waves = [], [], []
560
-
561
- for i in range(len(audio)):
562
- wave = librosa.load(audio[i], sr=44100, mono=False)
563
- waves.append(wave[0])
564
- wave_shapes.append(wave[0].shape[1])
565
-
566
- wave_shapes_index = wave_shapes.index(max(wave_shapes))
567
- target_shape = waves[wave_shapes_index]
568
-
569
- waves.pop(wave_shapes_index)
570
- final_waves.append(target_shape)
571
-
572
- for n_array in waves:
573
- wav_target = to_shape(n_array, target_shape.shape)
574
- final_waves.append(wav_target)
575
-
576
- waves = sum(final_waves)
577
- return waves / len(audio)
578
-
579
- def average_dual_sources(wav_1, wav_2, value):
580
- if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
581
- if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
582
-
583
- return (wav_1 * value) + (wav_2 * (1 - value))
584
-
585
- def reshape_sources(wav_1, wav_2):
586
- if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
587
-
588
- if wav_1.shape < wav_2.shape:
589
- ln = min([wav_1.shape[1], wav_2.shape[1]])
590
- wav_2 = wav_2[:, :ln]
591
-
592
- ln = min([wav_1.shape[1], wav_2.shape[1]])
593
- wav_1 = wav_1[:, :ln]
594
- wav_2 = wav_2[:, :ln]
595
-
596
- return wav_2
597
-
598
- def reshape_sources_ref(wav_1_shape, wav_2):
599
- if wav_1_shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1_shape)
600
- return wav_2
601
-
602
- def combine_arrarys(audio_sources, is_swap=False):
603
- source = np.zeros_like(max(audio_sources, key=np.size))
604
-
605
- for v in audio_sources:
606
- v = match_array_shapes(v, source, is_swap=is_swap)
607
- source += v
608
-
609
- return source
610
-
611
- def combine_audio(paths, audio_file_base=None, wav_type_set="FLOAT", save_format=None):
612
- source = combine_arrarys([load_audio(i) for i in paths])
613
- save_path = f"{audio_file_base}_combined.wav"
614
- sf.write(save_path, source.T, 44100, subtype=wav_type_set)
615
- save_format(save_path)
616
-
617
- def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
618
- return combine_arrarys([inst_source * (1 - reduction_rate), voc_source], is_swap=True)
619
-
620
- def organize_inputs(inputs):
621
- input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
622
-
623
- for i in inputs:
624
- if i.endswith("_(Vocals).wav"): input_list["reference"] = i
625
- elif "_RVC_" in i: input_list["target"] = i
626
- elif i.endswith("reverbed_stem.wav"): input_list["reverb"] = i
627
- elif i.endswith("_(Instrumental).wav"): input_list["inst"] = i
628
-
629
- return input_list
630
-
631
- def check_if_phase_inverted(wav1, wav2, is_mono=False):
632
- if not is_mono:
633
- wav1 = np.mean(wav1, axis=0)
634
- wav2 = np.mean(wav2, axis=0)
635
-
636
- return np.corrcoef(wav1[:1000], wav2[:1000])[0, 1] < 0
637
-
638
- def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_save_aligned, command_Text, save_format, align_window, align_intro_val, db_analysis, set_progress_bar, phase_option, phase_shifts, is_match_silence, is_spec_match):
639
- global progress_value
640
- progress_value = 0
641
- is_mono = False
642
-
643
- def get_diff(a, b):
644
- return np.correlate(a, b, "full").argmax() - (b.shape[0] - 1)
645
-
646
- def progress_bar(length):
647
- global progress_value
648
- progress_value += 1
649
-
650
- if (0.90 / length * progress_value) >= 0.9: length = progress_value + 1
651
- set_progress_bar(0.1, (0.9 / length * progress_value))
652
-
653
- wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
654
- wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
655
-
656
- if wav1.ndim == 1 and wav2.ndim == 1: is_mono = True
657
- elif wav1.ndim == 1: wav1 = np.asfortranarray([wav1, wav1])
658
- elif wav2.ndim == 1: wav2 = np.asfortranarray([wav2, wav2])
659
-
660
- if phase_option == AUTO_PHASE:
661
- if check_if_phase_inverted(wav1, wav2, is_mono=is_mono): wav2 = -wav2
662
- elif phase_option == POSITIVE_PHASE: wav2 = +wav2
663
- elif phase_option == NEGATIVE_PHASE: wav2 = -wav2
664
-
665
- if is_match_silence: wav2 = adjust_leading_silence(wav2, wav1)
666
-
667
- wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
668
- wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
669
-
670
- if not is_mono:
671
- wav1 = wav1.transpose()
672
- wav2 = wav2.transpose()
673
-
674
- wav2_org = wav2.copy()
675
-
676
- command_Text(translations["process_file"])
677
- seconds_length = min(wav1_length, wav2_length)
678
- wav2_aligned_sources = []
679
-
680
- for sec_len in align_intro_val:
681
- sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
682
- index = sr1 * sec_seg
683
-
684
- if is_mono:
685
- samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
686
- diff = get_diff(samp1, samp2)
687
- else:
688
- index = sr1 * sec_seg
689
- samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
690
- samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1]
691
- diff, _ = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
692
-
693
- if diff > 0: wav2_aligned = np.append(np.zeros(diff) if is_mono else np.zeros((diff, 2)), wav2_org, axis=0)
694
- elif diff < 0: wav2_aligned = wav2_org[-diff:]
695
- else: wav2_aligned = wav2_org
696
-
697
- if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources): wav2_aligned_sources.append(wav2_aligned)
698
-
699
- unique_sources = len(wav2_aligned_sources)
700
- sub_mapper_big_mapper = {}
701
-
702
- for s in wav2_aligned_sources:
703
- wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True)
704
-
705
- if align_window:
706
- wav_sub = time_correction(wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts)
707
- sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{np.abs(wav_sub).mean(): wav_sub}}
708
- else:
709
- wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
710
-
711
- for db_adjustment in db_analysis[1]:
712
- sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{np.abs(wav_sub).mean(): wav1 - (wav2_aligned * (10 ** (db_adjustment / 20)))}}
713
-
714
- wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values())) if is_spec_match and len(list(sub_mapper_big_mapper.values())) >= 2 else ensemble_wav(list(sub_mapper_big_mapper.values()))
715
- wav_sub = np.clip(wav_sub, -1, +1)
716
-
717
- command_Text(translations["save_instruments"])
718
-
719
- if is_save_aligned or is_spec_match:
720
- wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True)
721
- wav2_aligned = wav1 - wav_sub
722
-
723
- if is_spec_match:
724
- if wav1.ndim == 1 and wav2.ndim == 1:
725
- wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
726
- wav1 = np.asfortranarray([wav1, wav1]).T
727
-
728
- wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
729
- wav_sub = wav1 - wav2_aligned
730
-
731
- if is_save_aligned:
732
- sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
733
- save_format(file2_aligned)
734
-
735
- sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
736
- save_format(file_subtracted)
737
-
738
- def phase_shift_hilbert(signal, degree):
739
- analytic_signal = hilbert(signal)
740
- return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag
741
-
742
- def get_phase_shifted_tracks(track, phase_shift):
743
- if phase_shift == 180: return [track, -track]
744
-
745
- step = phase_shift
746
- end = 180 - (180 % step) if 180 % step == 0 else 181
747
- phase_range = range(step, end, step)
748
- flipped_list = [track, -track]
749
-
750
- for i in phase_range:
751
- flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)])
752
-
753
- return flipped_list
754
-
755
- def time_correction(mix, instrumental, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P):
756
- def align_tracks(track1, track2):
757
- shifted_tracks = {}
758
- track2 = track2 * np.power(10, db_analysis[0] / 20)
759
- track2_flipped = [track2] if phase_shifts == 190 else get_phase_shifted_tracks(track2, phase_shifts)
760
-
761
- for db_adjustment in db_analysis[1]:
762
- for t in track2_flipped:
763
- track2_adjusted = t * (10 ** (db_adjustment / 20))
764
- track2_shifted = np.roll(track2_adjusted, shift=np.argmax(np.abs(correlate(track1, track2_adjusted))) - (len(track1) - 1))
765
- shifted_tracks[np.abs(track1 - track2_shifted).mean()] = track2_shifted
766
-
767
- return shifted_tracks[min(shifted_tracks.keys())]
768
-
769
- assert mix.shape == instrumental.shape, translations["assert"].format(mixshape=mix.shape, instrumentalshape=instrumental.shape)
770
- seconds_length = seconds_length // 2
771
-
772
- sub_mapper = {}
773
- progress_update_interval, total_iterations = 120, 0
774
-
775
- if len(align_window) > 2: progress_update_interval = 320
776
-
777
- for secs in align_window:
778
- step = secs / 2
779
- window_size = int(sr * secs)
780
- step_size = int(sr * step)
781
-
782
- if len(mix.shape) == 1: total_iterations += ((len(range(0, len(mix) - window_size, step_size)) // progress_update_interval) * unique_sources)
783
- else: total_iterations += ((len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2 // progress_update_interval) * unique_sources)
784
-
785
- for secs in align_window:
786
- sub = np.zeros_like(mix)
787
- divider = np.zeros_like(mix)
788
- window_size = int(sr * secs)
789
- step_size = int(sr * secs / 2)
790
- window = np.hanning(window_size)
791
-
792
- if len(mix.shape) == 1:
793
- counter = 0
794
-
795
- for i in range(0, len(mix) - window_size, step_size):
796
- counter += 1
797
- if counter % progress_update_interval == 0: progress_bar(total_iterations)
798
-
799
- window_mix = mix[i : i + window_size] * window
800
- window_instrumental = instrumental[i : i + window_size] * window
801
- window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
802
- sub[i : i + window_size] += window_mix - window_instrumental_aligned
803
- divider[i : i + window_size] += window
804
- else:
805
- counter = 0
806
-
807
- for ch in range(mix.shape[1]):
808
- for i in range(0, len(mix[:, ch]) - window_size, step_size):
809
- counter += 1
810
-
811
- if counter % progress_update_interval == 0: progress_bar(total_iterations)
812
-
813
- window_mix = mix[i : i + window_size, ch] * window
814
- window_instrumental = instrumental[i : i + window_size, ch] * window
815
- window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
816
- sub[i : i + window_size, ch] += window_mix - window_instrumental_aligned
817
- divider[i : i + window_size, ch] += window
818
-
819
- return ensemble_wav(list({**sub_mapper, **{np.abs(sub).mean(): np.where(divider > 1e-6, sub / divider, sub)}}.values()), split_size=12)
820
-
821
- def ensemble_wav(waveforms, split_size=240):
822
- waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)}
823
- final_waveform = []
824
- for third_idx in range(split_size):
825
- final_waveform.append(waveform_thirds[np.argmin([np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))])][third_idx])
826
-
827
- return np.concatenate(final_waveform)
828
-
829
- def ensemble_wav_min(waveforms):
830
- for i in range(1, len(waveforms)):
831
- if i == 1: wave = waveforms[0]
832
- ln = min(len(wave), len(waveforms[i]))
833
- wave = wave[:ln]
834
- waveforms[i] = waveforms[i][:ln]
835
- wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
836
-
837
- return wave
838
-
839
- def align_audio_test(wav1, wav2, sr1=44100):
840
- def get_diff(a, b):
841
- return np.correlate(a, b, "full").argmax() - (b.shape[0] - 1)
842
-
843
- wav1 = wav1.transpose()
844
- wav2 = wav2.transpose()
845
- wav2_org = wav2.copy()
846
- index = sr1
847
- diff = get_diff(wav1[index : index + sr1, 0], wav2[index : index + sr1, 0])
848
-
849
- if diff > 0: wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
850
- elif diff < 0: wav2_aligned = wav2_org[-diff:]
851
- else: wav2_aligned = wav2_org
852
- return wav2_aligned
853
-
854
- def load_audio(audio_file):
855
- wav, _ = librosa.load(audio_file, sr=44100, mono=False)
856
- if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
857
- return wav
858
-
859
- def __rubberband(y, sr, **kwargs):
860
- assert sr > 0
861
- fd, infile = tempfile.mkstemp(suffix='.wav')
862
- os.close(fd)
863
- fd, outfile = tempfile.mkstemp(suffix='.wav')
864
- os.close(fd)
865
-
866
- sf.write(infile, y, sr)
867
-
868
- try:
869
- arguments = [os.path.join(BASE_PATH_RUB, 'rubberband'), '-q']
870
- for key, value in six.iteritems(kwargs):
871
- arguments.append(str(key))
872
- arguments.append(str(value))
873
-
874
- arguments.extend([infile, outfile])
875
- subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
876
-
877
- y_out, _ = sf.read(outfile, always_2d=True)
878
- if y.ndim == 1: y_out = np.squeeze(y_out)
879
- except OSError as exc:
880
- six.raise_from(RuntimeError(translations["rubberband"]), exc)
881
- finally:
882
- os.unlink(infile)
883
- os.unlink(outfile)
884
-
885
- return y_out
886
-
887
- def time_stretch(y, sr, rate, rbargs=None):
888
- if rate <= 0: raise ValueError(translations["rate"])
889
- if rate == 1.0: return y
890
- if rbargs is None: rbargs = dict()
891
-
892
- rbargs.setdefault('--tempo', rate)
893
- return __rubberband(y, sr, **rbargs)
894
-
895
- def pitch_shift(y, sr, n_steps, rbargs=None):
896
- if n_steps == 0: return y
897
- if rbargs is None: rbargs = dict()
898
-
899
- rbargs.setdefault('--pitch', n_steps)
900
- return __rubberband(y, sr, **rbargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/edge_tts.py DELETED
@@ -1,180 +0,0 @@
1
- import re
2
- import ssl
3
- import json
4
- import time
5
- import uuid
6
- import codecs
7
- import certifi
8
- import aiohttp
9
-
10
- from io import TextIOWrapper
11
- from dataclasses import dataclass
12
- from contextlib import nullcontext
13
- from xml.sax.saxutils import escape
14
-
15
-
16
- @dataclass
17
- class TTSConfig:
18
- def __init__(self, voice, rate, volume, pitch):
19
- self.voice = voice
20
- self.rate = rate
21
- self.volume = volume
22
- self.pitch = pitch
23
-
24
- @staticmethod
25
- def validate_string_param(param_name, param_value, pattern):
26
- if re.match(pattern, param_value) is None: raise ValueError(f"{param_name} '{param_value}'.")
27
- return param_value
28
-
29
- def __post_init__(self):
30
- match = re.match(r"^([a-z]{2,})-([A-Z]{2,})-(.+Neural)$", self.voice)
31
- if match is not None:
32
- region = match.group(2)
33
- name = match.group(3)
34
-
35
- if name.find("-") != -1:
36
- region = region + "-" + name[: name.find("-")]
37
- name = name[name.find("-") + 1 :]
38
-
39
- self.voice = ("Microsoft Server Speech Text to Speech Voice" + f" ({match.group(1)}-{region}, {name})")
40
-
41
- self.validate_string_param("voice", self.voice, r"^Microsoft Server Speech Text to Speech Voice \(.+,.+\)$")
42
- self.validate_string_param("rate", self.rate, r"^[+-]\d+%$")
43
- self.validate_string_param("volume", self.volume, r"^[+-]\d+%$")
44
- self.validate_string_param("pitch", self.pitch, r"^[+-]\d+Hz$")
45
-
46
- def get_headers_and_data(data, header_length):
47
- headers = {}
48
-
49
- for line in data[:header_length].split(b"\r\n"):
50
- key, value = line.split(b":", 1)
51
- headers[key] = value
52
-
53
- return headers, data[header_length + 2 :]
54
-
55
- def date_to_string():
56
- return time.strftime("%a %b %d %Y %H:%M:%S GMT+0000 (Coordinated Universal Time)", time.gmtime())
57
-
58
- def mkssml(tc, escaped_text):
59
- if isinstance(escaped_text, bytes): escaped_text = escaped_text.decode("utf-8")
60
- return (f"<speak version='1.0' xmlns='{codecs.decode('uggc://jjj.j3.bet/2001/10/flagurfvf', 'rot13')}' xml:lang='en-US'>" f"<voice name='{tc.voice}'>" f"<prosody pitch='{tc.pitch}' rate='{tc.rate}' volume='{tc.volume}'>" f"{escaped_text}" "</prosody>" "</voice>" "</speak>")
61
-
62
- def connect_id():
63
- return str(uuid.uuid4()).replace("-", "")
64
-
65
- def ssml_headers_plus_data(request_id, timestamp, ssml):
66
- return (f"X-RequestId:{request_id}\r\n" "Content-Type:application/ssml+xml\r\n" f"X-Timestamp:{timestamp}Z\r\n" "Path:ssml\r\n\r\n" f"{ssml}")
67
-
68
- def remove_incompatible_characters(string):
69
- if isinstance(string, bytes): string = string.decode("utf-8")
70
- chars = list(string)
71
-
72
- for idx, char in enumerate(chars):
73
- code = ord(char)
74
- if (0 <= code <= 8) or (11 <= code <= 12) or (14 <= code <= 31): chars[idx] = " "
75
-
76
- return "".join(chars)
77
-
78
- def split_text_by_byte_length(text, byte_length):
79
- if isinstance(text, str): text = text.encode("utf-8")
80
- if byte_length <= 0: raise ValueError("byte_length > 0")
81
-
82
- while len(text) > byte_length:
83
- split_at = text.rfind(b" ", 0, byte_length)
84
- split_at = split_at if split_at != -1 else byte_length
85
-
86
- while b"&" in text[:split_at]:
87
- ampersand_index = text.rindex(b"&", 0, split_at)
88
- if text.find(b";", ampersand_index, split_at) != -1: break
89
-
90
- split_at = ampersand_index - 1
91
- if split_at == 0: break
92
-
93
- new_text = text[:split_at].strip()
94
-
95
- if new_text: yield new_text
96
- if split_at == 0: split_at = 1
97
-
98
- text = text[split_at:]
99
-
100
- new_text = text.strip()
101
- if new_text: yield new_text
102
-
103
- class Communicate:
104
- def __init__(self, text, voice, *, rate="+0%", volume="+0%", pitch="+0Hz", proxy=None, connect_timeout=10, receive_timeout=60):
105
- self.tts_config = TTSConfig(voice, rate, volume, pitch)
106
- self.texts = split_text_by_byte_length(escape(remove_incompatible_characters(text)), 2**16 - (len(ssml_headers_plus_data(connect_id(), date_to_string(), mkssml(self.tts_config, ""))) + 50))
107
- self.proxy = proxy
108
- self.session_timeout = aiohttp.ClientTimeout(total=None, connect=None, sock_connect=connect_timeout, sock_read=receive_timeout)
109
- self.state = {"partial_text": None, "offset_compensation": 0, "last_duration_offset": 0, "stream_was_called": False}
110
-
111
- def __parse_metadata(self, data):
112
- for meta_obj in json.loads(data)["Metadata"]:
113
- meta_type = meta_obj["Type"]
114
- if meta_type == "WordBoundary": return {"type": meta_type, "offset": (meta_obj["Data"]["Offset"] + self.state["offset_compensation"]), "duration": meta_obj["Data"]["Duration"], "text": meta_obj["Data"]["text"]["Text"]}
115
- if meta_type in ("SessionEnd",): continue
116
-
117
- async def __stream(self):
118
- async def send_command_request():
119
- await websocket.send_str(f"X-Timestamp:{date_to_string()}\r\n" "Content-Type:application/json; charset=utf-8\r\n" "Path:speech.config\r\n\r\n" '{"context":{"synthesis":{"audio":{"metadataoptions":{' '"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true},' '"outputFormat":"audio-24khz-48kbitrate-mono-mp3"' "}}}}\r\n")
120
-
121
- async def send_ssml_request():
122
- await websocket.send_str(ssml_headers_plus_data(connect_id(), date_to_string(), mkssml(self.tts_config, self.state["partial_text"])))
123
-
124
- audio_was_received = False
125
- ssl_ctx = ssl.create_default_context(cafile=certifi.where())
126
-
127
- async with aiohttp.ClientSession(trust_env=True, timeout=self.session_timeout) as session, session.ws_connect(f"wss://speech.platform.bing.com/consumer/speech/synthesize/readaloud/edge/v1?TrustedClientToken=6A5AA1D4EAFF4E9FB37E23D68491D6F4&ConnectionId={connect_id()}", compress=15, proxy=self.proxy, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" " (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36" " Edg/130.0.0.0", "Accept-Encoding": "gzip, deflate, br", "Accept-Language": "en-US,en;q=0.9", "Pragma": "no-cache", "Cache-Control": "no-cache", "Origin": "chrome-extension://jdiccldimpdaibmpdkjnbmckianbfold"}, ssl=ssl_ctx) as websocket:
128
- await send_command_request()
129
- await send_ssml_request()
130
-
131
- async for received in websocket:
132
- if received.type == aiohttp.WSMsgType.TEXT:
133
- encoded_data: bytes = received.data.encode("utf-8")
134
- parameters, data = get_headers_and_data(encoded_data, encoded_data.find(b"\r\n\r\n"))
135
- path = parameters.get(b"Path", None)
136
-
137
- if path == b"audio.metadata":
138
- parsed_metadata = self.__parse_metadata(data)
139
- yield parsed_metadata
140
- self.state["last_duration_offset"] = (parsed_metadata["offset"] + parsed_metadata["duration"])
141
- elif path == b"turn.end":
142
- self.state["offset_compensation"] = self.state["last_duration_offset"]
143
- self.state["offset_compensation"] += 8_750_000
144
- break
145
- elif received.type == aiohttp.WSMsgType.BINARY:
146
- if len(received.data) < 2: raise Exception("received.data < 2")
147
-
148
- header_length = int.from_bytes(received.data[:2], "big")
149
- if header_length > len(received.data): raise Exception("header_length > received.data")
150
-
151
- parameters, data = get_headers_and_data(received.data, header_length)
152
- if parameters.get(b"Path") != b"audio": raise Exception("Path != audio")
153
-
154
- content_type = parameters.get(b"Content-Type", None)
155
- if content_type not in [b"audio/mpeg", None]: raise Exception("content_type != audio/mpeg")
156
-
157
- if content_type is None and len(data) == 0: continue
158
-
159
- if len(data) == 0: raise Exception("data = 0")
160
- audio_was_received = True
161
- yield {"type": "audio", "data": data}
162
-
163
- if not audio_was_received: raise Exception("!audio_was_received")
164
-
165
- async def stream(self):
166
- if self.state["stream_was_called"]: raise RuntimeError("stream_was_called")
167
- self.state["stream_was_called"] = True
168
-
169
- for self.state["partial_text"] in self.texts:
170
- async for message in self.__stream():
171
- yield message
172
-
173
- async def save(self, audio_fname, metadata_fname = None):
174
- metadata = (open(metadata_fname, "w", encoding="utf-8") if metadata_fname is not None else nullcontext())
175
- with metadata, open(audio_fname, "wb") as audio:
176
- async for message in self.stream():
177
- if message["type"] == "audio": audio.write(message["data"])
178
- elif (isinstance(metadata, TextIOWrapper) and message["type"] == "WordBoundary"):
179
- json.dump(message, metadata)
180
- metadata.write("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/gdown.py DELETED
@@ -1,110 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import json
5
- import tqdm
6
- import codecs
7
- import tempfile
8
- import requests
9
-
10
- from urllib.parse import urlparse, parse_qs, unquote
11
-
12
- sys.path.append(os.getcwd())
13
-
14
- from main.configs.config import Config
15
- translations = Config().translations
16
-
17
- def parse_url(url):
18
- parsed = urlparse(url)
19
- is_download_link = parsed.path.endswith("/uc")
20
- if not parsed.hostname in ("drive.google.com", "docs.google.com"): return None, is_download_link
21
- file_id = parse_qs(parsed.query).get("id", [None])[0]
22
-
23
- if file_id is None:
24
- for pattern in (r"^/file/d/(.*?)/(edit|view)$", r"^/file/u/[0-9]+/d/(.*?)/(edit|view)$", r"^/document/d/(.*?)/(edit|htmlview|view)$", r"^/document/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$"):
25
- match = re.match(pattern, parsed.path)
26
- if match:
27
- file_id = match.group(1)
28
- break
29
- return file_id, is_download_link
30
-
31
- def get_url_from_gdrive_confirmation(contents):
32
- for pattern in (r'href="(\/uc\?export=download[^"]+)', r'href="/open\?id=([^"]+)"', r'"downloadUrl":"([^"]+)'):
33
- match = re.search(pattern, contents)
34
- if match:
35
- url = match.group(1)
36
- if pattern == r'href="/open\?id=([^"]+)"': url = (codecs.decode("uggcf://qevir.hfrepbagrag.tbbtyr.pbz/qbjaybnq?vq=", "rot13") + url + "&confirm=t&uuid=" + re.search(r'<input\s+type="hidden"\s+name="uuid"\s+value="([^"]+)"', contents).group(1))
37
- elif pattern == r'"downloadUrl":"([^"]+)': url = url.replace("\\u003d", "=").replace("\\u0026", "&")
38
- else: url = codecs.decode("uggcf://qbpf.tbbtyr.pbz", "rot13") + url.replace("&", "&")
39
- return url
40
-
41
- match = re.search(r'<p class="uc-error-subcaption">(.*)</p>', contents)
42
- if match: raise Exception(match.group(1))
43
- raise Exception(translations["gdown_error"])
44
-
45
- def _get_session(use_cookies, return_cookies_file=False):
46
- sess = requests.session()
47
- sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
48
- cookies_file = os.path.join(os.path.expanduser("~"), ".cache/gdown/cookies.json")
49
-
50
- if os.path.exists(cookies_file) and use_cookies:
51
- with open(cookies_file) as f:
52
- for k, v in json.load(f):
53
- sess.cookies[k] = v
54
- return (sess, cookies_file) if return_cookies_file else sess
55
-
56
- def gdown_download(url=None, id=None, output=None):
57
- if not (id is None) ^ (url is None): raise ValueError(translations["gdown_value_error"])
58
- if id is not None: url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{id}"
59
-
60
- url_origin = url
61
- sess, cookies_file = _get_session(use_cookies=True, return_cookies_file=True)
62
- gdrive_file_id, is_gdrive_download_link = parse_url(url)
63
-
64
- if gdrive_file_id:
65
- url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{gdrive_file_id}"
66
- url_origin = url
67
- is_gdrive_download_link = True
68
-
69
- while 1:
70
- res = sess.get(url, stream=True, verify=True)
71
- if url == url_origin and res.status_code == 500:
72
- url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/bcra?vq=', 'rot13')}{gdrive_file_id}"
73
- continue
74
-
75
- os.makedirs(os.path.dirname(cookies_file), exist_ok=True)
76
- with open(cookies_file, "w") as f:
77
- json.dump([(k, v) for k, v in sess.cookies.items() if not k.startswith("download_warning_")], f, indent=2)
78
-
79
- if "Content-Disposition" in res.headers: break
80
- if not (gdrive_file_id and is_gdrive_download_link): break
81
-
82
- try:
83
- url = get_url_from_gdrive_confirmation(res.text)
84
- except Exception as e:
85
- raise Exception(e)
86
-
87
- if gdrive_file_id and is_gdrive_download_link:
88
- content_disposition = unquote(res.headers["Content-Disposition"])
89
- filename_from_url = (re.search(r"filename\*=UTF-8''(.*)", content_disposition) or re.search(r'filename=["\']?(.*?)["\']?$', content_disposition)).group(1).replace(os.path.sep, "_")
90
- else: filename_from_url = os.path.basename(url)
91
-
92
- output = os.path.join(output or ".", filename_from_url)
93
- tmp_file = tempfile.mktemp(suffix=tempfile.template, prefix=os.path.basename(output), dir=os.path.dirname(output))
94
- f = open(tmp_file, "ab")
95
-
96
- if tmp_file is not None and f.tell() != 0: res = sess.get(url, headers={"Range": f"bytes={f.tell()}-"}, stream=True, verify=True)
97
- print(translations["to"], os.path.abspath(output), file=sys.stderr)
98
-
99
- try:
100
- with tqdm.tqdm(total=int(res.headers.get("Content-Length", 0)), ncols=100, unit="byte") as pbar:
101
- for chunk in res.iter_content(chunk_size=512 * 1024):
102
- f.write(chunk)
103
- pbar.update(len(chunk))
104
-
105
- pbar.close()
106
- if tmp_file: f.close()
107
- finally:
108
- os.rename(tmp_file, output)
109
- sess.close()
110
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/google_tts.py DELETED
@@ -1,30 +0,0 @@
1
- import os
2
- import codecs
3
- import librosa
4
- import requests
5
-
6
- import soundfile as sf
7
-
8
- def google_tts(text, lang="vi", speed=1, pitch=0, output_file="output.mp3"):
9
- try:
10
- response = requests.get(codecs.decode("uggcf://genafyngr.tbbtyr.pbz/genafyngr_ggf", "rot13"), params={"ie": "UTF-8", "q": text, "tl": lang, "ttsspeed": speed, "client": "tw-ob"}, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36"})
11
-
12
- if response.status_code == 200:
13
- with open(output_file, "wb") as f:
14
- f.write(response.content)
15
-
16
- format = os.path.splitext(os.path.basename(output_file))[-1].lower().replace('.', '')
17
-
18
- if pitch != 0: pitch_shift(input_file=output_file, output_file=output_file, pitch=pitch, export_format=format)
19
- if speed != 1: change_speed(input_file=output_file, output_file=output_file, speed=speed, export_format=format)
20
- else: raise ValueError(f"{response.status_code}, {response.text}")
21
- except Exception as e:
22
- raise RuntimeError(e)
23
-
24
- def pitch_shift(input_file, output_file, pitch, export_format):
25
- y, sr = librosa.load(input_file, sr=None)
26
- sf.write(file=output_file, data=librosa.effects.pitch_shift(y, sr=sr, n_steps=pitch), samplerate=sr, format=export_format)
27
-
28
- def change_speed(input_file, output_file, speed, export_format):
29
- y, sr = librosa.load(input_file, sr=None)
30
- sf.write(file=output_file, data=librosa.effects.time_stretch(y, rate=speed), samplerate=sr, format=export_format)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/huggingface.py DELETED
@@ -1,24 +0,0 @@
1
- import os
2
- import requests
3
- import tqdm
4
-
5
-
6
- def HF_download_file(url, output_path=None):
7
- url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
8
-
9
- if output_path is None: output_path = os.path.basename(url)
10
- else: output_path = os.path.join(output_path, os.path.basename(url)) if os.path.isdir(output_path) else output_path
11
-
12
- response = requests.get(url, stream=True, timeout=300)
13
-
14
- if response.status_code == 200:
15
- progress_bar = tqdm.tqdm(total=int(response.headers.get("content-length", 0)), ncols=100, unit="byte")
16
-
17
- with open(output_path, "wb") as f:
18
- for chunk in response.iter_content(chunk_size=8192):
19
- progress_bar.update(len(chunk))
20
- f.write(chunk)
21
-
22
- progress_bar.close()
23
- return output_path
24
- else: raise ValueError(response.status_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/mediafire.py DELETED
@@ -1,30 +0,0 @@
1
- import os
2
- import sys
3
- import requests
4
- from bs4 import BeautifulSoup
5
-
6
-
7
- def Mediafire_Download(url, output=None, filename=None):
8
- if not filename: filename = url.split('/')[-2]
9
- if not output: output = os.path.dirname(os.path.realpath(__file__))
10
- output_file = os.path.join(output, filename)
11
-
12
- sess = requests.session()
13
- sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
14
-
15
- try:
16
- with requests.get(BeautifulSoup(sess.get(url).content, "html.parser").find(id="downloadButton").get("href"), stream=True) as r:
17
- r.raise_for_status()
18
- with open(output_file, "wb") as f:
19
- total_length = int(r.headers.get('content-length'))
20
- download_progress = 0
21
-
22
- for chunk in r.iter_content(chunk_size=1024):
23
- download_progress += len(chunk)
24
- f.write(chunk)
25
- sys.stdout.write(f"\r[{filename}]: {int(100 * download_progress/total_length)}% ({round(download_progress/1024/1024, 2)}mb/{round(total_length/1024/1024, 2)}mb)")
26
- sys.stdout.flush()
27
- sys.stdout.write("\n")
28
- return output_file
29
- except Exception as e:
30
- raise RuntimeError(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/meganz.py DELETED
@@ -1,160 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import json
5
- import tqdm
6
- import time
7
- import codecs
8
- import random
9
- import base64
10
- import struct
11
- import shutil
12
- import requests
13
- import tempfile
14
-
15
- from Crypto.Cipher import AES
16
- from Crypto.Util import Counter
17
-
18
- sys.path.append(os.getcwd())
19
-
20
- from main.configs.config import Config
21
- translations = Config().translations
22
-
23
- def makebyte(x):
24
- return codecs.latin_1_encode(x)[0]
25
-
26
- def a32_to_str(a):
27
- return struct.pack('>%dI' % len(a), *a)
28
-
29
- def get_chunks(size):
30
- p, s = 0, 0x20000
31
-
32
- while p + s < size:
33
- yield (p, s)
34
- p += s
35
-
36
- if s < 0x100000: s += 0x20000
37
-
38
- yield (p, size - p)
39
-
40
- def decrypt_attr(attr, key):
41
- attr = codecs.latin_1_decode(AES.new(a32_to_str(key), AES.MODE_CBC, makebyte('\0' * 16)).decrypt(attr))[0].rstrip('\0')
42
-
43
- return json.loads(attr[4:]) if attr[:6] == 'MEGA{"' else False
44
-
45
- def _api_request(data):
46
- sequence_num = random.randint(0, 0xFFFFFFFF)
47
- params = {'id': sequence_num}
48
- sequence_num += 1
49
-
50
- if not isinstance(data, list): data = [data]
51
-
52
- for attempt in range(60):
53
- try:
54
- json_resp = json.loads(requests.post(f'https://g.api.mega.co.nz/cs', params=params, data=json.dumps(data), timeout=160).text)
55
-
56
- try:
57
- if isinstance(json_resp, list): int_resp = json_resp[0] if isinstance(json_resp[0], int) else None
58
- elif isinstance(json_resp, int): int_resp = json_resp
59
- except IndexError:
60
- int_resp = None
61
-
62
- if int_resp is not None:
63
- if int_resp == 0: return int_resp
64
- if int_resp == -3: raise RuntimeError('int_resp==-3')
65
- raise Exception(int_resp)
66
-
67
- return json_resp[0]
68
- except (RuntimeError, requests.exceptions.RequestException):
69
- if attempt == 60 - 1: raise
70
- delay = 2 * (2 ** attempt)
71
- time.sleep(delay)
72
-
73
- def base64_url_decode(data):
74
- data += '=='[(2 - len(data) * 3) % 4:]
75
-
76
- for search, replace in (('-', '+'), ('_', '/'), (',', '')):
77
- data = data.replace(search, replace)
78
-
79
- return base64.b64decode(data)
80
-
81
- def str_to_a32(b):
82
- if isinstance(b, str): b = makebyte(b)
83
- if len(b) % 4: b += b'\0' * (4 - len(b) % 4)
84
-
85
- return struct.unpack('>%dI' % (len(b) / 4), b)
86
-
87
- def mega_download_file(file_handle, file_key, dest_path=None, dest_filename=None, file=None):
88
- if file is None:
89
- file_key = str_to_a32(base64_url_decode(file_key))
90
- file_data = _api_request({'a': 'g', 'g': 1, 'p': file_handle})
91
-
92
- k = (file_key[0] ^ file_key[4], file_key[1] ^ file_key[5], file_key[2] ^ file_key[6], file_key[3] ^ file_key[7])
93
- iv = file_key[4:6] + (0, 0)
94
- meta_mac = file_key[6:8]
95
- else:
96
- file_data = _api_request({'a': 'g', 'g': 1, 'n': file['h']})
97
- k = file['k']
98
- iv = file['iv']
99
- meta_mac = file['meta_mac']
100
-
101
- if 'g' not in file_data: raise Exception(translations["file_not_access"])
102
- file_size = file_data['s']
103
-
104
- attribs = decrypt_attr(base64_url_decode(file_data['at']), k)
105
-
106
- file_name = dest_filename if dest_filename is not None else attribs['n']
107
- input_file = requests.get(file_data['g'], stream=True).raw
108
-
109
- if dest_path is None: dest_path = ''
110
- else: dest_path += '/'
111
-
112
- temp_output_file = tempfile.NamedTemporaryFile(mode='w+b', prefix='megapy_', delete=False)
113
- k_str = a32_to_str(k)
114
-
115
- aes = AES.new(k_str, AES.MODE_CTR, counter=Counter.new(128, initial_value=((iv[0] << 32) + iv[1]) << 64))
116
- mac_str = b'\0' * 16
117
- mac_encryptor = AES.new(k_str, AES.MODE_CBC, mac_str)
118
-
119
- iv_str = a32_to_str([iv[0], iv[1], iv[0], iv[1]])
120
- pbar = tqdm.tqdm(total=file_size, ncols=100, unit="byte")
121
-
122
- for _, chunk_size in get_chunks(file_size):
123
- chunk = aes.decrypt(input_file.read(chunk_size))
124
- temp_output_file.write(chunk)
125
-
126
- pbar.update(len(chunk))
127
- encryptor = AES.new(k_str, AES.MODE_CBC, iv_str)
128
-
129
- for i in range(0, len(chunk)-16, 16):
130
- block = chunk[i:i + 16]
131
- encryptor.encrypt(block)
132
-
133
- if file_size > 16: i += 16
134
- else: i = 0
135
-
136
- block = chunk[i:i + 16]
137
- if len(block) % 16: block += b'\0' * (16 - (len(block) % 16))
138
-
139
- mac_str = mac_encryptor.encrypt(encryptor.encrypt(block))
140
-
141
- file_mac = str_to_a32(mac_str)
142
- temp_output_file.close()
143
-
144
- if (file_mac[0] ^ file_mac[1], file_mac[2] ^ file_mac[3]) != meta_mac: raise ValueError(translations["mac_not_match"])
145
-
146
- file_path = os.path.join(dest_path, file_name)
147
- if os.path.exists(file_path): os.remove(file_path)
148
-
149
- shutil.move(temp_output_file.name, file_path)
150
-
151
- def mega_download_url(url, dest_path=None, dest_filename=None):
152
- if '/file/' in url:
153
- url = url.replace(' ', '')
154
- file_id = re.findall(r'\W\w\w\w\w\w\w\w\w\W', url)[0][1:-1]
155
-
156
- path = f'{file_id}!{url[re.search(file_id, url).end() + 1:]}'.split('!')
157
- elif '!' in url: path = re.findall(r'/#!(.*)', url)[0].split('!')
158
- else: raise Exception(translations["missing_url"])
159
-
160
- return mega_download_file(file_handle=path[0], file_key=path[1], dest_path=dest_path, dest_filename=dest_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/noisereduce.py DELETED
@@ -1,200 +0,0 @@
1
- import torch
2
- import tempfile
3
- import numpy as np
4
-
5
- from tqdm.auto import tqdm
6
- from joblib import Parallel, delayed
7
- from torch.nn.functional import conv1d, conv2d
8
-
9
- from main.configs.config import Config
10
- translations = Config().translations
11
-
12
- @torch.no_grad()
13
- def amp_to_db(x, eps = torch.finfo(torch.float32).eps, top_db = 40):
14
- x_db = 20 * torch.log10(x.abs() + eps)
15
- return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1))
16
-
17
- @torch.no_grad()
18
- def temperature_sigmoid(x, x0, temp_coeff):
19
- return torch.sigmoid((x - x0) / temp_coeff)
20
-
21
- @torch.no_grad()
22
- def linspace(start, stop, num = 50, endpoint = True, **kwargs):
23
- return torch.linspace(start, stop, num, **kwargs) if endpoint else torch.linspace(start, stop, num + 1, **kwargs)[:-1]
24
-
25
- def _smoothing_filter(n_grad_freq, n_grad_time):
26
- smoothing_filter = np.outer(np.concatenate([np.linspace(0, 1, n_grad_freq + 1, endpoint=False), np.linspace(1, 0, n_grad_freq + 2)])[1:-1], np.concatenate([np.linspace(0, 1, n_grad_time + 1, endpoint=False), np.linspace(1, 0, n_grad_time + 2)])[1:-1])
27
- return smoothing_filter / np.sum(smoothing_filter)
28
-
29
- class SpectralGate:
30
- def __init__(self, y, sr, prop_decrease, chunk_size, padding, n_fft, win_length, hop_length, time_constant_s, freq_mask_smooth_hz, time_mask_smooth_ms, tmp_folder, use_tqdm, n_jobs):
31
- self.sr = sr
32
- self.flat = False
33
- y = np.array(y)
34
-
35
- if len(y.shape) == 1:
36
- self.y = np.expand_dims(y, 0)
37
- self.flat = True
38
- elif len(y.shape) > 2: raise ValueError(translations["waveform"])
39
- else: self.y = y
40
-
41
- self._dtype = y.dtype
42
- self.n_channels, self.n_frames = self.y.shape
43
- self._chunk_size = chunk_size
44
- self.padding = padding
45
- self.n_jobs = n_jobs
46
- self.use_tqdm = use_tqdm
47
- self._tmp_folder = tmp_folder
48
- self._n_fft = n_fft
49
- self._win_length = self._n_fft if win_length is None else win_length
50
- self._hop_length = (self._win_length // 4) if hop_length is None else hop_length
51
- self._time_constant_s = time_constant_s
52
- self._prop_decrease = prop_decrease
53
-
54
- if (freq_mask_smooth_hz is None) & (time_mask_smooth_ms is None): self.smooth_mask = False
55
- else: self._generate_mask_smoothing_filter(freq_mask_smooth_hz, time_mask_smooth_ms)
56
-
57
- def _generate_mask_smoothing_filter(self, freq_mask_smooth_hz, time_mask_smooth_ms):
58
- if freq_mask_smooth_hz is None: n_grad_freq = 1
59
- else:
60
- n_grad_freq = int(freq_mask_smooth_hz / (self.sr / (self._n_fft / 2)))
61
- if n_grad_freq < 1: raise ValueError(translations["freq_mask_smooth_hz"].format(hz=int((self.sr / (self._n_fft / 2)))))
62
-
63
- if time_mask_smooth_ms is None: n_grad_time = 1
64
- else:
65
- n_grad_time = int(time_mask_smooth_ms / ((self._hop_length / self.sr) * 1000))
66
- if n_grad_time < 1: raise ValueError(translations["time_mask_smooth_ms"].format(ms=int((self._hop_length / self.sr) * 1000)))
67
-
68
- if (n_grad_time == 1) & (n_grad_freq == 1): self.smooth_mask = False
69
- else:
70
- self.smooth_mask = True
71
- self._smoothing_filter = _smoothing_filter(n_grad_freq, n_grad_time)
72
-
73
- def _read_chunk(self, i1, i2):
74
- i1b = 0 if i1 < 0 else i1
75
- i2b = self.n_frames if i2 > self.n_frames else i2
76
- chunk = np.zeros((self.n_channels, i2 - i1))
77
- chunk[:, i1b - i1: i2b - i1] = self.y[:, i1b:i2b]
78
- return chunk
79
-
80
- def filter_chunk(self, start_frame, end_frame):
81
- i1 = start_frame - self.padding
82
- return self._do_filter(self._read_chunk(i1, (end_frame + self.padding)))[:, start_frame - i1: end_frame - i1]
83
-
84
- def _get_filtered_chunk(self, ind):
85
- start0 = ind * self._chunk_size
86
- end0 = (ind + 1) * self._chunk_size
87
- return self.filter_chunk(start_frame=start0, end_frame=end0)
88
-
89
- def _do_filter(self, chunk):
90
- pass
91
-
92
- def _iterate_chunk(self, filtered_chunk, pos, end0, start0, ich):
93
- filtered_chunk[:, pos: pos + end0 - start0] = self._get_filtered_chunk(ich)[:, start0:end0]
94
- pos += end0 - start0
95
-
96
- def get_traces(self, start_frame=None, end_frame=None):
97
- if start_frame is None: start_frame = 0
98
- if end_frame is None: end_frame = self.n_frames
99
-
100
- if self._chunk_size is not None:
101
- if end_frame - start_frame > self._chunk_size:
102
- ich1 = int(start_frame / self._chunk_size)
103
- ich2 = int((end_frame - 1) / self._chunk_size)
104
-
105
- with tempfile.NamedTemporaryFile(prefix=self._tmp_folder) as fp:
106
- filtered_chunk = np.memmap(fp, dtype=self._dtype, shape=(self.n_channels, int(end_frame - start_frame)), mode="w+")
107
- pos_list, start_list, end_list = [], [], []
108
- pos = 0
109
-
110
- for ich in range(ich1, ich2 + 1):
111
- start0 = (start_frame - ich * self._chunk_size) if ich == ich1 else 0
112
- end0 = end_frame - ich * self._chunk_size if ich == ich2 else self._chunk_size
113
- pos_list.append(pos)
114
- start_list.append(start0)
115
- end_list.append(end0)
116
- pos += end0 - start0
117
-
118
- Parallel(n_jobs=self.n_jobs)(delayed(self._iterate_chunk)(filtered_chunk, pos, end0, start0, ich) for pos, start0, end0, ich in zip(tqdm(pos_list, disable=not (self.use_tqdm)), start_list, end_list, range(ich1, ich2 + 1)))
119
- return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
120
-
121
- filtered_chunk = self.filter_chunk(start_frame=0, end_frame=end_frame)
122
- return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
123
-
124
- class TG(torch.nn.Module):
125
- @torch.no_grad()
126
- def __init__(self, sr, nonstationary = False, n_std_thresh_stationary = 1.5, n_thresh_nonstationary = 1.3, temp_coeff_nonstationary = 0.1, n_movemean_nonstationary = 20, prop_decrease = 1.0, n_fft = 1024, win_length = None, hop_length = None, freq_mask_smooth_hz = 500, time_mask_smooth_ms = 50):
127
- super().__init__()
128
- self.sr = sr
129
- self.nonstationary = nonstationary
130
- assert 0.0 <= prop_decrease <= 1.0
131
- self.prop_decrease = prop_decrease
132
- self.n_fft = n_fft
133
- self.win_length = self.n_fft if win_length is None else win_length
134
- self.hop_length = self.win_length // 4 if hop_length is None else hop_length
135
- self.n_std_thresh_stationary = n_std_thresh_stationary
136
- self.temp_coeff_nonstationary = temp_coeff_nonstationary
137
- self.n_movemean_nonstationary = n_movemean_nonstationary
138
- self.n_thresh_nonstationary = n_thresh_nonstationary
139
- self.freq_mask_smooth_hz = freq_mask_smooth_hz
140
- self.time_mask_smooth_ms = time_mask_smooth_ms
141
- self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
142
-
143
- @torch.no_grad()
144
- def _generate_mask_smoothing_filter(self):
145
- if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None
146
- n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))))
147
- if n_grad_freq < 1: raise ValueError(translations["freq_mask_smooth_hz"].format(hz=int((self.sr / (self._n_fft / 2)))))
148
-
149
- n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)))
150
- if n_grad_time < 1: raise ValueError(translations["time_mask_smooth_ms"].format(ms=int((self._hop_length / self.sr) * 1000)))
151
- if n_grad_time == 1 and n_grad_freq == 1: return None
152
-
153
- smoothing_filter = torch.outer(torch.cat([linspace(0, 1, n_grad_freq + 1, endpoint=False), linspace(1, 0, n_grad_freq + 2)])[1:-1], torch.cat([linspace(0, 1, n_grad_time + 1, endpoint=False), linspace(1, 0, n_grad_time + 2)])[1:-1]).unsqueeze(0).unsqueeze(0)
154
- return smoothing_filter / smoothing_filter.sum()
155
-
156
- @torch.no_grad()
157
- def _stationary_mask(self, X_db, xn = None):
158
- XN_db = amp_to_db(torch.stft(xn, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(xn.device))).to(dtype=X_db.dtype) if xn is not None else X_db
159
- std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
160
- return torch.gt(X_db, (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2))
161
-
162
- @torch.no_grad()
163
- def _nonstationary_mask(self, X_abs):
164
- X_smoothed = (conv1d(X_abs.reshape(-1, 1, X_abs.shape[-1]), torch.ones(self.n_movemean_nonstationary, dtype=X_abs.dtype, device=X_abs.device).view(1, 1, -1), padding="same").view(X_abs.shape) / self.n_movemean_nonstationary)
165
- return temperature_sigmoid(((X_abs - X_smoothed) / X_smoothed), self.n_thresh_nonstationary, self.temp_coeff_nonstationary)
166
-
167
- def forward(self, x, xn = None):
168
- assert x.ndim == 2
169
- if x.shape[-1] < self.win_length * 2: raise Exception(f"{translations['x']} {self.win_length * 2}")
170
- assert xn is None or xn.ndim == 1 or xn.ndim == 2
171
- if xn is not None and xn.shape[-1] < self.win_length * 2: raise Exception(f"{translations['xn']} {self.win_length * 2}")
172
-
173
- X = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(x.device))
174
- sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X), xn)
175
-
176
- sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
177
- if self.smoothing_filter is not None: sig_mask = conv2d(sig_mask.unsqueeze(1), self.smoothing_filter.to(sig_mask.dtype), padding="same")
178
-
179
- Y = X * sig_mask.squeeze(1)
180
- return torch.istft(Y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=True, window=torch.hann_window(self.win_length).to(Y.device)).to(dtype=x.dtype)
181
-
182
- class StreamedTorchGate(SpectralGate):
183
- def __init__(self, y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, n_std_thresh_stationary=1.5, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, n_jobs=1, device="cpu"):
184
- super().__init__(y=y, sr=sr, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, tmp_folder=tmp_folder, prop_decrease=prop_decrease, use_tqdm=use_tqdm, n_jobs=n_jobs)
185
- self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
186
-
187
- if y_noise is not None:
188
- if y_noise.shape[-1] > y.shape[-1] and clip_noise_stationary: y_noise = y_noise[: y.shape[-1]]
189
- y_noise = torch.from_numpy(y_noise).to(device)
190
- if len(y_noise.shape) == 1: y_noise = y_noise.unsqueeze(0)
191
-
192
- self.y_noise = y_noise
193
- self.tg = TG(sr=sr, nonstationary=not stationary, n_std_thresh_stationary=n_std_thresh_stationary, n_thresh_nonstationary=thresh_n_mult_nonstationary, temp_coeff_nonstationary=1 / sigmoid_slope_nonstationary, n_movemean_nonstationary=int(time_constant_s / self._hop_length * sr), prop_decrease=prop_decrease, n_fft=self._n_fft, win_length=self._win_length, hop_length=self._hop_length, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms).to(device)
194
-
195
- def _do_filter(self, chunk):
196
- if type(chunk) is np.ndarray: chunk = torch.from_numpy(chunk).to(self.device)
197
- return self.tg(x=chunk, xn=self.y_noise).cpu().detach().numpy()
198
-
199
- def reduce_noise(y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, device="cpu"):
200
- return StreamedTorchGate(y=y, sr=sr, stationary=stationary, y_noise=y_noise, prop_decrease=prop_decrease, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, thresh_n_mult_nonstationary=thresh_n_mult_nonstationary, sigmoid_slope_nonstationary=sigmoid_slope_nonstationary, tmp_folder=tmp_folder, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, clip_noise_stationary=clip_noise_stationary, use_tqdm=use_tqdm, n_jobs=1, device=device).get_traces()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/pixeldrain.py DELETED
@@ -1,16 +0,0 @@
1
- import os
2
- import requests
3
-
4
- def pixeldrain(url, output_dir):
5
- try:
6
- response = requests.get(f"https://pixeldrain.com/api/file/{url.split('pixeldrain.com/u/')[1]}")
7
-
8
- if response.status_code == 200:
9
- file_path = os.path.join(output_dir, (response.headers.get("Content-Disposition").split("filename=")[-1].strip('";')))
10
-
11
- with open(file_path, "wb") as newfile:
12
- newfile.write(response.content)
13
- return file_path
14
- else: return None
15
- except Exception as e:
16
- raise RuntimeError(e)