Staticaliza commited on
Commit
18943e0
·
verified ·
1 Parent(s): 7cd73f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -43
app.py CHANGED
@@ -182,7 +182,7 @@ footer {
182
 
183
  @torch.no_grad()
184
  @torch.inference_mode()
185
- def voice_conversion(input, reference, steps, guidance, pitch, speed):
186
  print("[INFO] | Voice conversion started.")
187
 
188
  inference_module, mel_fn, bigvgan_fn = model, to_mel, bigvgan_model
@@ -203,15 +203,29 @@ def voice_conversion(input, reference, steps, guidance, pitch, speed):
203
  ref_audio_tensor = torch.tensor(ref_audio).unsqueeze(0).float().to(device)
204
 
205
  # Resample to 16kHz
206
- ref_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, sampling_rate)
207
- converted_waves_16k = torchaudio.functional.resample(source_audio_tensor, sr_current, sampling_rate)
208
 
209
- # Generate Whisper features
210
  print("[INFO] | Generating Whisper features for source audio.")
211
  if converted_waves_16k.size(-1) <= sampling_rate * 30:
212
- alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate=sampling_rate)
213
- alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
214
- alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
216
  S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
217
  print(f"[INFO] | S_alt shape: {S_alt.shape}")
@@ -227,13 +241,29 @@ def voice_conversion(input, reference, steps, guidance, pitch, speed):
227
  total_length = converted_waves_16k.size(-1)
228
 
229
  while traversed_time < total_length:
230
- if buffer is None:
231
- chunk = converted_waves_16k[:, traversed_time:traversed_time + chunk_size]
232
- else:
233
- chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + chunk_size - overlap_size]], dim=-1)
234
- alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate=sampling_rate)
235
- alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
236
- alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  S_chunk = alt_outputs.last_hidden_state.to(torch.float32)
238
  S_chunk = S_chunk[:, :chunk.size(-1) // 320 + 1]
239
  print(f"[INFO] | Processed chunk with S_chunk shape: {S_chunk.shape}")
@@ -250,12 +280,26 @@ def voice_conversion(input, reference, steps, guidance, pitch, speed):
250
  S_alt = torch.cat(S_alt_list, dim=1)
251
  print(f"[INFO] | Final S_alt shape after chunk processing: {S_alt.shape}")
252
 
253
- # Original Whisper features
254
  print("[INFO] | Generating Whisper features for reference audio.")
255
- ori_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, sampling_rate)
256
- ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate=sampling_rate)
257
- ori_input_features = whisper_model._mask_input_features(ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
258
- ori_outputs = whisper_model.encoder(ori_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
260
  S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
261
  print(f"[INFO] | S_ori shape: {S_ori.shape}")
@@ -267,21 +311,30 @@ def voice_conversion(input, reference, steps, guidance, pitch, speed):
267
  print(f"[INFO] | Mel spectrogram shapes: mel={mel.shape}, mel2={mel2.shape}")
268
 
269
  # Length adjustment
270
- target_lengths = torch.LongTensor([int(mel.size(2) / speed)]).to(mel.device)
271
  target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
272
  print(f"[INFO] | Target lengths: {target_lengths.item()}, {target2_lengths.item()}")
273
 
274
  # Extract style features
275
  print("[INFO] | Extracting style features from reference audio.")
276
- feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k, num_mel_bins=80, dither=0, sample_frequency=sampling_rate)
 
 
 
 
 
277
  feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
278
- style2 = campplus_model(feat2.unsqueeze(0))
279
  print(f"[INFO] | Style2 shape: {style2.shape}")
280
 
281
  # Length Regulation
282
  print("[INFO] | Applying length regulation.")
283
- cond, _, _, _, _ = inference_module.length_regulator(S_alt, ylens=target_lengths, n_quantizers=3, f0=None)
284
- prompt_condition, _, _, _, _ = inference_module.length_regulator(S_ori, ylens=target2_lengths, n_quantizers=3, f0=None)
 
 
 
 
285
  print(f"[INFO] | Cond shape: {cond.shape}, Prompt condition shape: {prompt_condition.shape}")
286
 
287
  # Initialize variables for audio generation
@@ -297,56 +350,104 @@ def voice_conversion(input, reference, steps, guidance, pitch, speed):
297
  cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
298
 
299
  # Perform inference
300
- vc_target = inference_module.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device), mel2, style2, None, steps, inference_cfg_rate=guidance)
 
 
 
 
 
 
 
 
301
  vc_target = vc_target[:, :, mel2.size(2):]
302
  print(f"[INFO] | vc_target shape: {vc_target.shape}")
303
-
304
- # TEMP
305
- output_wave = vc_target[0].cpu().numpy()
306
- generated_wave_chunks.append(output_wave)
307
 
308
  # Generate waveform using BigVGAN
309
- """
310
  vc_wave = bigvgan_fn(vc_target.float())[0]
311
  print(f"[INFO] | vc_wave shape: {vc_wave.shape}")
312
 
313
  # Handle the generated waveform
314
- output_wave = vc_wave[0].cpu().numpy()
315
  generated_wave_chunks.append(output_wave)
316
- """
317
 
318
  # Ensure processed_frames increments correctly to avoid infinite loop
319
  processed_frames += vc_target.size(2)
320
-
321
  print(f"[INFO] | Processed frames updated to: {processed_frames}")
322
 
323
  # Concatenate all generated wave chunks
324
  final_audio = np.concatenate(generated_wave_chunks).astype(np.float32)
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  # Pitch Shifting using librosa
327
  print("[INFO] | Applying pitch shifting.")
328
  try:
329
  if pitch != 0:
330
- final_audio = librosa.effects.pitch_shift(final_audio, sr=sr_current, n_steps=pitch)
 
 
 
 
331
  print(f"[INFO] | Pitch shifted by {pitch} semitones.")
332
  else:
333
  print("[INFO] | No pitch shift applied.")
334
  except Exception as e:
335
  print(f"[ERROR] | Pitch shifting failed: {e}")
336
-
337
- # Normalize the audio to ensure it's within [-1.0, 1.0]
338
  max_val = np.max(np.abs(final_audio))
339
  if max_val > 1.0:
340
  final_audio = final_audio / max_val
341
- print("[INFO] | Final audio normalized.")
342
-
 
 
 
 
343
  # Save the audio to a temporary WAV file
344
  print("[INFO] | Saving final audio to a temporary WAV file.")
345
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
346
- sf.write(tmp_file.name, final_audio, sr_current, format='WAV')
347
- temp_file_path = tmp_file.name
348
-
349
- print(f"[INFO] | Final audio saved to {temp_file_path}")
 
 
 
350
 
351
  return temp_file_path
352
 
 
182
 
183
  @torch.no_grad()
184
  @torch.inference_mode()
185
+ def voice_conversion(input, reference, steps, guidance, speed, pitch):
186
  print("[INFO] | Voice conversion started.")
187
 
188
  inference_module, mel_fn, bigvgan_fn = model, to_mel, bigvgan_model
 
203
  ref_audio_tensor = torch.tensor(ref_audio).unsqueeze(0).float().to(device)
204
 
205
  # Resample to 16kHz
206
+ ref_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, sampling_rate).to(device)
207
+ converted_waves_16k = torchaudio.functional.resample(source_audio_tensor, sr_current, sampling_rate).to(device)
208
 
209
+ # Generate Whisper features for source audio
210
  print("[INFO] | Generating Whisper features for source audio.")
211
  if converted_waves_16k.size(-1) <= sampling_rate * 30:
212
+ alt_inputs = whisper_feature_extractor(
213
+ [converted_waves_16k.squeeze(0).cpu().numpy()],
214
+ return_tensors="pt",
215
+ return_attention_mask=True,
216
+ sampling_rate=sampling_rate
217
+ )
218
+ alt_input_features = whisper_model._mask_input_features(
219
+ alt_inputs.input_features,
220
+ attention_mask=alt_inputs.attention_mask
221
+ ).to(device)
222
+ alt_outputs = whisper_model.encoder(
223
+ alt_input_features.to(torch.float32),
224
+ head_mask=None,
225
+ output_attentions=False,
226
+ output_hidden_states=False,
227
+ return_dict=True
228
+ )
229
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
230
  S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
231
  print(f"[INFO] | S_alt shape: {S_alt.shape}")
 
241
  total_length = converted_waves_16k.size(-1)
242
 
243
  while traversed_time < total_length:
244
+ end_time = traversed_time + chunk_size
245
+ if end_time > total_length:
246
+ end_time = total_length
247
+ chunk = converted_waves_16k[:, traversed_time:end_time]
248
+ if buffer is not None:
249
+ chunk = torch.cat([buffer, chunk], dim=-1)
250
+ alt_inputs = whisper_feature_extractor(
251
+ [chunk.squeeze(0).cpu().numpy()],
252
+ return_tensors="pt",
253
+ return_attention_mask=True,
254
+ sampling_rate=sampling_rate
255
+ )
256
+ alt_input_features = whisper_model._mask_input_features(
257
+ alt_inputs.input_features,
258
+ attention_mask=alt_inputs.attention_mask
259
+ ).to(device)
260
+ alt_outputs = whisper_model.encoder(
261
+ alt_input_features.to(torch.float32),
262
+ head_mask=None,
263
+ output_attentions=False,
264
+ output_hidden_states=False,
265
+ return_dict=True
266
+ )
267
  S_chunk = alt_outputs.last_hidden_state.to(torch.float32)
268
  S_chunk = S_chunk[:, :chunk.size(-1) // 320 + 1]
269
  print(f"[INFO] | Processed chunk with S_chunk shape: {S_chunk.shape}")
 
280
  S_alt = torch.cat(S_alt_list, dim=1)
281
  print(f"[INFO] | Final S_alt shape after chunk processing: {S_alt.shape}")
282
 
283
+ # Generate Whisper features for reference audio
284
  print("[INFO] | Generating Whisper features for reference audio.")
285
+ ori_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, sampling_rate).to(device)
286
+ ori_inputs = whisper_feature_extractor(
287
+ [ori_waves_16k.squeeze(0).cpu().numpy()],
288
+ return_tensors="pt",
289
+ return_attention_mask=True,
290
+ sampling_rate=sampling_rate
291
+ )
292
+ ori_input_features = whisper_model._mask_input_features(
293
+ ori_inputs.input_features,
294
+ attention_mask=ori_inputs.attention_mask
295
+ ).to(device)
296
+ ori_outputs = whisper_model.encoder(
297
+ ori_input_features.to(torch.float32),
298
+ head_mask=None,
299
+ output_attentions=False,
300
+ output_hidden_states=False,
301
+ return_dict=True
302
+ )
303
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
304
  S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
305
  print(f"[INFO] | S_ori shape: {S_ori.shape}")
 
311
  print(f"[INFO] | Mel spectrogram shapes: mel={mel.shape}, mel2={mel2.shape}")
312
 
313
  # Length adjustment
314
+ target_lengths = torch.LongTensor([int(mel.size(2) * speed)]).to(mel.device)
315
  target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
316
  print(f"[INFO] | Target lengths: {target_lengths.item()}, {target2_lengths.item()}")
317
 
318
  # Extract style features
319
  print("[INFO] | Extracting style features from reference audio.")
320
+ feat2 = torchaudio.compliance.kaldi.fbank(
321
+ ref_waves_16k,
322
+ num_mel_bins=80,
323
+ dither=0,
324
+ sample_frequency=sampling_rate
325
+ )
326
  feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
327
+ style2 = campplus_model(feat2.unsqueeze(0)).to(device)
328
  print(f"[INFO] | Style2 shape: {style2.shape}")
329
 
330
  # Length Regulation
331
  print("[INFO] | Applying length regulation.")
332
+ cond, _, _, _, _ = inference_module.length_regulator(
333
+ S_alt, ylens=target_lengths, n_quantizers=3, f0=None
334
+ )
335
+ prompt_condition, _, _, _, _ = inference_module.length_regulator(
336
+ S_ori, ylens=target2_lengths, n_quantizers=3, f0=None
337
+ )
338
  print(f"[INFO] | Cond shape: {cond.shape}, Prompt condition shape: {prompt_condition.shape}")
339
 
340
  # Initialize variables for audio generation
 
350
  cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
351
 
352
  # Perform inference
353
+ vc_target = inference_module.cfm.inference(
354
+ cat_condition,
355
+ torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
356
+ mel2,
357
+ style2,
358
+ None,
359
+ steps,
360
+ inference_cfg_rate=guidance
361
+ )
362
  vc_target = vc_target[:, :, mel2.size(2):]
363
  print(f"[INFO] | vc_target shape: {vc_target.shape}")
 
 
 
 
364
 
365
  # Generate waveform using BigVGAN
 
366
  vc_wave = bigvgan_fn(vc_target.float())[0]
367
  print(f"[INFO] | vc_wave shape: {vc_wave.shape}")
368
 
369
  # Handle the generated waveform
370
+ output_wave = vc_wave.squeeze(0).cpu().numpy()
371
  generated_wave_chunks.append(output_wave)
 
372
 
373
  # Ensure processed_frames increments correctly to avoid infinite loop
374
  processed_frames += vc_target.size(2)
 
375
  print(f"[INFO] | Processed frames updated to: {processed_frames}")
376
 
377
  # Concatenate all generated wave chunks
378
  final_audio = np.concatenate(generated_wave_chunks).astype(np.float32)
379
 
380
+ # Normalize the audio to ensure it's within [-1.0, 1.0]
381
+ max_val = np.max(np.abs(final_audio))
382
+ if max_val > 1.0:
383
+ final_audio = final_audio / max_val
384
+ print("[INFO] | Final audio normalized.")
385
+
386
+ # ----------------------------
387
+ # Audio Processing: Noise Reduction and Pitch Shifting
388
+ # ----------------------------
389
+
390
+ # Noise Reduction using noisereduce
391
+ print("[INFO] | Applying noise reduction.")
392
+ try:
393
+ # Option 1: Using a Noise Sample (first 0.5 seconds)
394
+ noise_duration = 0.5 # seconds
395
+ noise_sample = final_audio[:int(noise_duration * sr_current)]
396
+ final_audio = nr.reduce_noise(
397
+ y=final_audio,
398
+ sr=sr_current,
399
+ y_noise=noise_sample,
400
+ prop_decrease=1.0
401
+ )
402
+ print("[INFO] | Noise reduction applied using a noise sample.")
403
+ except Exception as e:
404
+ print(f"[ERROR] | Noise reduction with noise sample failed: {e}")
405
+ # Option 2: Automatic Noise Estimation
406
+ try:
407
+ final_audio = nr.reduce_noise(
408
+ y=final_audio,
409
+ sr=sr_current,
410
+ stationary=False
411
+ )
412
+ print("[INFO] | Noise reduction applied with automatic noise estimation.")
413
+ except Exception as e:
414
+ print(f"[ERROR] | Noise reduction with automatic estimation failed: {e}")
415
+
416
  # Pitch Shifting using librosa
417
  print("[INFO] | Applying pitch shifting.")
418
  try:
419
  if pitch != 0:
420
+ final_audio = librosa.effects.pitch_shift(
421
+ final_audio,
422
+ sr=sr_current,
423
+ n_steps=pitch
424
+ )
425
  print(f"[INFO] | Pitch shifted by {pitch} semitones.")
426
  else:
427
  print("[INFO] | No pitch shift applied.")
428
  except Exception as e:
429
  print(f"[ERROR] | Pitch shifting failed: {e}")
430
+
431
+ # Optional: Further Normalization after Pitch Shifting
432
  max_val = np.max(np.abs(final_audio))
433
  if max_val > 1.0:
434
  final_audio = final_audio / max_val
435
+ print("[INFO] | Final audio normalized after pitch shifting.")
436
+
437
+ # ----------------------------
438
+ # Save the Audio
439
+ # ----------------------------
440
+
441
  # Save the audio to a temporary WAV file
442
  print("[INFO] | Saving final audio to a temporary WAV file.")
443
+ try:
444
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
445
+ sf.write(tmp_file.name, final_audio, sr_current, format='WAV')
446
+ temp_file_path = tmp_file.name
447
+ print(f"[INFO] | Final audio saved to {temp_file_path}")
448
+ except Exception as e:
449
+ print(f"[ERROR] | Saving audio failed: {e}")
450
+ return None
451
 
452
  return temp_file_path
453