DreamVoice / dreamvoice /freevc_wrapper.py
Higobeatz's picture
freevc plugin
0dabde8
raw
history blame
2.24 kB
import os
import torch
import librosa
import soundfile as sf
from pathlib import Path
from transformers import WavLMModel
from .freevc.utils import load_checkpoint, get_hparams_from_file
from .freevc.models import SynthesizerTrn
# from mel_processing import mel_spectrogram_torch
# from free_vc.speaker_encoder.voice_encoder import SpeakerEncoder
# from speaker_encoder.voice_encoder import SpeakerEncoder
def get_freevc_models(path='freevc', speaker_path='../pre_ckpts/spk_encoder/pretrained.pt', device='cuda'):
hps = get_hparams_from_file(f"{path}/freevc.json")
freevc = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model).to(device)
freevc.eval()
load_checkpoint(f"{path}/freevc.pth", freevc, None)
cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
cmodel.eval()
# smodel = spk_encoder.load_model(Path(speaker_path), device)
# smodel = spk_encoder.load_model(Path(f"speaker_encoder/ckpt/pretrained_bak_5805000.pt"), 'cuda')
# smodel = SpeakerEncoder(f"speaker_encoder/ckpt/pretrained_bak_5805000.pt", device)
return freevc, cmodel, hps
@torch.no_grad()
def convert(freevc, content, speaker):
audio = freevc.infer(content, g=speaker)
audio = audio[0][0].data.cpu().float().numpy()
return audio, 16000
if __name__ == '__main__':
freevc_24, cmodel, smodel, hps = get_freevc_models()
tgt = 'p226_002.wav'
# src = 'p226_002.wav'
src = 'p225_001.wav'
device = 'cuda'
# tgt
wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
g_tgt = smodel.embed_utterance(wav_tgt)
g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
# g_tgt = spk_encoder.embed_utterance_batch(torch.tensor(wav_tgt).unsqueeze(0).cuda())
# src
wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
content = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
output, sr = convert(freevc_24, content, g_tgt)
sf.write('output.wav', output, sr)