lj1995 commited on
Commit
254c0f5
·
verified ·
1 Parent(s): d0aac67

librosa.load->torchaudio.load

Browse files
Files changed (1) hide show
  1. inference_webui.py +8 -4
inference_webui.py CHANGED
@@ -479,17 +479,21 @@ def get_tts_wav(
479
  )
480
  if not ref_free:
481
  with torch.no_grad():
482
- wav16k, sr = librosa.load(ref_wav_path, sr=16000)
 
 
 
 
 
 
483
  if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
484
  gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
485
  raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
486
- wav16k = torch.from_numpy(wav16k)
487
  zero_wav_torch = torch.from_numpy(zero_wav)
488
  if is_half == True:
489
- wav16k = wav16k.half().to(device)
490
  zero_wav_torch = zero_wav_torch.half().to(device)
491
  else:
492
- wav16k = wav16k.to(device)
493
  zero_wav_torch = zero_wav_torch.to(device)
494
  wav16k = torch.cat([wav16k, zero_wav_torch])
495
  ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
 
479
  )
480
  if not ref_free:
481
  with torch.no_grad():
482
+ wav16k, sr = torchaudio.load(url_ref_wav)
483
+ wav16k=wav16k.to(device)
484
+ if wav16k.shape[0] == 2:
485
+ wav16k = wav16k.mean(0).unsqueeze(0)
486
+ if sr!=16000:
487
+ wav16k=resample(wav16k, sr, 16000, device)
488
+ wav16k=wav16k[0]
489
  if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
490
  gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
491
  raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
 
492
  zero_wav_torch = torch.from_numpy(zero_wav)
493
  if is_half == True:
494
+ wav16k = wav16k.half()
495
  zero_wav_torch = zero_wav_torch.half().to(device)
496
  else:
 
497
  zero_wav_torch = zero_wav_torch.to(device)
498
  wav16k = torch.cat([wav16k, zero_wav_torch])
499
  ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()