Staticaliza commited on
Commit
add1014
·
verified ·
1 Parent(s): 464583c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -62
app.py CHANGED
@@ -180,17 +180,12 @@ footer {
180
 
181
  @torch.no_grad()
182
  @torch.inference_mode()
183
- def voice_conversion(input, reference, steps, guidance, speed, use_conditioned, use_auto_adjustment, pitch):
184
  print("[INFO] | Voice conversion started.")
185
 
186
- inference_module = model if not use_conditioned else model_f0
187
- mel_fn = to_mel if not use_conditioned else to_mel_f0
188
- bigvgan_fn = bigvgan_model if not use_conditioned else bigvgan_44k_model
189
- sr_current = 22050 if not use_conditioned else 44100
190
- hop_length_current = 256 if not use_conditioned else 512
191
- max_context_window = sr_current // hop_length_current * 30
192
- overlap_wave_len = 16 * hop_length_current
193
- bitrate = "320k"
194
 
195
  # Load audio using librosa
196
  print("[INFO] | Loading source and reference audio.")
@@ -206,13 +201,13 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
206
  ref_audio_tensor = torch.tensor(ref_audio).unsqueeze(0).float().to(device)
207
 
208
  # Resample to 16kHz
209
- ref_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, 16000)
210
- converted_waves_16k = torchaudio.functional.resample(source_audio_tensor, sr_current, 16000)
211
 
212
  # Generate Whisper features
213
  print("[INFO] | Generating Whisper features for source audio.")
214
- if converted_waves_16k.size(-1) <= 16000 * 30:
215
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate=16000)
216
  alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
217
  alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
218
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
@@ -222,8 +217,8 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
222
  # Process in chunks
223
  print("[INFO] | Processing source audio in chunks.")
224
  overlapping_time = 5 # seconds
225
- chunk_size = 16000 * 30 # 30 seconds
226
- overlap_size = 16000 * overlapping_time
227
  S_alt_list = []
228
  buffer = None
229
  traversed_time = 0
@@ -234,7 +229,7 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
234
  chunk = converted_waves_16k[:, traversed_time:traversed_time + chunk_size]
235
  else:
236
  chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + chunk_size - overlap_size]], dim=-1)
237
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],return_tensors="pt", return_attention_mask=True, sampling_rate=16000)
238
  alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
239
  alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
240
  S_chunk = alt_outputs.last_hidden_state.to(torch.float32)
@@ -255,8 +250,8 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
255
 
256
  # Original Whisper features
257
  print("[INFO] | Generating Whisper features for reference audio.")
258
- ori_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, 16000)
259
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate=16000)
260
  ori_input_features = whisper_model._mask_input_features(ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
261
  ori_outputs = whisper_model.encoder(ori_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
262
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
@@ -276,48 +271,15 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
276
 
277
  # Extract style features
278
  print("[INFO] | Extracting style features from reference audio.")
279
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k, num_mel_bins=80, dither=0, sample_frequency=16000)
280
  feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
281
  style2 = campplus_model(feat2.unsqueeze(0))
282
  print(f"[INFO] | Style2 shape: {style2.shape}")
283
 
284
- # F0 Conditioning
285
- if use_conditioned:
286
- print("[INFO] | Performing F0 conditioning.")
287
- F0_ori = rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.5)
288
- F0_alt = rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.5)
289
-
290
- F0_ori = torch.from_numpy(F0_ori).to(device)[None].float()
291
- F0_alt = torch.from_numpy(F0_alt).to(device)[None].float()
292
-
293
- voiced_F0_ori = F0_ori[F0_ori > 1]
294
- voiced_F0_alt = F0_alt[F0_alt > 1]
295
-
296
- log_f0_alt = torch.log(F0_alt + 1e-5)
297
- voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
298
- voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
299
-
300
- median_log_f0_ori = torch.median(voiced_log_f0_ori)
301
- median_log_f0_alt = torch.median(voiced_log_f0_alt)
302
-
303
- # Shift F0 levels
304
- shifted_log_f0_alt = log_f0_alt.clone()
305
- if auto_f0_adjust:
306
- shifted_log_f0_alt[F0_alt > 1] = (log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori)
307
- shifted_f0_alt = torch.exp(shifted_log_f0_alt)
308
- if pitch != 0:
309
- shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch)
310
- print("[INFO] | F0 conditioning completed.")
311
- else:
312
- F0_ori = None
313
- F0_alt = None
314
- shifted_f0_alt = None
315
- print("[INFO] | F0 conditioning not applied.")
316
-
317
  # Length Regulation
318
  print("[INFO] | Applying length regulation.")
319
- cond, _, _, _, _ = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt)
320
- prompt_condition, _, _, _, _ = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori)
321
  print(f"[INFO] | Cond shape: {cond.shape}, Prompt condition shape: {prompt_condition.shape}")
322
 
323
  # Initialize variables for audio generation
@@ -345,8 +307,8 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
345
  output_wave = vc_wave[0].cpu().numpy()
346
  generated_wave_chunks.append(output_wave)
347
 
348
- # Fix: Ensure processed_frames increments correctly to avoid infinite loop
349
- processed_frames += vc_target.size(2) # Changed from 'vc_target.size(2) - 16' to 'vc_target.size(2)'
350
  print(f"[INFO] | Processed frames updated to: {processed_frames}")
351
 
352
  # Concatenate all generated wave chunks
@@ -392,11 +354,6 @@ with gr.Blocks(css=css) as main:
392
  guidance = gr.Slider(label="Guidance", value=0.7, minimum=0.0, maximum=1.0, step=0.1)
393
  speed = gr.Slider(label="Speed", value=1.0, minimum=0.5, maximum=2.0, step=0.1)
394
 
395
- with gr.Column():
396
- use_conditioned = gr.Checkbox(label="Use 'F0 Conditioned Model'", value=False)
397
- use_auto_adjustment = gr.Checkbox(label="Use 'Auto F0 Adjustment' with 'F0 Conditioned Model'", value=True)
398
- pitch = gr.Slider(label="Pitch with 'F0 Conditioned Model'", value=0, minimum=-12, maximum=12, step=1)
399
-
400
  with gr.Column():
401
  submit = gr.Button("▶")
402
  maintain = gr.Button("☁️")
@@ -404,7 +361,7 @@ with gr.Blocks(css=css) as main:
404
  with gr.Column():
405
  output = gr.Audio(label="Output", type="filepath")
406
 
407
- submit.click(voice_conversion, inputs=[input, reference_input, steps, guidance, speed, use_conditioned, use_auto_adjustment, pitch], outputs=output, queue=False)
408
  maintain.click(cloud, inputs=[], outputs=[], queue=False)
409
 
410
  main.launch(show_api=True)
 
180
 
181
  @torch.no_grad()
182
  @torch.inference_mode()
183
+ def voice_conversion(input, reference, steps, guidance, speed):
184
  print("[INFO] | Voice conversion started.")
185
 
186
+ inference_module, mel_fn, bigvgan_fn = model, to_mel, bigvgan_model
187
+ bitrate, sampling_rate, sr_current, hop_length_current = "320k", 16000, 22050, 256
188
+ max_context_window, overlap_wave_len = sr_current // hop_length_current * 30, 16 * hop_length_current
 
 
 
 
 
189
 
190
  # Load audio using librosa
191
  print("[INFO] | Loading source and reference audio.")
 
201
  ref_audio_tensor = torch.tensor(ref_audio).unsqueeze(0).float().to(device)
202
 
203
  # Resample to 16kHz
204
+ ref_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, sampling_rate)
205
+ converted_waves_16k = torchaudio.functional.resample(source_audio_tensor, sr_current, sampling_rate)
206
 
207
  # Generate Whisper features
208
  print("[INFO] | Generating Whisper features for source audio.")
209
+ if converted_waves_16k.size(-1) <= sampling_rate * 30:
210
+ alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate)
211
  alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
212
  alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
213
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
 
217
  # Process in chunks
218
  print("[INFO] | Processing source audio in chunks.")
219
  overlapping_time = 5 # seconds
220
+ chunk_size = sampling_rate * 30 # 30 seconds
221
+ overlap_size = sampling_rate * overlapping_time
222
  S_alt_list = []
223
  buffer = None
224
  traversed_time = 0
 
229
  chunk = converted_waves_16k[:, traversed_time:traversed_time + chunk_size]
230
  else:
231
  chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + chunk_size - overlap_size]], dim=-1)
232
+ alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate)
233
  alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
234
  alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
235
  S_chunk = alt_outputs.last_hidden_state.to(torch.float32)
 
250
 
251
  # Original Whisper features
252
  print("[INFO] | Generating Whisper features for reference audio.")
253
+ ori_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, sampling_rate)
254
+ ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate)
255
  ori_input_features = whisper_model._mask_input_features(ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
256
  ori_outputs = whisper_model.encoder(ori_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
257
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
 
271
 
272
  # Extract style features
273
  print("[INFO] | Extracting style features from reference audio.")
274
+ feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k, num_mel_bins=80, dither=0, sample_frequency=sampling_rate)
275
  feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
276
  style2 = campplus_model(feat2.unsqueeze(0))
277
  print(f"[INFO] | Style2 shape: {style2.shape}")
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  # Length Regulation
280
  print("[INFO] | Applying length regulation.")
281
+ cond, _, _, _, _ = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=None)
282
+ prompt_condition, _, _, _, _ = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=None)
283
  print(f"[INFO] | Cond shape: {cond.shape}, Prompt condition shape: {prompt_condition.shape}")
284
 
285
  # Initialize variables for audio generation
 
307
  output_wave = vc_wave[0].cpu().numpy()
308
  generated_wave_chunks.append(output_wave)
309
 
310
+ # Ensure processed_frames increments correctly to avoid infinite loop
311
+ processed_frames += vc_target.size(2)
312
  print(f"[INFO] | Processed frames updated to: {processed_frames}")
313
 
314
  # Concatenate all generated wave chunks
 
354
  guidance = gr.Slider(label="Guidance", value=0.7, minimum=0.0, maximum=1.0, step=0.1)
355
  speed = gr.Slider(label="Speed", value=1.0, minimum=0.5, maximum=2.0, step=0.1)
356
 
 
 
 
 
 
357
  with gr.Column():
358
  submit = gr.Button("▶")
359
  maintain = gr.Button("☁️")
 
361
  with gr.Column():
362
  output = gr.Audio(label="Output", type="filepath")
363
 
364
+ submit.click(voice_conversion, inputs=[input, reference_input, steps, guidance, speed], outputs=output, queue=False)
365
  maintain.click(cloud, inputs=[], outputs=[], queue=False)
366
 
367
  main.launch(show_api=True)