next-playground commited on
Commit
16734a7
·
verified ·
1 Parent(s): bf73350

Upload 2 files

Browse files
Files changed (2) hide show
  1. requirements.txt +10 -0
  2. separate.py +214 -0
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ audioread==3.0.0
3
+ librosa==0.10.0.post2
4
+ onnx==1.14.0
5
+ onnxruntime==1.15.0
6
+ pydub==0.25.1
7
+ soundstretch==1.2
8
+ tqdm==4.65.0
9
+ Pillow==9.5.0
10
+ resampy==0.4.2
separate.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import os
4
+ import librosa
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ from pathlib import Path
8
+ from argparse import ArgumentParser
9
+ from tqdm import tqdm
10
+
11
+
12
+ class ConvTDFNet:
13
+ def __init__(self, target_name, L, dim_f, dim_t, n_fft, hop=1024):
14
+ super(ConvTDFNet, self).__init__()
15
+ self.dim_c = 4
16
+ self.dim_f = dim_f
17
+ self.dim_t = 2**dim_t
18
+ self.n_fft = n_fft
19
+ self.hop = hop
20
+ self.n_bins = self.n_fft // 2 + 1
21
+ self.chunk_size = hop * (self.dim_t - 1)
22
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
23
+ self.target_name = target_name
24
+
25
+ out_c = self.dim_c * 4 if target_name == "*" else self.dim_c
26
+
27
+ self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t])
28
+ self.n = L // 2
29
+
30
+ def stft(self, x):
31
+ x = x.reshape([-1, self.chunk_size])
32
+ x = torch.stft(
33
+ x,
34
+ n_fft=self.n_fft,
35
+ hop_length=self.hop,
36
+ window=self.window,
37
+ center=True,
38
+ return_complex=True,
39
+ )
40
+ x = torch.view_as_real(x)
41
+ x = x.permute([0, 3, 1, 2])
42
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
43
+ [-1, self.dim_c, self.n_bins, self.dim_t]
44
+ )
45
+ return x[:, :, : self.dim_f]
46
+
47
+ # Inversed Short-time Fourier transform (STFT).
48
+ def istft(self, x, freq_pad=None):
49
+ freq_pad = (
50
+ self.freq_pad.repeat([x.shape[0], 1, 1, 1])
51
+ if freq_pad is None
52
+ else freq_pad
53
+ )
54
+ x = torch.cat([x, freq_pad], -2)
55
+ c = 4 * 2 if self.target_name == "*" else 2
56
+ x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
57
+ [-1, 2, self.n_bins, self.dim_t]
58
+ )
59
+ x = x.permute([0, 2, 3, 1])
60
+ x = x.contiguous()
61
+ x = torch.view_as_complex(x)
62
+ x = torch.istft(
63
+ x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
64
+ )
65
+ return x.reshape([-1, c, self.chunk_size])
66
+
67
+ class Predictor:
68
+ def __init__(self, args):
69
+ self.args = args
70
+ self.model_ = ConvTDFNet(
71
+ target_name="vocals",
72
+ L=11,
73
+ dim_f=args["dim_f"],
74
+ dim_t=args["dim_t"],
75
+ n_fft=args["n_fft"]
76
+ )
77
+
78
+ if torch.cuda.is_available():
79
+ self.model = ort.InferenceSession(args['model_path'], providers=['CUDAExecutionProvider'])
80
+ else:
81
+ self.model = ort.InferenceSession(args['model_path'], providers=['CPUExecutionProvider'])
82
+
83
+ def demix(self, mix):
84
+ samples = mix.shape[-1]
85
+ margin = self.args["margin"]
86
+ chunk_size = self.args["chunks"] * 44100
87
+
88
+ assert not margin == 0, "margin cannot be zero!"
89
+
90
+ if margin > chunk_size:
91
+ margin = chunk_size
92
+
93
+ segmented_mix = {}
94
+
95
+ if self.args["chunks"] == 0 or samples < chunk_size:
96
+ chunk_size = samples
97
+
98
+ counter = -1
99
+ for skip in range(0, samples, chunk_size):
100
+ counter += 1
101
+ s_margin = 0 if counter == 0 else margin
102
+ end = min(skip + chunk_size + margin, samples)
103
+ start = skip - s_margin
104
+ segmented_mix[skip] = mix[:, start:end].copy()
105
+ if end == samples:
106
+ break
107
+
108
+ sources = self.demix_base(segmented_mix, margin_size=margin)
109
+ return sources
110
+
111
+ def demix_base(self, mixes, margin_size):
112
+ chunked_sources = []
113
+ progress_bar = tqdm(total=len(mixes))
114
+ progress_bar.set_description("Processing")
115
+
116
+ for mix in mixes:
117
+ cmix = mixes[mix]
118
+ sources = []
119
+ n_sample = cmix.shape[1]
120
+ model = self.model_
121
+ trim = model.n_fft // 2
122
+ gen_size = model.chunk_size - 2 * trim
123
+ pad = gen_size - n_sample % gen_size
124
+ mix_p = np.concatenate(
125
+ (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
126
+ )
127
+ mix_waves = []
128
+ i = 0
129
+ while i < n_sample + pad:
130
+ waves = np.array(mix_p[:, i : i + model.chunk_size])
131
+ mix_waves.append(waves)
132
+ i += gen_size
133
+
134
+ mix_waves = torch.tensor(np.array(mix_waves), dtype=torch.float32)
135
+
136
+ with torch.no_grad():
137
+ _ort = self.model
138
+ spek = model.stft(mix_waves)
139
+ if self.args["denoise"]:
140
+ spec_pred = (
141
+ -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
142
+ + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
143
+ )
144
+ tar_waves = model.istft(torch.tensor(spec_pred))
145
+ else:
146
+ tar_waves = model.istft(
147
+ torch.tensor(_ort.run(None, {"input": spek.cpu().numpy() })[0])
148
+ )
149
+ tar_signal = (
150
+ tar_waves[:, :, trim:-trim]
151
+ .transpose(0, 1)
152
+ .reshape(2, -1)
153
+ .numpy()[:, :-pad]
154
+ )
155
+
156
+ start = 0 if mix == 0 else margin_size
157
+ end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
158
+
159
+ if margin_size == 0:
160
+ end = None
161
+
162
+ sources.append(tar_signal[:, start:end])
163
+
164
+ progress_bar.update(1)
165
+
166
+ chunked_sources.append(sources)
167
+ _sources = np.concatenate(chunked_sources, axis=-1)
168
+
169
+ progress_bar.close()
170
+ return _sources
171
+
172
+ def predict(self, file_path):
173
+
174
+ mix, rate = librosa.load(file_path, mono=False, sr=44100)
175
+
176
+ if mix.ndim == 1:
177
+ mix = np.asfortranarray([mix, mix])
178
+
179
+ mix = mix.T
180
+ sources = self.demix(mix.T)
181
+ opt = sources[0].T
182
+
183
+ return (mix - opt, opt, rate)
184
+
185
+ def main():
186
+ parser = ArgumentParser()
187
+
188
+ parser.add_argument("files", nargs="+", type=Path, default=[], help="Source audio path")
189
+ parser.add_argument("-o", "--output", type=Path, default=Path("separated"), help="Output folder")
190
+ parser.add_argument("-m", "--model_path", type=Path, help="MDX Net ONNX Model path")
191
+
192
+ parser.add_argument("-d", "--no-denoise", dest="denoise", action="store_false", default=True, help="Disable denoising")
193
+ parser.add_argument("-M", "--margin", type=int, default=44100, help="Margin")
194
+ parser.add_argument("-c", "--chunks", type=int, default=15, help="Chunk size")
195
+ parser.add_argument("-F", "--n_fft", type=int, default=6144)
196
+ parser.add_argument("-t", "--dim_t", type=int, default=8)
197
+ parser.add_argument("-f", "--dim_f", type=int, default=2048)
198
+
199
+ args = parser.parse_args()
200
+ dict_args = vars(args)
201
+
202
+ os.makedirs(args.output, exist_ok=True)
203
+
204
+ for file_path in args.files:
205
+ predictor = Predictor(args=dict_args)
206
+ vocals, no_vocals, sampling_rate = predictor.predict(file_path)
207
+ filename = os.path.splitext(os.path.split(file_path)[-1])[0]
208
+ sf.write(os.path.join(args.output, filename+"_no_vocals.wav"), no_vocals, sampling_rate)
209
+ sf.write(os.path.join(args.output, filename+"_vocals.wav"), vocals, sampling_rate)
210
+
211
+ if __name__ == "__main__":
212
+ main()
213
+
214
+