|
import os |
|
import json |
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import librosa |
|
import numpy as np |
|
import soundfile as sf |
|
import gradio as gr |
|
from transformers import WavLMModel |
|
|
|
from env import AttrDict |
|
from meldataset import mel_spectrogram, MAX_WAV_VALUE |
|
from models import Generator |
|
from Utils.JDC.model import JDCNet |
|
|
|
|
|
|
|
hpfile = "config_v1_16k.json" |
|
ptfile = "exp/default/g_00700000" |
|
spk2id_path = "filelists/spk2id.json" |
|
f0_stats_path = "filelists/f0_stats.json" |
|
spk_stats_path = "filelists/spk_stats.json" |
|
spk_emb_dir = "dataset/spk" |
|
spk_wav_dir = "dataset/audio" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
with open(hpfile) as f: |
|
data = f.read() |
|
json_config = json.loads(data) |
|
h = AttrDict(json_config) |
|
|
|
|
|
F0_model = JDCNet(num_class=1, seq_len=192) |
|
generator = Generator(h, F0_model).to(device) |
|
|
|
state_dict_g = torch.load(ptfile, map_location=device) |
|
generator.load_state_dict(state_dict_g['generator'], strict=True) |
|
generator.remove_weight_norm() |
|
_ = generator.eval() |
|
|
|
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base-plus") |
|
wavlm.eval() |
|
wavlm.to(device) |
|
|
|
|
|
with open(spk2id_path) as f: |
|
spk2id = json.load(f) |
|
with open(f0_stats_path) as f: |
|
f0_stats = json.load(f) |
|
with open(spk_stats_path) as f: |
|
spk_stats = json.load(f) |
|
|
|
|
|
threshold = 10 |
|
step = (math.log(1100) - math.log(50)) / 256 |
|
def tune_f0(initial_f0, i): |
|
if i == 0: |
|
return initial_f0 |
|
voiced = initial_f0 > threshold |
|
initial_lf0 = torch.log(initial_f0) |
|
lf0 = initial_lf0 + step * i |
|
f0 = torch.exp(lf0) |
|
f0 = torch.where(voiced, f0, initial_f0) |
|
return f0 |
|
|
|
|
|
def convert(tgt_spk, src_wav, f0_shift=0): |
|
tgt_ref = spk_stats[tgt_spk]["best_spk_emb"] |
|
tgt_emb = f"{spk_emb_dir}/{tgt_spk}/{tgt_ref}.npy" |
|
|
|
with torch.no_grad(): |
|
|
|
spk_id = spk2id[tgt_spk] |
|
spk_id = torch.LongTensor([spk_id]).unsqueeze(0).to(device) |
|
|
|
spk_emb = np.load(tgt_emb) |
|
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device) |
|
|
|
f0_mean_tgt = f0_stats[tgt_spk]["mean"] |
|
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device) |
|
|
|
|
|
wav, sr = librosa.load(src_wav, sr=16000) |
|
wav = torch.FloatTensor(wav).to(device) |
|
mel = mel_spectrogram(wav.unsqueeze(0), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) |
|
|
|
x = wavlm(wav.unsqueeze(0)).last_hidden_state |
|
x = x.transpose(1, 2) |
|
x = F.pad(x, (0, mel.size(2) - x.size(2)), 'constant') |
|
|
|
|
|
f0 = generator.get_f0(mel, f0_mean_tgt) |
|
f0 = tune_f0(f0, f0_shift) |
|
x = generator.get_x(x, spk_emb, spk_id) |
|
y = generator.infer(x, f0) |
|
|
|
audio = y.squeeze() |
|
audio = audio / torch.max(torch.abs(audio)) * 0.95 |
|
audio = audio * MAX_WAV_VALUE |
|
audio = audio.cpu().numpy().astype('int16') |
|
|
|
sf.write("out.wav", audio, h.sampling_rate, "PCM_16") |
|
|
|
out_wav = "out.wav" |
|
return out_wav |
|
|
|
|
|
def change_spk(tgt_spk): |
|
tgt_ref = spk_stats[tgt_spk]["best_spk_emb"] |
|
tgt_wav = f"{spk_wav_dir}/{tgt_spk}/{tgt_ref}.wav" |
|
return tgt_wav |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# PitchVC") |
|
gr.Markdown("Gradio Demo for PitchVC. ([Github Repo](https://github.com/OlaWod/PitchVC))") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
tgt_spk = gr.Dropdown(choices=spk2id.keys(), type="value", label="Target Speaker") |
|
ref_audio = gr.Audio(label="Reference Audio", type='filepath') |
|
src_audio = gr.Audio(label="Source Audio", type='filepath') |
|
f0_shift = gr.Slider(minimum=-30, maximum=30, value=0, step=1, label="F0 Shift") |
|
with gr.Column(): |
|
out_audio = gr.Audio(label="Output Audio", type='filepath') |
|
submit = gr.Button(value="Submit") |
|
|
|
tgt_spk.change(fn=change_spk, inputs=[tgt_spk], outputs=[ref_audio]) |
|
submit.click(convert, [tgt_spk, src_audio, f0_shift], [out_audio]) |
|
|
|
examples = gr.Examples( |
|
examples=[["p225", 'dataset/audio/p226/p226_341.wav', 0], |
|
["p226", 'dataset/audio/p225/p225_220.wav', -5]], |
|
inputs=[tgt_spk, src_audio, f0_shift]) |
|
|
|
demo.launch() |
|
|