CCockrum commited on
Commit
a20b9cf
·
verified ·
1 Parent(s): 6058904

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1474 -0
app.py ADDED
@@ -0,0 +1,1474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.system("pip install ./ort_nightly_gpu-1.17.0.dev20240118002-cp310-cp310-manylinux_2_28_x86_64.whl")
3
+ os.system("pip install ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/")
4
+ import gc
5
+ import hashlib
6
+ import queue
7
+ import threading
8
+ import json
9
+ import shlex
10
+ import sys
11
+ import subprocess
12
+ import librosa
13
+ import numpy as np
14
+ import soundfile as sf
15
+ import torch
16
+ from tqdm import tqdm
17
+ from utils import (
18
+ remove_directory_contents,
19
+ create_directories,
20
+ download_manager,
21
+ )
22
+ import random
23
+ import spaces
24
+ from utils import logger
25
+ import onnxruntime as ort
26
+ import warnings
27
+ import spaces
28
+ import gradio as gr
29
+ import logging
30
+ import time
31
+ import traceback
32
+ from pedalboard import Pedalboard, Reverb, Delay, Chorus, Compressor, Gain, HighpassFilter, LowpassFilter
33
+ from pedalboard.io import AudioFile
34
+ import numpy as np
35
+ import yt_dlp
36
+
37
+ warnings.filterwarnings("ignore")
38
+
39
+ title = "<center><strong><font size='7'>Audio🔹separator</font></strong></center>"
40
+ description = "This demo uses the MDX-Net models for vocal and background sound separation."
41
+ theme = "NoCrypt/miku"
42
+
43
+ stem_naming = {
44
+ "Vocals": "Instrumental",
45
+ "Other": "Instruments",
46
+ "Instrumental": "Vocals",
47
+ "Drums": "Drumless",
48
+ "Bass": "Bassless",
49
+ }
50
+
51
+
52
+ class MDXModel:
53
+ def __init__(
54
+ self,
55
+ device,
56
+ dim_f,
57
+ dim_t,
58
+ n_fft,
59
+ hop=1024,
60
+ stem_name=None,
61
+ compensation=1.000,
62
+ ):
63
+ self.dim_f = dim_f
64
+ self.dim_t = dim_t
65
+ self.dim_c = 4
66
+ self.n_fft = n_fft
67
+ self.hop = hop
68
+ self.stem_name = stem_name
69
+ self.compensation = compensation
70
+
71
+ self.n_bins = self.n_fft // 2 + 1
72
+ self.chunk_size = hop * (self.dim_t - 1)
73
+ self.window = torch.hann_window(
74
+ window_length=self.n_fft, periodic=True
75
+ ).to(device)
76
+
77
+ out_c = self.dim_c
78
+
79
+ self.freq_pad = torch.zeros(
80
+ [1, out_c, self.n_bins - self.dim_f, self.dim_t]
81
+ ).to(device)
82
+
83
+ def stft(self, x):
84
+ x = x.reshape([-1, self.chunk_size])
85
+ x = torch.stft(
86
+ x,
87
+ n_fft=self.n_fft,
88
+ hop_length=self.hop,
89
+ window=self.window,
90
+ center=True,
91
+ return_complex=True,
92
+ )
93
+ x = torch.view_as_real(x)
94
+ x = x.permute([0, 3, 1, 2])
95
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
96
+ [-1, 4, self.n_bins, self.dim_t]
97
+ )
98
+ return x[:, :, : self.dim_f]
99
+
100
+ def istft(self, x, freq_pad=None):
101
+ freq_pad = (
102
+ self.freq_pad.repeat([x.shape[0], 1, 1, 1])
103
+ if freq_pad is None
104
+ else freq_pad
105
+ )
106
+ x = torch.cat([x, freq_pad], -2)
107
+ # c = 4*2 if self.target_name=='*' else 2
108
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
109
+ [-1, 2, self.n_bins, self.dim_t]
110
+ )
111
+ x = x.permute([0, 2, 3, 1])
112
+ x = x.contiguous()
113
+ x = torch.view_as_complex(x)
114
+ x = torch.istft(
115
+ x,
116
+ n_fft=self.n_fft,
117
+ hop_length=self.hop,
118
+ window=self.window,
119
+ center=True,
120
+ )
121
+ return x.reshape([-1, 2, self.chunk_size])
122
+
123
+
124
+ class MDX:
125
+ DEFAULT_SR = 44100
126
+ # Unit: seconds
127
+ DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
128
+ DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
129
+
130
+ def __init__(
131
+ self, model_path: str, params: MDXModel, processor=0
132
+ ):
133
+ # Set the device and the provider (CPU or CUDA)
134
+ self.device = (
135
+ torch.device(f"cuda:{processor}")
136
+ if processor >= 0
137
+ else torch.device("cpu")
138
+ )
139
+ self.provider = (
140
+ ["CUDAExecutionProvider"]
141
+ if processor >= 0
142
+ else ["CPUExecutionProvider"]
143
+ )
144
+
145
+ self.model = params
146
+
147
+ # Load the ONNX model using ONNX Runtime
148
+ self.ort = ort.InferenceSession(model_path, providers=self.provider)
149
+ # Preload the model for faster performance
150
+ self.ort.run(
151
+ None,
152
+ {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()},
153
+ )
154
+ self.process = lambda spec: self.ort.run(
155
+ None, {"input": spec.cpu().numpy()}
156
+ )[0]
157
+
158
+ self.prog = None
159
+
160
+ @staticmethod
161
+ def get_hash(model_path):
162
+ try:
163
+ with open(model_path, "rb") as f:
164
+ f.seek(-10000 * 1024, 2)
165
+ model_hash = hashlib.md5(f.read()).hexdigest()
166
+ except: # noqa
167
+ model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest()
168
+
169
+ return model_hash
170
+
171
+ @staticmethod
172
+ def segment(
173
+ wave,
174
+ combine=True,
175
+ chunk_size=DEFAULT_CHUNK_SIZE,
176
+ margin_size=DEFAULT_MARGIN_SIZE,
177
+ ):
178
+ """
179
+ Segment or join segmented wave array
180
+ Args:
181
+ wave: (np.array) Wave array to be segmented or joined
182
+ combine: (bool) If True, combines segmented wave array.
183
+ If False, segments wave array.
184
+ chunk_size: (int) Size of each segment (in samples)
185
+ margin_size: (int) Size of margin between segments (in samples)
186
+ Returns:
187
+ numpy array: Segmented or joined wave array
188
+ """
189
+
190
+ if combine:
191
+ # Initializing as None instead of [] for later numpy array concatenation
192
+ processed_wave = None
193
+ for segment_count, segment in enumerate(wave):
194
+ start = 0 if segment_count == 0 else margin_size
195
+ end = None if segment_count == len(wave) - 1 else -margin_size
196
+ if margin_size == 0:
197
+ end = None
198
+ if processed_wave is None: # Create array for first segment
199
+ processed_wave = segment[:, start:end]
200
+ else: # Concatenate to existing array for subsequent segments
201
+ processed_wave = np.concatenate(
202
+ (processed_wave, segment[:, start:end]), axis=-1
203
+ )
204
+
205
+ else:
206
+ processed_wave = []
207
+ sample_count = wave.shape[-1]
208
+
209
+ if chunk_size <= 0 or chunk_size > sample_count:
210
+ chunk_size = sample_count
211
+
212
+ if margin_size > chunk_size:
213
+ margin_size = chunk_size
214
+
215
+ for segment_count, skip in enumerate(
216
+ range(0, sample_count, chunk_size)
217
+ ):
218
+ margin = 0 if segment_count == 0 else margin_size
219
+ end = min(skip + chunk_size + margin_size, sample_count)
220
+ start = skip - margin
221
+
222
+ cut = wave[:, start:end].copy()
223
+ processed_wave.append(cut)
224
+
225
+ if end == sample_count:
226
+ break
227
+
228
+ return processed_wave
229
+
230
+ def pad_wave(self, wave):
231
+ """
232
+ Pad the wave array to match the required chunk size
233
+ Args:
234
+ wave: (np.array) Wave array to be padded
235
+ Returns:
236
+ tuple: (padded_wave, pad, trim)
237
+ - padded_wave: Padded wave array
238
+ - pad: Number of samples that were padded
239
+ - trim: Number of samples that were trimmed
240
+ """
241
+ n_sample = wave.shape[1]
242
+ trim = self.model.n_fft // 2
243
+ gen_size = self.model.chunk_size - 2 * trim
244
+ pad = gen_size - n_sample % gen_size
245
+
246
+ # Padded wave
247
+ wave_p = np.concatenate(
248
+ (
249
+ np.zeros((2, trim)),
250
+ wave,
251
+ np.zeros((2, pad)),
252
+ np.zeros((2, trim)),
253
+ ),
254
+ 1,
255
+ )
256
+
257
+ mix_waves = []
258
+ for i in range(0, n_sample + pad, gen_size):
259
+ waves = np.array(wave_p[:, i:i + self.model.chunk_size])
260
+ mix_waves.append(waves)
261
+
262
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
263
+ self.device
264
+ )
265
+
266
+ return mix_waves, pad, trim
267
+
268
+ def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
269
+ """
270
+ Process each wave segment in a multi-threaded environment
271
+ Args:
272
+ mix_waves: (torch.Tensor) Wave segments to be processed
273
+ trim: (int) Number of samples trimmed during padding
274
+ pad: (int) Number of samples padded during padding
275
+ q: (queue.Queue) Queue to hold the processed wave segments
276
+ _id: (int) Identifier of the processed wave segment
277
+ Returns:
278
+ numpy array: Processed wave segment
279
+ """
280
+ mix_waves = mix_waves.split(1)
281
+ with torch.no_grad():
282
+ pw = []
283
+ for mix_wave in mix_waves:
284
+ self.prog.update()
285
+ spec = self.model.stft(mix_wave)
286
+ processed_spec = torch.tensor(self.process(spec))
287
+ processed_wav = self.model.istft(
288
+ processed_spec.to(self.device)
289
+ )
290
+ processed_wav = (
291
+ processed_wav[:, :, trim:-trim]
292
+ .transpose(0, 1)
293
+ .reshape(2, -1)
294
+ .cpu()
295
+ .numpy()
296
+ )
297
+ pw.append(processed_wav)
298
+ processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
299
+ q.put({_id: processed_signal})
300
+ return processed_signal
301
+
302
+ def process_wave(self, wave: np.array, mt_threads=1):
303
+ """
304
+ Process the wave array in a multi-threaded environment
305
+ Args:
306
+ wave: (np.array) Wave array to be processed
307
+ mt_threads: (int) Number of threads to be used for processing
308
+ Returns:
309
+ numpy array: Processed wave array
310
+ """
311
+ self.prog = tqdm(total=0)
312
+ chunk = wave.shape[-1] // mt_threads
313
+ waves = self.segment(wave, False, chunk)
314
+
315
+ # Create a queue to hold the processed wave segments
316
+ q = queue.Queue()
317
+ threads = []
318
+ for c, batch in enumerate(waves):
319
+ mix_waves, pad, trim = self.pad_wave(batch)
320
+ self.prog.total = len(mix_waves) * mt_threads
321
+ thread = threading.Thread(
322
+ target=self._process_wave, args=(mix_waves, trim, pad, q, c)
323
+ )
324
+ thread.start()
325
+ threads.append(thread)
326
+ for thread in threads:
327
+ thread.join()
328
+ self.prog.close()
329
+
330
+ processed_batches = []
331
+ while not q.empty():
332
+ processed_batches.append(q.get())
333
+ processed_batches = [
334
+ list(wave.values())[0]
335
+ for wave in sorted(
336
+ processed_batches, key=lambda d: list(d.keys())[0]
337
+ )
338
+ ]
339
+ assert len(processed_batches) == len(
340
+ waves
341
+ ), "Incomplete processed batches, please reduce batch size!"
342
+ return self.segment(processed_batches, True, chunk)
343
+
344
+
345
+ @spaces.GPU()
346
+ def run_mdx(
347
+ model_params,
348
+ output_dir,
349
+ model_path,
350
+ filename,
351
+ exclude_main=False,
352
+ exclude_inversion=False,
353
+ suffix=None,
354
+ invert_suffix=None,
355
+ denoise=False,
356
+ keep_orig=True,
357
+ m_threads=2,
358
+ device_base="cuda",
359
+ ):
360
+
361
+ if device_base == "cuda":
362
+ device = torch.device("cuda:0")
363
+ processor_num = 0
364
+ device_properties = torch.cuda.get_device_properties(device)
365
+ vram_gb = device_properties.total_memory / 1024**3
366
+ m_threads = 1 if vram_gb < 8 else (8 if vram_gb > 32 else 2)
367
+ logger.info(f"threads: {m_threads} vram: {vram_gb}")
368
+ else:
369
+ device = torch.device("cpu")
370
+ processor_num = -1
371
+ m_threads = 1
372
+
373
+ model_hash = MDX.get_hash(model_path)
374
+ mp = model_params.get(model_hash)
375
+ model = MDXModel(
376
+ device,
377
+ dim_f=mp["mdx_dim_f_set"],
378
+ dim_t=2 ** mp["mdx_dim_t_set"],
379
+ n_fft=mp["mdx_n_fft_scale_set"],
380
+ stem_name=mp["primary_stem"],
381
+ compensation=mp["compensate"],
382
+ )
383
+
384
+ mdx_sess = MDX(model_path, model, processor=processor_num)
385
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
386
+ # normalizing input wave gives better output
387
+ peak = max(np.max(wave), abs(np.min(wave)))
388
+ wave /= peak
389
+ if denoise:
390
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
391
+ mdx_sess.process_wave(wave, m_threads)
392
+ )
393
+ wave_processed *= 0.5
394
+ else:
395
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
396
+ # return to previous peak
397
+ wave_processed *= peak
398
+ stem_name = model.stem_name if suffix is None else suffix
399
+
400
+ main_filepath = None
401
+ if not exclude_main:
402
+ main_filepath = os.path.join(
403
+ output_dir,
404
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
405
+ )
406
+ sf.write(main_filepath, wave_processed.T, sr)
407
+
408
+ invert_filepath = None
409
+ if not exclude_inversion:
410
+ diff_stem_name = (
411
+ stem_naming.get(stem_name)
412
+ if invert_suffix is None
413
+ else invert_suffix
414
+ )
415
+ stem_name = (
416
+ f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
417
+ )
418
+ invert_filepath = os.path.join(
419
+ output_dir,
420
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
421
+ )
422
+ sf.write(
423
+ invert_filepath,
424
+ (-wave_processed.T * model.compensation) + wave.T,
425
+ sr,
426
+ )
427
+
428
+ if not keep_orig:
429
+ os.remove(filename)
430
+
431
+ del mdx_sess, wave_processed, wave
432
+ gc.collect()
433
+ torch.cuda.empty_cache()
434
+ return main_filepath, invert_filepath
435
+
436
+
437
+ def run_mdx_beta(
438
+ model_params,
439
+ output_dir,
440
+ model_path,
441
+ filename,
442
+ exclude_main=False,
443
+ exclude_inversion=False,
444
+ suffix=None,
445
+ invert_suffix=None,
446
+ denoise=False,
447
+ keep_orig=True,
448
+ m_threads=2,
449
+ device_base="",
450
+ ):
451
+
452
+ m_threads = 1
453
+ duration = librosa.get_duration(filename=filename)
454
+ if duration >= 60 and duration <= 120:
455
+ m_threads = 8
456
+ elif duration > 120:
457
+ m_threads = 16
458
+
459
+ logger.info(f"threads: {m_threads}")
460
+
461
+ model_hash = MDX.get_hash(model_path)
462
+ device = torch.device("cpu")
463
+ processor_num = -1
464
+ mp = model_params.get(model_hash)
465
+ model = MDXModel(
466
+ device,
467
+ dim_f=mp["mdx_dim_f_set"],
468
+ dim_t=2 ** mp["mdx_dim_t_set"],
469
+ n_fft=mp["mdx_n_fft_scale_set"],
470
+ stem_name=mp["primary_stem"],
471
+ compensation=mp["compensate"],
472
+ )
473
+
474
+ mdx_sess = MDX(model_path, model, processor=processor_num)
475
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
476
+ # normalizing input wave gives better output
477
+ peak = max(np.max(wave), abs(np.min(wave)))
478
+ wave /= peak
479
+ if denoise:
480
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
481
+ mdx_sess.process_wave(wave, m_threads)
482
+ )
483
+ wave_processed *= 0.5
484
+ else:
485
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
486
+ # return to previous peak
487
+ wave_processed *= peak
488
+ stem_name = model.stem_name if suffix is None else suffix
489
+
490
+ main_filepath = None
491
+ if not exclude_main:
492
+ main_filepath = os.path.join(
493
+ output_dir,
494
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
495
+ )
496
+ sf.write(main_filepath, wave_processed.T, sr)
497
+
498
+ invert_filepath = None
499
+ if not exclude_inversion:
500
+ diff_stem_name = (
501
+ stem_naming.get(stem_name)
502
+ if invert_suffix is None
503
+ else invert_suffix
504
+ )
505
+ stem_name = (
506
+ f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
507
+ )
508
+ invert_filepath = os.path.join(
509
+ output_dir,
510
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
511
+ )
512
+ sf.write(
513
+ invert_filepath,
514
+ (-wave_processed.T * model.compensation) + wave.T,
515
+ sr,
516
+ )
517
+
518
+ if not keep_orig:
519
+ os.remove(filename)
520
+
521
+ del mdx_sess, wave_processed, wave
522
+ gc.collect()
523
+ torch.cuda.empty_cache()
524
+ return main_filepath, invert_filepath
525
+
526
+
527
+ MDX_DOWNLOAD_LINK = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
528
+ UVR_MODELS = [
529
+ "UVR-MDX-NET-Voc_FT.onnx",
530
+ "UVR_MDXNET_KARA_2.onnx",
531
+ "Reverb_HQ_By_FoxJoy.onnx",
532
+ "UVR-MDX-NET-Inst_HQ_4.onnx",
533
+ ]
534
+ BASE_DIR = "." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
535
+ mdxnet_models_dir = os.path.join(BASE_DIR, "mdx_models")
536
+ output_dir = os.path.join(BASE_DIR, "clean_song_output")
537
+
538
+
539
+ def convert_to_stereo_and_wav(audio_path):
540
+ wave, sr = librosa.load(audio_path, mono=False, sr=44100)
541
+
542
+ # check if mono
543
+ if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav": # noqa
544
+ stereo_path = f"{os.path.splitext(audio_path)[0]}_stereo.wav"
545
+ stereo_path = os.path.join(output_dir, stereo_path)
546
+
547
+ command = shlex.split(
548
+ f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}"'
549
+ )
550
+ sub_params = {
551
+ "stdout": subprocess.PIPE,
552
+ "stderr": subprocess.PIPE,
553
+ "creationflags": subprocess.CREATE_NO_WINDOW
554
+ if sys.platform == "win32"
555
+ else 0,
556
+ }
557
+ process_wav = subprocess.Popen(command, **sub_params)
558
+ output, errors = process_wav.communicate()
559
+ if process_wav.returncode != 0 or not os.path.exists(stereo_path):
560
+ raise Exception("Error processing audio to stereo wav")
561
+
562
+ return stereo_path
563
+ else:
564
+ return audio_path
565
+
566
+
567
+ def get_hash(filepath):
568
+ with open(filepath, 'rb') as f:
569
+ file_hash = hashlib.blake2b()
570
+ while chunk := f.read(8192):
571
+ file_hash.update(chunk)
572
+
573
+ return file_hash.hexdigest()[:18]
574
+
575
+ def random_sleep():
576
+ sleep_time = round(random.uniform(5.2, 7.9), 1)
577
+ time.sleep(sleep_time)
578
+
579
+ def process_uvr_task(
580
+ orig_song_path: str = "aud_test.mp3",
581
+ main_vocals: bool = False,
582
+ dereverb: bool = True,
583
+ song_id: str = "mdx", # folder output name
584
+ only_voiceless: bool = False,
585
+ remove_files_output_dir: bool = False,
586
+ ):
587
+
588
+ device_base = "cuda" if torch.cuda.is_available() else "cpu"
589
+ logger.info(f"Device: {device_base}")
590
+
591
+ if remove_files_output_dir:
592
+ remove_directory_contents(output_dir)
593
+
594
+ with open(os.path.join(mdxnet_models_dir, "data.json")) as infile:
595
+ mdx_model_params = json.load(infile)
596
+
597
+ song_output_dir = os.path.join(output_dir, song_id)
598
+ create_directories(song_output_dir)
599
+ orig_song_path = convert_to_stereo_and_wav(orig_song_path)
600
+
601
+ logger.info(f"onnxruntime device >> {ort.get_device()}")
602
+
603
+ if only_voiceless:
604
+ logger.info("Voiceless Track Separation...")
605
+
606
+ process = run_mdx(
607
+ mdx_model_params,
608
+ song_output_dir,
609
+ os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Inst_HQ_4.onnx"),
610
+ orig_song_path,
611
+ suffix="Voiceless",
612
+ denoise=False,
613
+ keep_orig=True,
614
+ exclude_inversion=True,
615
+ device_base=device_base,
616
+ )
617
+
618
+ return process
619
+
620
+ logger.info("Vocal Track Isolation...")
621
+ vocals_path, instrumentals_path = run_mdx(
622
+ mdx_model_params,
623
+ song_output_dir,
624
+ os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Voc_FT.onnx"),
625
+ orig_song_path,
626
+ denoise=True,
627
+ keep_orig=True,
628
+ device_base=device_base,
629
+ )
630
+
631
+ if main_vocals:
632
+ random_sleep()
633
+ msg_main = "Main Voice Separation from Supporting Vocals..."
634
+ logger.info(msg_main)
635
+ gr.Info(msg_main)
636
+ try:
637
+ backup_vocals_path, main_vocals_path = run_mdx(
638
+ mdx_model_params,
639
+ song_output_dir,
640
+ os.path.join(mdxnet_models_dir, "UVR_MDXNET_KARA_2.onnx"),
641
+ vocals_path,
642
+ suffix="Backup",
643
+ invert_suffix="Main",
644
+ denoise=True,
645
+ device_base=device_base,
646
+ )
647
+ except Exception as e:
648
+ backup_vocals_path, main_vocals_path = run_mdx_beta(
649
+ mdx_model_params,
650
+ song_output_dir,
651
+ os.path.join(mdxnet_models_dir, "UVR_MDXNET_KARA_2.onnx"),
652
+ vocals_path,
653
+ suffix="Backup",
654
+ invert_suffix="Main",
655
+ denoise=True,
656
+ device_base=device_base,
657
+ )
658
+ else:
659
+ backup_vocals_path, main_vocals_path = None, vocals_path
660
+
661
+ if dereverb:
662
+ random_sleep()
663
+ msg_dereverb = "Vocal Clarity Enhancement through De-Reverberation..."
664
+ logger.info(msg_dereverb)
665
+ gr.Info(msg_dereverb)
666
+ try:
667
+ _, vocals_dereverb_path = run_mdx(
668
+ mdx_model_params,
669
+ song_output_dir,
670
+ os.path.join(mdxnet_models_dir, "Reverb_HQ_By_FoxJoy.onnx"),
671
+ main_vocals_path,
672
+ invert_suffix="DeReverb",
673
+ exclude_main=True,
674
+ denoise=True,
675
+ device_base=device_base,
676
+ )
677
+ except Exception as e:
678
+ _, vocals_dereverb_path = run_mdx_beta(
679
+ mdx_model_params,
680
+ song_output_dir,
681
+ os.path.join(mdxnet_models_dir, "Reverb_HQ_By_FoxJoy.onnx"),
682
+ main_vocals_path,
683
+ invert_suffix="DeReverb",
684
+ exclude_main=True,
685
+ denoise=True,
686
+ device_base=device_base,
687
+ )
688
+ else:
689
+ vocals_dereverb_path = main_vocals_path
690
+
691
+ return (
692
+ vocals_path,
693
+ instrumentals_path,
694
+ backup_vocals_path,
695
+ main_vocals_path,
696
+ vocals_dereverb_path,
697
+ )
698
+
699
+
700
+ def add_vocal_effects(input_file, output_file, reverb_room_size=0.6, vocal_reverb_dryness=0.8, reverb_damping=0.6, reverb_wet_level=0.35,
701
+ delay_seconds=0.4, delay_mix=0.25,
702
+ compressor_threshold_db=-25, compressor_ratio=3.5, compressor_attack_ms=10, compressor_release_ms=60,
703
+ gain_db=3):
704
+
705
+ effects = [HighpassFilter()]
706
+
707
+ effects.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=vocal_reverb_dryness))
708
+
709
+ effects.append(Compressor(threshold_db=compressor_threshold_db, ratio=compressor_ratio,
710
+ attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
711
+
712
+ if delay_seconds > 0 or delay_mix > 0:
713
+ effects.append(Delay(delay_seconds=delay_seconds, mix=delay_mix))
714
+ print("delay applied")
715
+ # effects.append(Chorus())
716
+
717
+ if gain_db:
718
+ effects.append(Gain(gain_db=gain_db))
719
+ print("added gain db")
720
+
721
+ board = Pedalboard(effects)
722
+
723
+ with AudioFile(input_file) as f:
724
+ with AudioFile(output_file, 'w', f.samplerate, f.num_channels) as o:
725
+ # Read one second of audio at a time, until the file is empty:
726
+ while f.tell() < f.frames:
727
+ chunk = f.read(int(f.samplerate))
728
+ effected = board(chunk, f.samplerate, reset=False)
729
+ o.write(effected)
730
+
731
+
732
+ def add_instrumental_effects(input_file, output_file, highpass_freq=100, lowpass_freq=12000,
733
+ reverb_room_size=0.5, reverb_damping=0.5, reverb_wet_level=0.25,
734
+ compressor_threshold_db=-20, compressor_ratio=2.5, compressor_attack_ms=15, compressor_release_ms=80,
735
+ gain_db=2):
736
+
737
+ effects = [
738
+ HighpassFilter(cutoff_frequency_hz=highpass_freq),
739
+ LowpassFilter(cutoff_frequency_hz=lowpass_freq),
740
+ ]
741
+ if reverb_room_size > 0 or reverb_damping > 0 or reverb_wet_level > 0:
742
+ effects.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level))
743
+
744
+ effects.append(Compressor(threshold_db=compressor_threshold_db, ratio=compressor_ratio,
745
+ attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
746
+
747
+ if gain_db:
748
+ effects.append(Gain(gain_db=gain_db))
749
+
750
+ board = Pedalboard(effects)
751
+
752
+ with AudioFile(input_file) as f:
753
+ with AudioFile(output_file, 'w', f.samplerate, f.num_channels) as o:
754
+ # Read one second of audio at a time, until the file is empty:
755
+ while f.tell() < f.frames:
756
+ chunk = f.read(int(f.samplerate))
757
+ effected = board(chunk, f.samplerate, reset=False)
758
+ o.write(effected)
759
+
760
+
761
+ def sound_separate(media_file, stem, main, dereverb, vocal_effects=True, background_effects=True,
762
+ vocal_reverb_room_size=0.6, vocal_reverb_damping=0.6, vocal_reverb_wet_level=0.35,
763
+ vocal_delay_seconds=0.4, vocal_delay_mix=0.25,
764
+ vocal_compressor_threshold_db=-25, vocal_compressor_ratio=3.5, vocal_compressor_attack_ms=10, vocal_compressor_release_ms=60,
765
+ vocal_gain_db=4,
766
+ background_highpass_freq=120, background_lowpass_freq=11000,
767
+ background_reverb_room_size=0.5, background_reverb_damping=0.5, background_reverb_wet_level=0.25,
768
+ background_compressor_threshold_db=-20, background_compressor_ratio=2.5, background_compressor_attack_ms=15, background_compressor_release_ms=80,
769
+ background_gain_db=3):
770
+ if not media_file:
771
+ raise ValueError("The audio path is missing.")
772
+
773
+ if not stem:
774
+ raise ValueError("Please select 'vocal' or 'background' stem.")
775
+
776
+ hash_audio = str(get_hash(media_file))
777
+ media_dir = os.path.dirname(media_file)
778
+
779
+ outputs = []
780
+
781
+ start_time = time.time()
782
+
783
+ if stem == "vocal":
784
+ try:
785
+ _, _, _, _, vocal_audio = process_uvr_task(
786
+ orig_song_path=media_file,
787
+ song_id=hash_audio + "mdx",
788
+ main_vocals=main,
789
+ dereverb=dereverb,
790
+ remove_files_output_dir=False,
791
+ )
792
+
793
+ if vocal_effects:
794
+ suffix = '_effects'
795
+ file_name, file_extension = os.path.splitext(vocal_audio)
796
+ out_effects = file_name + suffix + file_extension
797
+ out_effects_path = os.path.join(media_dir, out_effects)
798
+ add_vocal_effects(vocal_audio, out_effects_path,
799
+ reverb_room_size=vocal_reverb_room_size, reverb_damping=vocal_reverb_damping, reverb_wet_level=vocal_reverb_wet_level,
800
+ delay_seconds=vocal_delay_seconds, delay_mix=vocal_delay_mix,
801
+ compressor_threshold_db=vocal_compressor_threshold_db, compressor_ratio=vocal_compressor_ratio, compressor_attack_ms=vocal_compressor_attack_ms, compressor_release_ms=vocal_compressor_release_ms,
802
+ gain_db=vocal_gain_db
803
+ )
804
+ vocal_audio = out_effects_path
805
+
806
+ outputs.append(vocal_audio)
807
+ except Exception as error:
808
+ logger.error(str(error))
809
+ traceback.print_exc()
810
+
811
+ if stem == "background":
812
+ background_audio, _ = process_uvr_task(
813
+ orig_song_path=media_file,
814
+ song_id=hash_audio + "voiceless",
815
+ only_voiceless=True,
816
+ remove_files_output_dir=False,
817
+ )
818
+
819
+ if background_effects:
820
+ suffix = '_effects'
821
+ file_name, file_extension = os.path.splitext(background_audio)
822
+ out_effects = file_name + suffix + file_extension
823
+ out_effects_path = os.path.join(media_dir, out_effects)
824
+ add_instrumental_effects(background_audio, out_effects_path,
825
+ highpass_freq=background_highpass_freq, lowpass_freq=background_lowpass_freq,
826
+ reverb_room_size=background_reverb_room_size, reverb_damping=background_reverb_damping, reverb_wet_level=background_reverb_wet_level,
827
+ compressor_threshold_db=background_compressor_threshold_db, compressor_ratio=background_compressor_ratio, compressor_attack_ms=background_compressor_attack_ms, compressor_release_ms=background_compressor_release_ms,
828
+ gain_db=background_gain_db
829
+ )
830
+ background_audio = out_effects_path
831
+
832
+ outputs.append(background_audio)
833
+
834
+ end_time = time.time()
835
+ execution_time = end_time - start_time
836
+ logger.info(f"Execution time: {execution_time} seconds")
837
+
838
+ if not outputs:
839
+ raise Exception("Error in sound separation.")
840
+
841
+ return outputs
842
+
843
+
844
+ def sound_separate(media_file, stem, main, dereverb, vocal_effects=True, background_effects=True,
845
+ vocal_reverb_room_size=0.6, vocal_reverb_damping=0.6, vocal_reverb_dryness=0.8 ,vocal_reverb_wet_level=0.35,
846
+ vocal_delay_seconds=0.4, vocal_delay_mix=0.25,
847
+ vocal_compressor_threshold_db=-25, vocal_compressor_ratio=3.5, vocal_compressor_attack_ms=10, vocal_compressor_release_ms=60,
848
+ vocal_gain_db=4,
849
+ background_highpass_freq=120, background_lowpass_freq=11000,
850
+ background_reverb_room_size=0.5, background_reverb_damping=0.5, background_reverb_wet_level=0.25,
851
+ background_compressor_threshold_db=-20, background_compressor_ratio=2.5, background_compressor_attack_ms=15, background_compressor_release_ms=80,
852
+ background_gain_db=3,
853
+ ):
854
+ if not media_file:
855
+ raise ValueError("The audio path is missing.")
856
+
857
+ if not stem:
858
+ raise ValueError("Please select 'vocal' or 'background' stem.")
859
+
860
+ hash_audio = str(get_hash(media_file))
861
+ media_dir = os.path.dirname(media_file)
862
+
863
+ outputs = []
864
+
865
+ try:
866
+ duration_base_ = librosa.get_duration(filename=media_file)
867
+ print("Duration audio:", duration_base_)
868
+ except Exception as e:
869
+ print(e)
870
+
871
+ start_time = time.time()
872
+
873
+ if stem == "vocal":
874
+ try:
875
+ _, _, _, _, vocal_audio = process_uvr_task(
876
+ orig_song_path=media_file,
877
+ song_id=hash_audio + "mdx",
878
+ main_vocals=main,
879
+ dereverb=dereverb,
880
+ remove_files_output_dir=False,
881
+ )
882
+
883
+ if vocal_effects:
884
+ suffix = '_effects'
885
+ file_name, file_extension = os.path.splitext(os.path.abspath(vocal_audio))
886
+ out_effects = file_name + suffix + file_extension
887
+ out_effects_path = os.path.join(media_dir, out_effects)
888
+ add_vocal_effects(vocal_audio, out_effects_path,
889
+ reverb_room_size=vocal_reverb_room_size, reverb_damping=vocal_reverb_damping, vocal_reverb_dryness=vocal_reverb_dryness, reverb_wet_level=vocal_reverb_wet_level,
890
+ delay_seconds=vocal_delay_seconds, delay_mix=vocal_delay_mix,
891
+ compressor_threshold_db=vocal_compressor_threshold_db, compressor_ratio=vocal_compressor_ratio, compressor_attack_ms=vocal_compressor_attack_ms, compressor_release_ms=vocal_compressor_release_ms,
892
+ gain_db=vocal_gain_db
893
+ )
894
+ vocal_audio = out_effects_path
895
+
896
+ outputs.append(vocal_audio)
897
+ except Exception as error:
898
+ gr.Info(str(error))
899
+ logger.error(str(error))
900
+
901
+ if stem == "background":
902
+ background_audio, _ = process_uvr_task(
903
+ orig_song_path=media_file,
904
+ song_id=hash_audio + "voiceless",
905
+ only_voiceless=True,
906
+ remove_files_output_dir=False,
907
+ )
908
+
909
+ if background_effects:
910
+ suffix = '_effects'
911
+ file_name, file_extension = os.path.splitext(os.path.abspath(background_audio))
912
+ out_effects = file_name + suffix + file_extension
913
+ out_effects_path = os.path.join(media_dir, out_effects)
914
+ print(file_name, file_extension, out_effects, out_effects_path)
915
+ add_instrumental_effects(background_audio, out_effects_path,
916
+ highpass_freq=background_highpass_freq, lowpass_freq=background_lowpass_freq,
917
+ reverb_room_size=background_reverb_room_size, reverb_damping=background_reverb_damping, reverb_wet_level=background_reverb_wet_level,
918
+ compressor_threshold_db=background_compressor_threshold_db, compressor_ratio=background_compressor_ratio, compressor_attack_ms=background_compressor_attack_ms, compressor_release_ms=background_compressor_release_ms,
919
+ gain_db=background_gain_db
920
+ )
921
+ background_audio = out_effects_path
922
+
923
+ outputs.append(background_audio)
924
+
925
+ end_time = time.time()
926
+ execution_time = end_time - start_time
927
+ logger.info(f"Execution time: {execution_time} seconds")
928
+
929
+ if not outputs:
930
+ raise Exception("Error in sound separation.")
931
+
932
+ return outputs
933
+
934
+
935
+ def audio_downloader(
936
+ url_media,
937
+ ):
938
+
939
+ url_media = url_media.strip()
940
+
941
+ if not url_media:
942
+ return None
943
+
944
+ print(url_media[:10])
945
+
946
+ dir_output_downloads = "downloads"
947
+ os.makedirs(dir_output_downloads, exist_ok=True)
948
+
949
+ media_info = yt_dlp.YoutubeDL(
950
+ {"quiet": True, "no_warnings": True, "noplaylist": True}
951
+ ).extract_info(url_media, download=False)
952
+ download_path = f"{os.path.join(dir_output_downloads, media_info['title'])}.m4a"
953
+
954
+ ydl_opts = {
955
+ 'format': 'm4a/bestaudio/best',
956
+ 'postprocessors': [{ # Extract audio using ffmpeg
957
+ 'key': 'FFmpegExtractAudio',
958
+ 'preferredcodec': 'm4a',
959
+ }],
960
+ 'force_overwrites': True,
961
+ 'noplaylist': True,
962
+ 'no_warnings': True,
963
+ 'quiet': True,
964
+ 'ignore_no_formats_error': True,
965
+ 'restrictfilenames': True,
966
+ 'outtmpl': download_path,
967
+ }
968
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl_download:
969
+ ydl_download.download([url_media])
970
+
971
+ return download_path
972
+
973
+
974
+ def downloader_conf():
975
+ return gr.Checkbox(
976
+ False,
977
+ label="URL-to-Audio",
978
+ # info="",
979
+ container=False,
980
+ )
981
+
982
+
983
+ def url_media_conf():
984
+ return gr.Textbox(
985
+ value="",
986
+ label="Enter URL",
987
+ placeholder="www.youtube.com/watch?v=g_9rPvbENUw",
988
+ visible=False,
989
+ lines=1,
990
+ )
991
+
992
+
993
+ def url_button_conf():
994
+ return gr.Button(
995
+ "Go",
996
+ variant="secondary",
997
+ visible=False,
998
+ )
999
+
1000
+
1001
+ def show_components_downloader(value_active):
1002
+ return gr.update(
1003
+ visible=value_active
1004
+ ), gr.update(
1005
+ visible=value_active
1006
+ )
1007
+
1008
+
1009
+ def audio_conf():
1010
+ return gr.File(
1011
+ label="Audio file",
1012
+ # file_count="multiple",
1013
+ type="filepath",
1014
+ container=True,
1015
+ )
1016
+
1017
+
1018
+ def stem_conf():
1019
+ return gr.Radio(
1020
+ choices=["vocal", "background"],
1021
+ value="vocal",
1022
+ label="Stem",
1023
+ # info="",
1024
+ )
1025
+
1026
+
1027
+ def main_conf():
1028
+ return gr.Checkbox(
1029
+ False,
1030
+ label="Main",
1031
+ # info="",
1032
+ )
1033
+
1034
+
1035
+ def dereverb_conf():
1036
+ return gr.Checkbox(
1037
+ False,
1038
+ label="Dereverb",
1039
+ # info="",
1040
+ visible=True,
1041
+ )
1042
+
1043
+
1044
+ def vocal_effects_conf():
1045
+ return gr.Checkbox(
1046
+ False,
1047
+ label="Vocal Effects",
1048
+ # info="",
1049
+ visible=True,
1050
+ )
1051
+
1052
+
1053
+ def background_effects_conf():
1054
+ return gr.Checkbox(
1055
+ False,
1056
+ label="Background Effects",
1057
+ # info="",
1058
+ visible=False,
1059
+ )
1060
+
1061
+
1062
+ def vocal_reverb_room_size_conf():
1063
+ return gr.Number(
1064
+ 0.15,
1065
+ label="Vocal Reverb Room Size",
1066
+ minimum=0.0,
1067
+ maximum=1.0,
1068
+ step=0.05,
1069
+ visible=True,
1070
+ )
1071
+
1072
+
1073
+ def vocal_reverb_damping_conf():
1074
+ return gr.Number(
1075
+ 0.7,
1076
+ label="Vocal Reverb Damping",
1077
+ minimum=0.0,
1078
+ maximum=1.0,
1079
+ step=0.01,
1080
+ visible=True,
1081
+ )
1082
+
1083
+
1084
+ def vocal_reverb_wet_level_conf():
1085
+ return gr.Number(
1086
+ 0.2,
1087
+ label="Vocal Reverb Wet Level",
1088
+ minimum=0.0,
1089
+ maximum=1.0,
1090
+ step=0.05,
1091
+ visible=True,
1092
+ )
1093
+
1094
+
1095
+ def vocal_reverb_dryness_level_conf():
1096
+ return gr.Number(
1097
+ 0.8,
1098
+ label="Vocal Reverb Dryness Level",
1099
+ minimum=0.0,
1100
+ maximum=1.0,
1101
+ step=0.05,
1102
+ visible=True,
1103
+ )
1104
+
1105
+
1106
+ def vocal_delay_seconds_conf():
1107
+ return gr.Number(
1108
+ 0.,
1109
+ label="Vocal Delay Seconds",
1110
+ minimum=0.0,
1111
+ maximum=1.0,
1112
+ step=0.01,
1113
+ visible=True,
1114
+ )
1115
+
1116
+
1117
+ def vocal_delay_mix_conf():
1118
+ return gr.Number(
1119
+ 0.,
1120
+ label="Vocal Delay Mix",
1121
+ minimum=0.0,
1122
+ maximum=1.0,
1123
+ step=0.01,
1124
+ visible=True,
1125
+ )
1126
+
1127
+
1128
+ def vocal_compressor_threshold_db_conf():
1129
+ return gr.Number(
1130
+ -15,
1131
+ label="Vocal Compressor Threshold (dB)",
1132
+ minimum=-60,
1133
+ maximum=0,
1134
+ step=1,
1135
+ visible=True,
1136
+ )
1137
+
1138
+
1139
+ def vocal_compressor_ratio_conf():
1140
+ return gr.Number(
1141
+ 4.,
1142
+ label="Vocal Compressor Ratio",
1143
+ minimum=0,
1144
+ maximum=20,
1145
+ step=0.1,
1146
+ visible=True,
1147
+ )
1148
+
1149
+
1150
+ def vocal_compressor_attack_ms_conf():
1151
+ return gr.Number(
1152
+ 1.0,
1153
+ label="Vocal Compressor Attack (ms)",
1154
+ minimum=0,
1155
+ maximum=1000,
1156
+ step=1,
1157
+ visible=True,
1158
+ )
1159
+
1160
+
1161
+ def vocal_compressor_release_ms_conf():
1162
+ return gr.Number(
1163
+ 100,
1164
+ label="Vocal Compressor Release (ms)",
1165
+ minimum=0,
1166
+ maximum=3000,
1167
+ step=1,
1168
+ visible=True,
1169
+ )
1170
+
1171
+
1172
+ def vocal_gain_db_conf():
1173
+ return gr.Number(
1174
+ 0,
1175
+ label="Vocal Gain (dB)",
1176
+ minimum=-40,
1177
+ maximum=40,
1178
+ step=1,
1179
+ visible=True,
1180
+ )
1181
+
1182
+
1183
+ def background_highpass_freq_conf():
1184
+ return gr.Number(
1185
+ 120,
1186
+ label="Background Highpass Frequency (Hz)",
1187
+ minimum=0,
1188
+ maximum=1000,
1189
+ step=1,
1190
+ visible=True,
1191
+ )
1192
+
1193
+
1194
+ def background_lowpass_freq_conf():
1195
+ return gr.Number(
1196
+ 11000,
1197
+ label="Background Lowpass Frequency (Hz)",
1198
+ minimum=0,
1199
+ maximum=20000,
1200
+ step=1,
1201
+ visible=True,
1202
+ )
1203
+
1204
+
1205
+ def background_reverb_room_size_conf():
1206
+ return gr.Number(
1207
+ 0.1,
1208
+ label="Background Reverb Room Size",
1209
+ minimum=0.0,
1210
+ maximum=1.0,
1211
+ step=0.1,
1212
+ visible=True,
1213
+ )
1214
+
1215
+
1216
+ def background_reverb_damping_conf():
1217
+ return gr.Number(
1218
+ 0.5,
1219
+ label="Background Reverb Damping",
1220
+ minimum=0.0,
1221
+ maximum=1.0,
1222
+ step=0.1,
1223
+ visible=True,
1224
+ )
1225
+
1226
+
1227
+ def background_reverb_wet_level_conf():
1228
+ return gr.Number(
1229
+ 0.25,
1230
+ label="Background Reverb Wet Level",
1231
+ minimum=0.0,
1232
+ maximum=1.0,
1233
+ step=0.05,
1234
+ visible=True,
1235
+ )
1236
+
1237
+
1238
+ def background_compressor_threshold_db_conf():
1239
+ return gr.Number(
1240
+ -15,
1241
+ label="Background Compressor Threshold (dB)",
1242
+ minimum=-60,
1243
+ maximum=0,
1244
+ step=1,
1245
+ visible=True,
1246
+ )
1247
+
1248
+
1249
+ def background_compressor_ratio_conf():
1250
+ return gr.Number(
1251
+ 4.,
1252
+ label="Background Compressor Ratio",
1253
+ minimum=0,
1254
+ maximum=20,
1255
+ step=0.1,
1256
+ visible=True,
1257
+ )
1258
+
1259
+
1260
+ def background_compressor_attack_ms_conf():
1261
+ return gr.Number(
1262
+ 15,
1263
+ label="Background Compressor Attack (ms)",
1264
+ minimum=0,
1265
+ maximum=1000,
1266
+ step=1,
1267
+ visible=True,
1268
+ )
1269
+
1270
+
1271
+ def background_compressor_release_ms_conf():
1272
+ return gr.Number(
1273
+ 60,
1274
+ label="Background Compressor Release (ms)",
1275
+ minimum=0,
1276
+ maximum=3000,
1277
+ step=1,
1278
+ visible=True,
1279
+ )
1280
+
1281
+
1282
+ def background_gain_db_conf():
1283
+ return gr.Number(
1284
+ 0,
1285
+ label="Background Gain (dB)",
1286
+ minimum=-40,
1287
+ maximum=40,
1288
+ step=1,
1289
+ visible=True,
1290
+ )
1291
+
1292
+
1293
+ def button_conf():
1294
+ return gr.Button(
1295
+ "Inference",
1296
+ variant="primary",
1297
+ )
1298
+
1299
+
1300
+ def output_conf():
1301
+ return gr.File(
1302
+ label="Result",
1303
+ file_count="multiple",
1304
+ interactive=False,
1305
+ )
1306
+
1307
+
1308
+ def show_vocal_components(value_name):
1309
+
1310
+ if value_name == "vocal":
1311
+ return gr.update(visible=True), gr.update(
1312
+ visible=True
1313
+ ), gr.update(visible=True), gr.update(
1314
+ visible=False
1315
+ )
1316
+ else:
1317
+ return gr.update(visible=False), gr.update(
1318
+ visible=False
1319
+ ), gr.update(visible=False), gr.update(
1320
+ visible=True
1321
+ )
1322
+
1323
+
1324
+ def get_gui(theme):
1325
+ with gr.Blocks(theme=theme) as app:
1326
+ gr.Markdown(title)
1327
+ gr.Markdown(description)
1328
+
1329
+ downloader_gui = downloader_conf()
1330
+ with gr.Row():
1331
+ with gr.Column(scale=2):
1332
+ url_media_gui = url_media_conf()
1333
+ with gr.Column(scale=1):
1334
+ url_button_gui = url_button_conf()
1335
+
1336
+ downloader_gui.change(
1337
+ show_components_downloader,
1338
+ [downloader_gui],
1339
+ [url_media_gui, url_button_gui]
1340
+ )
1341
+
1342
+ aud = audio_conf()
1343
+
1344
+ url_button_gui.click(
1345
+ audio_downloader,
1346
+ [url_media_gui],
1347
+ [aud]
1348
+ )
1349
+
1350
+ with gr.Column():
1351
+ with gr.Row():
1352
+ stem_gui = stem_conf()
1353
+
1354
+ with gr.Column():
1355
+ with gr.Row():
1356
+ main_gui = main_conf()
1357
+ dereverb_gui = dereverb_conf()
1358
+ vocal_effects_gui = vocal_effects_conf()
1359
+ background_effects_gui = background_effects_conf()
1360
+
1361
+ # with gr.Column():
1362
+ with gr.Accordion("Vocal Effects Parameters", open=False): # with gr.Row():
1363
+ # gr.Label("Vocal Effects Parameters")
1364
+ with gr.Row():
1365
+ vocal_reverb_room_size_gui = vocal_reverb_room_size_conf()
1366
+ vocal_reverb_damping_gui = vocal_reverb_damping_conf()
1367
+ vocal_reverb_dryness_gui = vocal_reverb_dryness_level_conf()
1368
+ vocal_reverb_wet_level_gui = vocal_reverb_wet_level_conf()
1369
+ vocal_delay_seconds_gui = vocal_delay_seconds_conf()
1370
+ vocal_delay_mix_gui = vocal_delay_mix_conf()
1371
+ vocal_compressor_threshold_db_gui = vocal_compressor_threshold_db_conf()
1372
+ vocal_compressor_ratio_gui = vocal_compressor_ratio_conf()
1373
+ vocal_compressor_attack_ms_gui = vocal_compressor_attack_ms_conf()
1374
+ vocal_compressor_release_ms_gui = vocal_compressor_release_ms_conf()
1375
+ vocal_gain_db_gui = vocal_gain_db_conf()
1376
+
1377
+ with gr.Accordion("Background Effects Parameters", open=False): # with gr.Row():
1378
+ # gr.Label("Background Effects Parameters")
1379
+ with gr.Row():
1380
+ background_highpass_freq_gui = background_highpass_freq_conf()
1381
+ background_lowpass_freq_gui = background_lowpass_freq_conf()
1382
+ background_reverb_room_size_gui = background_reverb_room_size_conf()
1383
+ background_reverb_damping_gui = background_reverb_damping_conf()
1384
+ background_reverb_wet_level_gui = background_reverb_wet_level_conf()
1385
+ background_compressor_threshold_db_gui = background_compressor_threshold_db_conf()
1386
+ background_compressor_ratio_gui = background_compressor_ratio_conf()
1387
+ background_compressor_attack_ms_gui = background_compressor_attack_ms_conf()
1388
+ background_compressor_release_ms_gui = background_compressor_release_ms_conf()
1389
+ background_gain_db_gui = background_gain_db_conf()
1390
+
1391
+ stem_gui.change(
1392
+ show_vocal_components,
1393
+ [stem_gui],
1394
+ [main_gui, dereverb_gui, vocal_effects_gui, background_effects_gui],
1395
+ )
1396
+
1397
+ button_base = button_conf()
1398
+ output_base = output_conf()
1399
+
1400
+ button_base.click(
1401
+ sound_separate,
1402
+ inputs=[
1403
+ aud,
1404
+ stem_gui,
1405
+ main_gui,
1406
+ dereverb_gui,
1407
+ vocal_effects_gui,
1408
+ background_effects_gui,
1409
+ vocal_reverb_room_size_gui, vocal_reverb_damping_gui, vocal_reverb_dryness_gui, vocal_reverb_wet_level_gui,
1410
+ vocal_delay_seconds_gui, vocal_delay_mix_gui, vocal_compressor_threshold_db_gui, vocal_compressor_ratio_gui,
1411
+ vocal_compressor_attack_ms_gui, vocal_compressor_release_ms_gui, vocal_gain_db_gui,
1412
+ background_highpass_freq_gui, background_lowpass_freq_gui, background_reverb_room_size_gui,
1413
+ background_reverb_damping_gui, background_reverb_wet_level_gui, background_compressor_threshold_db_gui,
1414
+ background_compressor_ratio_gui, background_compressor_attack_ms_gui, background_compressor_release_ms_gui,
1415
+ background_gain_db_gui,
1416
+ ],
1417
+ outputs=[output_base],
1418
+ )
1419
+
1420
+ gr.Examples(
1421
+ examples=[
1422
+ [
1423
+ "./test.mp3",
1424
+ "vocal",
1425
+ False,
1426
+ False,
1427
+ False,
1428
+ False,
1429
+ 0.15, 0.7, 0.8, 0.2,
1430
+ 0., 0., -15, 4., 1, 100, 0,
1431
+ 120, 11000, 0.5, 0.1, 0.25, -15, 4., 15, 60, 0,
1432
+ ],
1433
+ ],
1434
+ fn=sound_separate,
1435
+ inputs=[
1436
+ aud,
1437
+ stem_gui,
1438
+ main_gui,
1439
+ dereverb_gui,
1440
+ vocal_effects_gui,
1441
+ background_effects_gui,
1442
+ vocal_reverb_room_size_gui, vocal_reverb_damping_gui, vocal_reverb_dryness_gui, vocal_reverb_wet_level_gui,
1443
+ vocal_delay_seconds_gui, vocal_delay_mix_gui, vocal_compressor_threshold_db_gui, vocal_compressor_ratio_gui,
1444
+ vocal_compressor_attack_ms_gui, vocal_compressor_release_ms_gui, vocal_gain_db_gui,
1445
+ background_highpass_freq_gui, background_lowpass_freq_gui, background_reverb_room_size_gui,
1446
+ background_reverb_damping_gui, background_reverb_wet_level_gui, background_compressor_threshold_db_gui,
1447
+ background_compressor_ratio_gui, background_compressor_attack_ms_gui, background_compressor_release_ms_gui,
1448
+ background_gain_db_gui,
1449
+ ],
1450
+ outputs=[output_base],
1451
+ cache_examples=False,
1452
+ )
1453
+
1454
+ return app
1455
+
1456
+
1457
+ if __name__ == "__main__":
1458
+
1459
+ for id_model in UVR_MODELS:
1460
+ download_manager(
1461
+ os.path.join(MDX_DOWNLOAD_LINK, id_model), mdxnet_models_dir
1462
+ )
1463
+
1464
+ app = get_gui(theme)
1465
+
1466
+ app.queue(default_concurrency_limit=40)
1467
+
1468
+ app.launch(
1469
+ max_threads=40,
1470
+ share=False,
1471
+ show_error=True,
1472
+ quiet=False,
1473
+ debug=False,
1474
+ )