Spaces:
Sleeping
Sleeping
Create mdx_processing_script.py
Browse files- mdx_processing_script.py +153 -0
mdx_processing_script.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import requests
|
3 |
+
import subprocess
|
4 |
+
import logging
|
5 |
+
import sys
|
6 |
+
from bs4 import BeautifulSoup
|
7 |
+
import torch, pdb, os, warnings, librosa
|
8 |
+
import soundfile as sf
|
9 |
+
from tqdm import tqdm
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
now_dir = os.getcwd()
|
13 |
+
sys.path.append(now_dir)
|
14 |
+
import mdx
|
15 |
+
branch = "https://github.com/NaJeongMo/Colab-for-MDX_B"
|
16 |
+
|
17 |
+
model_params = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json"
|
18 |
+
_Models = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
|
19 |
+
# _models = "https://pastebin.com/raw/jBzYB8vz"
|
20 |
+
_models = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
|
21 |
+
#stem_naming = "https://pastebin.com/raw/mpH4hRcF"
|
22 |
+
|
23 |
+
file_folder = "Colab-for-MDX_B"
|
24 |
+
model_ids = requests.get(_models).json()
|
25 |
+
model_ids = model_ids["mdx_download_list"].values()
|
26 |
+
#print(model_ids)
|
27 |
+
model_params = requests.get(model_params).json()
|
28 |
+
#stem_naming = requests.get(stem_naming).json()
|
29 |
+
stem_naming = {
|
30 |
+
"Vocals": "Instrumental",
|
31 |
+
"Other": "Instruments",
|
32 |
+
"Instrumental": "Vocals",
|
33 |
+
"Drums": "Drumless",
|
34 |
+
"Bass": "Bassless"
|
35 |
+
}
|
36 |
+
|
37 |
+
os.makedirs("tmp_models", exist_ok=True)
|
38 |
+
|
39 |
+
warnings.filterwarnings("ignore")
|
40 |
+
cpu = torch.device("cpu")
|
41 |
+
if torch.cuda.is_available():
|
42 |
+
device = torch.device("cuda:0")
|
43 |
+
elif torch.backends.mps.is_available():
|
44 |
+
device = torch.device("mps")
|
45 |
+
else:
|
46 |
+
device = torch.device("cpu")
|
47 |
+
|
48 |
+
|
49 |
+
def get_model_list():
|
50 |
+
return model_ids
|
51 |
+
|
52 |
+
def id_to_ptm(mkey):
|
53 |
+
if mkey in model_ids:
|
54 |
+
mpath = f"{now_dir}/tmp_models/{mkey}"
|
55 |
+
if not os.path.exists(f'{now_dir}/tmp_models/{mkey}'):
|
56 |
+
print('Downloading model...',end=' ')
|
57 |
+
subprocess.run(
|
58 |
+
["wget", _Models+mkey, "-O", mpath]
|
59 |
+
)
|
60 |
+
print(f'saved to {mpath}')
|
61 |
+
# get_ipython().system(f'gdown {model_id} -O /content/tmp_models/{mkey}')
|
62 |
+
return mpath
|
63 |
+
else:
|
64 |
+
return mpath
|
65 |
+
else:
|
66 |
+
mpath = f'models/{mkey}'
|
67 |
+
return mpath
|
68 |
+
|
69 |
+
def prepare_mdx(onnx,custom_param=False, dim_f=None, dim_t=None, n_fft=None, stem_name=None, compensation=None):
|
70 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
71 |
+
if custom_param:
|
72 |
+
assert not (dim_f is None or dim_t is None or n_fft is None or compensation is None), 'Custom parameter selected, but incomplete parameters are provided.'
|
73 |
+
mdx_model = mdx.MDX_Model(
|
74 |
+
device,
|
75 |
+
dim_f = dim_f,
|
76 |
+
dim_t = dim_t,
|
77 |
+
n_fft = n_fft,
|
78 |
+
stem_name=stem_name,
|
79 |
+
compensation=compensation
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
model_hash = mdx.MDX.get_hash(onnx)
|
83 |
+
if model_hash in model_params:
|
84 |
+
mp = model_params.get(model_hash)
|
85 |
+
mdx_model = mdx.MDX_Model(
|
86 |
+
device,
|
87 |
+
dim_f = mp["mdx_dim_f_set"],
|
88 |
+
dim_t = 2**mp["mdx_dim_t_set"],
|
89 |
+
n_fft = mp["mdx_n_fft_scale_set"],
|
90 |
+
stem_name=mp["primary_stem"],
|
91 |
+
compensation=compensation if not custom_param and compensation is not None else mp["compensate"]
|
92 |
+
)
|
93 |
+
return mdx_model
|
94 |
+
|
95 |
+
def run_mdx(onnx, mdx_model,filename, output_format='wav',diff=False,suffix=None,diff_suffix=None, denoise=False, m_threads=2):
|
96 |
+
mdx_sess = mdx.MDX(onnx,mdx_model)
|
97 |
+
print(f"Processing: {filename}")
|
98 |
+
if filename.lower().endswith('.wav'):
|
99 |
+
wave, sr = librosa.load(filename, mono=False, sr=44100)
|
100 |
+
else:
|
101 |
+
temp_wav = 'temp_audio.wav'
|
102 |
+
subprocess.run(['ffmpeg', '-i', filename, '-ar', '44100', '-ac', '2', temp_wav]) # Convert to WAV format
|
103 |
+
wave, sr = librosa.load(temp_wav, mono=False, sr=44100)
|
104 |
+
os.remove(temp_wav)
|
105 |
+
|
106 |
+
#wave, sr = librosa.load(filename,mono=False, sr=44100)
|
107 |
+
# normalizing input wave gives better output
|
108 |
+
peak = max(np.max(wave), abs(np.min(wave)))
|
109 |
+
wave /= peak
|
110 |
+
if denoise:
|
111 |
+
wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
|
112 |
+
wave_processed *= 0.5
|
113 |
+
else:
|
114 |
+
wave_processed = mdx_sess.process_wave(wave, m_threads)
|
115 |
+
# return to previous peak
|
116 |
+
wave_processed *= peak
|
117 |
+
|
118 |
+
stem_name = mdx_model.stem_name if suffix is None else suffix # use suffix if provided
|
119 |
+
save_path = os.path.basename(os.path.splitext(filename)[0])
|
120 |
+
#vocals_save_path = os.path.join(vocals_folder, f"{save_path}_{stem_name}.{output_format}")
|
121 |
+
#instrumental_save_path = os.path.join(instrumental_folder, f"{save_path}_{stem_name}.{output_format}")
|
122 |
+
save_path = f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.{output_format}"
|
123 |
+
save_path = os.path.join(
|
124 |
+
'audios',
|
125 |
+
save_path
|
126 |
+
)
|
127 |
+
sf.write(
|
128 |
+
save_path,
|
129 |
+
wave_processed.T,
|
130 |
+
sr
|
131 |
+
)
|
132 |
+
|
133 |
+
print(f'done, saved to: {save_path}')
|
134 |
+
|
135 |
+
if diff:
|
136 |
+
diff_stem_name = stem_naming.get(stem_name) if diff_suffix is None else diff_suffix # use suffix if provided
|
137 |
+
stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
|
138 |
+
save_path = f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.{output_format}"
|
139 |
+
save_path = os.path.join(
|
140 |
+
'audio-others',
|
141 |
+
save_path
|
142 |
+
)
|
143 |
+
sf.write(
|
144 |
+
save_path,
|
145 |
+
(-wave_processed.T*mdx_model.compensation)+wave.T,
|
146 |
+
sr
|
147 |
+
)
|
148 |
+
print(f'invert done, saved to: {save_path}')
|
149 |
+
del mdx_sess, wave_processed, wave
|
150 |
+
gc.collect()
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
print()
|