Zeph27 commited on
Commit
bce4e8a
·
1 Parent(s): 15deb82

add print debug

Browse files
Files changed (1) hide show
  1. src/mdx.py +42 -71
src/mdx.py CHANGED
@@ -12,15 +12,13 @@ import soundfile as sf
12
  import torch
13
  from tqdm import tqdm
14
 
15
- import re
16
- import random
17
-
18
  warnings.filterwarnings("ignore")
19
  stem_naming = {'Vocals': 'Instrumental', 'Other': 'Instruments', 'Instrumental': 'Vocals', 'Drums': 'Drumless', 'Bass': 'Bassless'}
20
 
21
 
22
  class MDXModel:
23
  def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000):
 
24
  self.dim_f = dim_f
25
  self.dim_t = dim_t
26
  self.dim_c = 4
@@ -36,89 +34,80 @@ class MDXModel:
36
  out_c = self.dim_c
37
 
38
  self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
 
39
 
40
  def stft(self, x):
 
41
  x = x.reshape([-1, self.chunk_size])
42
  x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
43
  x = torch.view_as_real(x)
44
  x = x.permute([0, 3, 1, 2])
45
  x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 4, self.n_bins, self.dim_t])
 
46
  return x[:, :, :self.dim_f]
47
 
48
  def istft(self, x, freq_pad=None):
 
49
  freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
50
  x = torch.cat([x, freq_pad], -2)
51
- # c = 4*2 if self.target_name=='*' else 2
52
  x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
53
  x = x.permute([0, 2, 3, 1])
54
  x = x.contiguous()
55
  x = torch.view_as_complex(x)
56
  x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
 
57
  return x.reshape([-1, 2, self.chunk_size])
58
 
59
 
60
  class MDX:
61
  DEFAULT_SR = 44100
62
- # Unit: seconds
63
  DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
64
  DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
65
 
66
  DEFAULT_PROCESSOR = 0
67
 
68
  def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
69
-
70
- # Set the device and the provider (CPU or CUDA)
71
  self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
72
  self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
73
 
74
  self.model = params
75
 
76
- # Load the ONNX model using ONNX Runtime
77
  self.ort = ort.InferenceSession(model_path, providers=self.provider)
78
- # Preload the model for faster performance
79
  self.ort.run(None, {'input': torch.rand(1, 4, params.dim_f, params.dim_t).numpy()})
80
  self.process = lambda spec: self.ort.run(None, {'input': spec.cpu().numpy()})[0]
81
 
82
  self.prog = None
 
83
 
84
  @staticmethod
85
  def get_hash(model_path):
 
86
  try:
87
  with open(model_path, 'rb') as f:
88
  f.seek(- 10000 * 1024, 2)
89
  model_hash = hashlib.md5(f.read()).hexdigest()
90
  except:
91
  model_hash = hashlib.md5(open(model_path, 'rb').read()).hexdigest()
92
-
93
  return model_hash
94
 
95
  @staticmethod
96
  def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE):
97
- """
98
- Segment or join segmented wave array
99
-
100
- Args:
101
- wave: (np.array) Wave array to be segmented or joined
102
- combine: (bool) If True, combines segmented wave array. If False, segments wave array.
103
- chunk_size: (int) Size of each segment (in samples)
104
- margin_size: (int) Size of margin between segments (in samples)
105
-
106
- Returns:
107
- numpy array: Segmented or joined wave array
108
- """
109
-
110
  if combine:
111
- processed_wave = None # Initializing as None instead of [] for later numpy array concatenation
112
  for segment_count, segment in enumerate(wave):
113
  start = 0 if segment_count == 0 else margin_size
114
  end = None if segment_count == len(wave) - 1 else -margin_size
115
  if margin_size == 0:
116
  end = None
117
- if processed_wave is None: # Create array for first segment
118
  processed_wave = segment[:, start:end]
119
- else: # Concatenate to existing array for subsequent segments
120
  processed_wave = np.concatenate((processed_wave, segment[:, start:end]), axis=-1)
121
-
122
  else:
123
  processed_wave = []
124
  sample_count = wave.shape[-1]
@@ -130,7 +119,6 @@ class MDX:
130
  margin_size = chunk_size
131
 
132
  for segment_count, skip in enumerate(range(0, sample_count, chunk_size)):
133
-
134
  margin = 0 if segment_count == 0 else margin_size
135
  end = min(skip + chunk_size + margin_size, sample_count)
136
  start = skip - margin
@@ -140,28 +128,16 @@ class MDX:
140
 
141
  if end == sample_count:
142
  break
143
-
144
  return processed_wave
145
 
146
  def pad_wave(self, wave):
147
- """
148
- Pad the wave array to match the required chunk size
149
-
150
- Args:
151
- wave: (np.array) Wave array to be padded
152
-
153
- Returns:
154
- tuple: (padded_wave, pad, trim)
155
- - padded_wave: Padded wave array
156
- - pad: Number of samples that were padded
157
- - trim: Number of samples that were trimmed
158
- """
159
  n_sample = wave.shape[1]
160
  trim = self.model.n_fft // 2
161
  gen_size = self.model.chunk_size - 2 * trim
162
  pad = gen_size - n_sample % gen_size
163
 
164
- # Padded wave
165
  wave_p = np.concatenate((np.zeros((2, trim)), wave, np.zeros((2, pad)), np.zeros((2, trim))), 1)
166
 
167
  mix_waves = []
@@ -170,23 +146,11 @@ class MDX:
170
  mix_waves.append(waves)
171
 
172
  mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
173
-
174
  return mix_waves, pad, trim
175
 
176
  def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
177
- """
178
- Process each wave segment in a multi-threaded environment
179
-
180
- Args:
181
- mix_waves: (torch.Tensor) Wave segments to be processed
182
- trim: (int) Number of samples trimmed during padding
183
- pad: (int) Number of samples padded during padding
184
- q: (queue.Queue) Queue to hold the processed wave segments
185
- _id: (int) Identifier of the processed wave segment
186
-
187
- Returns:
188
- numpy array: Processed wave segment
189
- """
190
  mix_waves = mix_waves.split(1)
191
  with torch.no_grad():
192
  pw = []
@@ -199,24 +163,15 @@ class MDX:
199
  pw.append(processed_wav)
200
  processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
201
  q.put({_id: processed_signal})
 
202
  return processed_signal
203
 
204
  def process_wave(self, wave: np.array, mt_threads=1):
205
- """
206
- Process the wave array in a multi-threaded environment
207
-
208
- Args:
209
- wave: (np.array) Wave array to be processed
210
- mt_threads: (int) Number of threads to be used for processing
211
-
212
- Returns:
213
- numpy array: Processed wave array
214
- """
215
  self.prog = tqdm(total=0)
216
  chunk = wave.shape[-1] // mt_threads
217
  waves = self.segment(wave, False, chunk)
218
 
219
- # Create a queue to hold the processed wave segments
220
  q = queue.Queue()
221
  threads = []
222
  for c, batch in enumerate(waves):
@@ -235,15 +190,18 @@ class MDX:
235
  processed_batches = [list(wave.values())[0] for wave in
236
  sorted(processed_batches, key=lambda d: list(d.keys())[0])]
237
  assert len(processed_batches) == len(waves), 'Incomplete processed batches, please reduce batch size!'
 
238
  return self.segment(processed_batches, True, chunk)
239
 
240
 
241
  def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2):
 
242
  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
243
 
244
  device_properties = torch.cuda.get_device_properties(device)
245
  vram_gb = device_properties.total_memory / 1024**3
246
  m_threads = 1 if vram_gb < 8 else 2
 
247
 
248
  model_hash = MDX.get_hash(model_path)
249
  mp = model_params.get(model_hash)
@@ -257,22 +215,25 @@ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False,
257
  )
258
 
259
  mdx_sess = MDX(model_path, model)
 
260
  wave, sr = librosa.load(filename, mono=False, sr=44100)
261
- # normalizing input wave gives better output
262
  peak = max(np.max(wave), abs(np.min(wave)))
263
  wave /= peak
264
  if denoise:
 
265
  wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
266
  wave_processed *= 0.5
267
  else:
 
268
  wave_processed = mdx_sess.process_wave(wave, m_threads)
269
- # return to previous peak
270
  wave_processed *= peak
271
  stem_name = model.stem_name if suffix is None else suffix
272
 
273
  main_filepath = None
274
  if not exclude_main:
275
  main_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
 
276
  sf.write(main_filepath, wave_processed.T, sr)
277
 
278
  invert_filepath = None
@@ -280,29 +241,35 @@ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False,
280
  diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
281
  stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
282
  invert_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
 
283
  sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
284
 
285
  if not keep_orig:
 
286
  os.remove(filename)
287
 
 
288
  del mdx_sess, wave_processed, wave
289
  if torch.cuda.is_available():
290
  torch.cuda.empty_cache()
291
  gc.collect()
 
292
  return main_filepath, invert_filepath
293
 
294
  def run_roformer(model_params, output_dir, model_name, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2):
 
295
  os.makedirs(output_dir, exist_ok=True)
296
 
297
- # Load and process the audio
298
  wave, sr = librosa.load(filename, mono=False, sr=44100)
299
  base_name = os.path.splitext(os.path.basename(filename))[0]
300
 
301
  roformer_output_format = 'wav'
302
  roformer_overlap = 4
303
  roformer_segment_size = 256
304
- print(f"output_dir: {output_dir}")
305
  prompt = f'audio-separator "{filename}" --model_filename {model_name} --output_dir="{output_dir}" --output_format={roformer_output_format} --normalization=0.9 --mdxc_overlap={roformer_overlap} --mdxc_segment_size={roformer_segment_size}'
 
306
  os.system(prompt)
307
 
308
  vocals_file = f"{base_name}_Vocals.wav"
@@ -314,14 +281,18 @@ def run_roformer(model_params, output_dir, model_name, filename, exclude_main=Fa
314
  if not exclude_main:
315
  main_filepath = os.path.join(output_dir, vocals_file)
316
  if os.path.exists(os.path.join(output_dir, f"{base_name}_(Vocals)_{model_name.replace('.9755.ckpt', '')}.wav")):
 
317
  os.rename(os.path.join(output_dir, f"{base_name}_(Vocals)_{model_name.replace('.9755.ckpt', '')}.wav"), main_filepath)
318
 
319
  if not exclude_inversion:
320
  invert_filepath = os.path.join(output_dir, instrumental_file)
321
  if os.path.exists(os.path.join(output_dir, f"{base_name}_(Instrumental)_{model_name.replace('.9755.ckpt', '')}.wav")):
 
322
  os.rename(os.path.join(output_dir, f"{base_name}_(Instrumental)_{model_name.replace('.9755.ckpt', '')}.wav"), invert_filepath)
323
 
324
  if not keep_orig:
 
325
  os.remove(filename)
326
 
 
327
  return main_filepath, invert_filepath
 
12
  import torch
13
  from tqdm import tqdm
14
 
 
 
 
15
  warnings.filterwarnings("ignore")
16
  stem_naming = {'Vocals': 'Instrumental', 'Other': 'Instruments', 'Instrumental': 'Vocals', 'Drums': 'Drumless', 'Bass': 'Bassless'}
17
 
18
 
19
  class MDXModel:
20
  def __init__(self, device, dim_f, dim_t, n_fft, hop=1024, stem_name=None, compensation=1.000):
21
+ print("[~] Initializing MDXModel...")
22
  self.dim_f = dim_f
23
  self.dim_t = dim_t
24
  self.dim_c = 4
 
34
  out_c = self.dim_c
35
 
36
  self.freq_pad = torch.zeros([1, out_c, self.n_bins - self.dim_f, self.dim_t]).to(device)
37
+ print("[+] MDXModel initialized")
38
 
39
  def stft(self, x):
40
+ print("[~] Performing STFT...")
41
  x = x.reshape([-1, self.chunk_size])
42
  x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
43
  x = torch.view_as_real(x)
44
  x = x.permute([0, 3, 1, 2])
45
  x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 4, self.n_bins, self.dim_t])
46
+ print("[+] STFT completed")
47
  return x[:, :, :self.dim_f]
48
 
49
  def istft(self, x, freq_pad=None):
50
+ print("[~] Performing inverse STFT...")
51
  freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
52
  x = torch.cat([x, freq_pad], -2)
 
53
  x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
54
  x = x.permute([0, 2, 3, 1])
55
  x = x.contiguous()
56
  x = torch.view_as_complex(x)
57
  x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
58
+ print("[+] Inverse STFT completed")
59
  return x.reshape([-1, 2, self.chunk_size])
60
 
61
 
62
  class MDX:
63
  DEFAULT_SR = 44100
 
64
  DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
65
  DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
66
 
67
  DEFAULT_PROCESSOR = 0
68
 
69
  def __init__(self, model_path: str, params: MDXModel, processor=DEFAULT_PROCESSOR):
70
+ print("[~] Initializing MDX...")
 
71
  self.device = torch.device(f'cuda:{processor}') if processor >= 0 else torch.device('cpu')
72
  self.provider = ['CUDAExecutionProvider'] if processor >= 0 else ['CPUExecutionProvider']
73
 
74
  self.model = params
75
 
76
+ print(f"[~] Loading ONNX model from {model_path}...")
77
  self.ort = ort.InferenceSession(model_path, providers=self.provider)
78
+ print("[~] Preloading model...")
79
  self.ort.run(None, {'input': torch.rand(1, 4, params.dim_f, params.dim_t).numpy()})
80
  self.process = lambda spec: self.ort.run(None, {'input': spec.cpu().numpy()})[0]
81
 
82
  self.prog = None
83
+ print("[+] MDX initialized")
84
 
85
  @staticmethod
86
  def get_hash(model_path):
87
+ print(f"[~] Calculating hash for model: {model_path}")
88
  try:
89
  with open(model_path, 'rb') as f:
90
  f.seek(- 10000 * 1024, 2)
91
  model_hash = hashlib.md5(f.read()).hexdigest()
92
  except:
93
  model_hash = hashlib.md5(open(model_path, 'rb').read()).hexdigest()
94
+ print(f"[+] Model hash: {model_hash}")
95
  return model_hash
96
 
97
  @staticmethod
98
  def segment(wave, combine=True, chunk_size=DEFAULT_CHUNK_SIZE, margin_size=DEFAULT_MARGIN_SIZE):
99
+ print("[~] Segmenting wave...")
 
 
 
 
 
 
 
 
 
 
 
 
100
  if combine:
101
+ processed_wave = None
102
  for segment_count, segment in enumerate(wave):
103
  start = 0 if segment_count == 0 else margin_size
104
  end = None if segment_count == len(wave) - 1 else -margin_size
105
  if margin_size == 0:
106
  end = None
107
+ if processed_wave is None:
108
  processed_wave = segment[:, start:end]
109
+ else:
110
  processed_wave = np.concatenate((processed_wave, segment[:, start:end]), axis=-1)
 
111
  else:
112
  processed_wave = []
113
  sample_count = wave.shape[-1]
 
119
  margin_size = chunk_size
120
 
121
  for segment_count, skip in enumerate(range(0, sample_count, chunk_size)):
 
122
  margin = 0 if segment_count == 0 else margin_size
123
  end = min(skip + chunk_size + margin_size, sample_count)
124
  start = skip - margin
 
128
 
129
  if end == sample_count:
130
  break
131
+ print("[+] Wave segmentation completed")
132
  return processed_wave
133
 
134
  def pad_wave(self, wave):
135
+ print("[~] Padding wave...")
 
 
 
 
 
 
 
 
 
 
 
136
  n_sample = wave.shape[1]
137
  trim = self.model.n_fft // 2
138
  gen_size = self.model.chunk_size - 2 * trim
139
  pad = gen_size - n_sample % gen_size
140
 
 
141
  wave_p = np.concatenate((np.zeros((2, trim)), wave, np.zeros((2, pad)), np.zeros((2, trim))), 1)
142
 
143
  mix_waves = []
 
146
  mix_waves.append(waves)
147
 
148
  mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
149
+ print(f"[+] Wave padded. Shape: {mix_waves.shape}")
150
  return mix_waves, pad, trim
151
 
152
  def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
153
+ print(f"[~] Processing wave segment {_id}...")
 
 
 
 
 
 
 
 
 
 
 
 
154
  mix_waves = mix_waves.split(1)
155
  with torch.no_grad():
156
  pw = []
 
163
  pw.append(processed_wav)
164
  processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
165
  q.put({_id: processed_signal})
166
+ print(f"[+] Wave segment {_id} processed")
167
  return processed_signal
168
 
169
  def process_wave(self, wave: np.array, mt_threads=1):
170
+ print(f"[~] Processing wave with {mt_threads} threads...")
 
 
 
 
 
 
 
 
 
171
  self.prog = tqdm(total=0)
172
  chunk = wave.shape[-1] // mt_threads
173
  waves = self.segment(wave, False, chunk)
174
 
 
175
  q = queue.Queue()
176
  threads = []
177
  for c, batch in enumerate(waves):
 
190
  processed_batches = [list(wave.values())[0] for wave in
191
  sorted(processed_batches, key=lambda d: list(d.keys())[0])]
192
  assert len(processed_batches) == len(waves), 'Incomplete processed batches, please reduce batch size!'
193
+ print("[+] Wave processing completed")
194
  return self.segment(processed_batches, True, chunk)
195
 
196
 
197
  def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2):
198
+ print(f"[~] Running MDX on file: {filename}")
199
  device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
200
 
201
  device_properties = torch.cuda.get_device_properties(device)
202
  vram_gb = device_properties.total_memory / 1024**3
203
  m_threads = 1 if vram_gb < 8 else 2
204
+ print(f"[~] Using {m_threads} threads for processing")
205
 
206
  model_hash = MDX.get_hash(model_path)
207
  mp = model_params.get(model_hash)
 
215
  )
216
 
217
  mdx_sess = MDX(model_path, model)
218
+ print("[~] Loading audio file...")
219
  wave, sr = librosa.load(filename, mono=False, sr=44100)
220
+ print("[~] Normalizing input wave...")
221
  peak = max(np.max(wave), abs(np.min(wave)))
222
  wave /= peak
223
  if denoise:
224
+ print("[~] Denoising wave...")
225
  wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))
226
  wave_processed *= 0.5
227
  else:
228
+ print("[~] Processing wave...")
229
  wave_processed = mdx_sess.process_wave(wave, m_threads)
 
230
  wave_processed *= peak
231
  stem_name = model.stem_name if suffix is None else suffix
232
 
233
  main_filepath = None
234
  if not exclude_main:
235
  main_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
236
+ print(f"[~] Writing main output to: {main_filepath}")
237
  sf.write(main_filepath, wave_processed.T, sr)
238
 
239
  invert_filepath = None
 
241
  diff_stem_name = stem_naming.get(stem_name) if invert_suffix is None else invert_suffix
242
  stem_name = f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
243
  invert_filepath = os.path.join(output_dir, f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav")
244
+ print(f"[~] Writing inverted output to: {invert_filepath}")
245
  sf.write(invert_filepath, (-wave_processed.T * model.compensation) + wave.T, sr)
246
 
247
  if not keep_orig:
248
+ print(f"[~] Removing original file: {filename}")
249
  os.remove(filename)
250
 
251
+ print("[~] Cleaning up...")
252
  del mdx_sess, wave_processed, wave
253
  if torch.cuda.is_available():
254
  torch.cuda.empty_cache()
255
  gc.collect()
256
+ print("[+] MDX processing completed")
257
  return main_filepath, invert_filepath
258
 
259
  def run_roformer(model_params, output_dir, model_name, filename, exclude_main=False, exclude_inversion=False, suffix=None, invert_suffix=None, denoise=False, keep_orig=True, m_threads=2):
260
+ print(f"[~] Running RoFormer on file: {filename}")
261
  os.makedirs(output_dir, exist_ok=True)
262
 
263
+ print("[~] Loading audio file...")
264
  wave, sr = librosa.load(filename, mono=False, sr=44100)
265
  base_name = os.path.splitext(os.path.basename(filename))[0]
266
 
267
  roformer_output_format = 'wav'
268
  roformer_overlap = 4
269
  roformer_segment_size = 256
270
+ print(f"[~] Output directory: {output_dir}")
271
  prompt = f'audio-separator "{filename}" --model_filename {model_name} --output_dir="{output_dir}" --output_format={roformer_output_format} --normalization=0.9 --mdxc_overlap={roformer_overlap} --mdxc_segment_size={roformer_segment_size}'
272
+ print(f"[~] Running command: {prompt}")
273
  os.system(prompt)
274
 
275
  vocals_file = f"{base_name}_Vocals.wav"
 
281
  if not exclude_main:
282
  main_filepath = os.path.join(output_dir, vocals_file)
283
  if os.path.exists(os.path.join(output_dir, f"{base_name}_(Vocals)_{model_name.replace('.9755.ckpt', '')}.wav")):
284
+ print(f"[~] Renaming vocals file to: {main_filepath}")
285
  os.rename(os.path.join(output_dir, f"{base_name}_(Vocals)_{model_name.replace('.9755.ckpt', '')}.wav"), main_filepath)
286
 
287
  if not exclude_inversion:
288
  invert_filepath = os.path.join(output_dir, instrumental_file)
289
  if os.path.exists(os.path.join(output_dir, f"{base_name}_(Instrumental)_{model_name.replace('.9755.ckpt', '')}.wav")):
290
+ print(f"[~] Renaming instrumental file to: {invert_filepath}")
291
  os.rename(os.path.join(output_dir, f"{base_name}_(Instrumental)_{model_name.replace('.9755.ckpt', '')}.wav"), invert_filepath)
292
 
293
  if not keep_orig:
294
+ print(f"[~] Removing original file: {filename}")
295
  os.remove(filename)
296
 
297
+ print("[+] RoFormer processing completed")
298
  return main_filepath, invert_filepath