ttettheu commited on
Commit
56b2bcf
·
verified ·
1 Parent(s): 75754ee

Create mdx_processing_script.py

Browse files
Files changed (1) hide show
  1. mdx_processing_script.py +153 -0
mdx_processing_script.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import requests
3
+ import subprocess
4
+ import logging
5
+ import sys
6
+ from bs4 import BeautifulSoup
7
+ import torch, pdb, os, warnings, librosa
8
+ import soundfile as sf
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ import torch
12
+ now_dir = os.getcwd()
13
+ sys.path.append(now_dir)
14
+ import mdx
15
+ branch = "https://github.com/NaJeongMo/Colab-for-MDX_B"
16
+
17
+ model_params = "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json"
18
+ _Models = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
19
+ # _models = "https://pastebin.com/raw/jBzYB8vz"
20
+ _models = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
21
+ #stem_naming = "https://pastebin.com/raw/mpH4hRcF"
22
+
23
+ file_folder = "Colab-for-MDX_B"
24
+ model_ids = requests.get(_models).json()
25
+ model_ids = model_ids["mdx_download_list"].values()
26
+ #print(model_ids)
27
+ model_params = requests.get(model_params).json()
28
+ #stem_naming = requests.get(stem_naming).json()
29
+ stem_naming = {
30
+ "Vocals": "Instrumental",
31
+ "Other": "Instruments",
32
+ "Instrumental": "Vocals",
33
+ "Drums": "Drumless",
34
+ "Bass": "Bassless"
35
+ }
36
+
37
+ os.makedirs("tmp_models", exist_ok=True)
38
+
39
+ warnings.filterwarnings("ignore")
40
+ cpu = torch.device("cpu")
41
+ if torch.cuda.is_available():
42
+ device = torch.device("cuda:0")
43
+ elif torch.backends.mps.is_available():
44
+ device = torch.device("mps")
45
+ else:
46
+ device = torch.device("cpu")
47
+
48
+
49
+ def get_model_list():
50
+ return model_ids
51
+
52
+ def id_to_ptm(mkey):
53
+ if mkey in model_ids:
54
+ mpath = f"{now_dir}/tmp_models/{mkey}"
55
+ if not os.path.exists(f'{now_dir}/tmp_models/{mkey}'):
56
+ print('Downloading model...',end=' ')
57
+ subprocess.run(
58
+ ["wget", _Models+mkey, "-O", mpath]
59
+ )
60
+ print(f'saved to {mpath}')
61
+ # get_ipython().system(f'gdown {model_id} -O /content/tmp_models/{mkey}')
62
+ return mpath
63
+ else:
64
+ return mpath
65
+ else:
66
+ mpath = f'models/{mkey}'
67
+ return mpath
68
+
69
+ def prepare_mdx(onnx,custom_param=False, dim_f=None, dim_t=None, n_fft=None, stem_name=None, compensation=None):
70
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
71
+ if custom_param:
72
+ assert not (dim_f is None or dim_t is None or n_fft is None or compensation is None), 'Custom parameter selected, but incomplete parameters are provided.'
73
+ mdx_model = mdx.MDX_Model(
74
+ device,
75
+ dim_f = dim_f,
76
+ dim_t = dim_t,
77
+ n_fft = n_fft,
78
+ stem_name=stem_name,
79
+ compensation=compensation
80
+ )
81
+ else:
82
+ model_hash = mdx.MDX.get_hash(onnx)
83
+ if model_hash in model_params:
84
+ mp = model_params.get(model_hash)
85
+ mdx_model = mdx.MDX_Model(
86
+ device,
87
+ dim_f = mp["mdx_dim_f_set"],
88
+ dim_t = 2**mp["mdx_dim_t_set"],
89
+ n_fft = mp["mdx_n_fft_scale_set"],
90
+ stem_name=mp["primary_stem"],
91
+ compensation=compensation if not custom_param and compensation is not None else mp["compensate"]
92
+ )
93
+ return mdx_model
94
+
95
+ def run_mdx(onnx, mdx_model,filename, output_format='wav',diff=False,suffix=None,diff_suffix=None, denoise=False, m_threads=2):
96
+ mdx_sess = mdx.MDX(onnx,mdx_model)
97
+ print(f"Processing: {filename}")
98
+ if filename.lower().endswith('.wav'):
99
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
100
+ else:
101
+ temp_wav = 'temp_audio.wav'
102
+ subprocess.run(['ffmpeg', '-i', filename, '-ar', '44100', '-ac', '2', temp_wav]) # Convert to WAV format
103
+ wave, sr = librosa.load(temp_wav, mono=False, sr=44100)
104
+ os.remove(temp_wav)
105
+
106
+ #wave, sr = librosa.load(filename,mono=False, sr=44100)
107
+ # normalizing input wave gives better output
108
+ peak = max(np.max(wave), abs(np.min(wave)))
109
+ wave /= peak
110
+ if denoise:
111
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
112
+ wave_processed *= 0.5
113
+ else:
114
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
115
+ # return to previous peak
116
+ wave_processed *= peak
117
+
118
+ stem_name = mdx_model.stem_name if suffix is None else suffix # use suffix if provided
119
+ save_path = os.path.basename(os.path.splitext(filename)[0])
120
+ #vocals_save_path = os.path.join(vocals_folder, f"{save_path}_{stem_name}.{output_format}")
121
+ #instrumental_save_path = os.path.join(instrumental_folder, f"{save_path}_{stem_name}.{output_format}")
122
+ save_path = f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.{output_format}"
123
+ save_path = os.path.join(
124
+ 'audios',
125
+ save_path
126
+ )
127
+ sf.write(
128
+ save_path,
129
+ wave_processed.T,
130
+ sr
131
+ )
132
+
133
+ print(f'done, saved to: {save_path}')
134
+
135
+ if diff:
136
+ diff_stem_name = stem_naming.get(stem_name) if diff_suffix is None else diff_suffix # use suffix if provided
137
+ stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
138
+ save_path = f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.{output_format}"
139
+ save_path = os.path.join(
140
+ 'audio-others',
141
+ save_path
142
+ )
143
+ sf.write(
144
+ save_path,
145
+ (-wave_processed.T*mdx_model.compensation)+wave.T,
146
+ sr
147
+ )
148
+ print(f'invert done, saved to: {save_path}')
149
+ del mdx_sess, wave_processed, wave
150
+ gc.collect()
151
+
152
+ if __name__ == "__main__":
153
+ print()