Spaces:
Running
Running
File size: 4,340 Bytes
18be3e0 94a6ab2 18be3e0 94a6ab2 18be3e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
import json
import math
import torch
import torch.nn.functional as F
import librosa
import numpy as np
import soundfile as sf
import gradio as gr
from transformers import WavLMModel
from env import AttrDict
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from models import Generator
from Utils.JDC.model import JDCNet
# files
hpfile = "config_v1_16k.json"
ptfile = "exp/default/g_00700000"
spk2id_path = "filelists/spk2id.json"
f0_stats_path = "filelists/f0_stats.json"
spk_stats_path = "filelists/spk_stats.json"
spk_emb_dir = "dataset/spk"
spk_wav_dir = "dataset/audio"
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load config
with open(hpfile) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
# load models
F0_model = JDCNet(num_class=1, seq_len=192)
generator = Generator(h, F0_model).to(device)
state_dict_g = torch.load(ptfile, map_location=device)
generator.load_state_dict(state_dict_g['generator'], strict=True)
generator.remove_weight_norm()
_ = generator.eval()
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
wavlm.eval()
wavlm.to(device)
# load stats
with open(spk2id_path) as f:
spk2id = json.load(f)
with open(f0_stats_path) as f:
f0_stats = json.load(f)
with open(spk_stats_path) as f:
spk_stats = json.load(f)
# tune f0
threshold = 10
step = (math.log(1100) - math.log(50)) / 256
def tune_f0(initial_f0, i):
if i == 0:
return initial_f0
voiced = initial_f0 > threshold
initial_lf0 = torch.log(initial_f0)
lf0 = initial_lf0 + step * i
f0 = torch.exp(lf0)
f0 = torch.where(voiced, f0, initial_f0)
return f0
# convert function
def convert(tgt_spk, src_wav, f0_shift=0):
tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
tgt_emb = f"{spk_emb_dir}/{tgt_spk}/{tgt_ref}.npy"
with torch.no_grad():
# tgt
spk_id = spk2id[tgt_spk]
spk_id = torch.LongTensor([spk_id]).unsqueeze(0).to(device)
spk_emb = np.load(tgt_emb)
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)
# src
wav, sr = librosa.load(src_wav, sr=16000)
wav = torch.FloatTensor(wav).to(device)
mel = mel_spectrogram(wav.unsqueeze(0), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
x = wavlm(wav.unsqueeze(0)).last_hidden_state
x = x.transpose(1, 2) # (B, C, T)
x = F.pad(x, (0, mel.size(2) - x.size(2)), 'constant')
# cvt
f0 = generator.get_f0(mel, f0_mean_tgt)
f0 = tune_f0(f0, f0_shift)
x = generator.get_x(x, spk_emb, spk_id)
y = generator.infer(x, f0)
audio = y.squeeze()
audio = audio / torch.max(torch.abs(audio)) * 0.95
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
sf.write("out.wav", audio, h.sampling_rate, "PCM_16")
out_wav = "out.wav"
return out_wav
# change spk
def change_spk(tgt_spk):
tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
tgt_wav = f"{spk_wav_dir}/{tgt_spk}/{tgt_ref}.wav"
return tgt_wav
# interface
with gr.Blocks() as demo:
gr.Markdown("# PitchVC")
gr.Markdown("Gradio Demo for PitchVC. ([Github Repo](https://github.com/OlaWod/PitchVC))")
with gr.Row():
with gr.Column():
tgt_spk = gr.Dropdown(choices=spk2id.keys(), type="value", label="Target Speaker")
ref_audio = gr.Audio(label="Reference Audio", type='filepath')
src_audio = gr.Audio(label="Source Audio", type='filepath')
f0_shift = gr.Slider(minimum=-30, maximum=30, value=0, step=1, label="F0 Shift")
with gr.Column():
out_audio = gr.Audio(label="Output Audio", type='filepath')
submit = gr.Button(value="Submit")
tgt_spk.change(fn=change_spk, inputs=[tgt_spk], outputs=[ref_audio])
submit.click(convert, [tgt_spk, src_audio, f0_shift], [out_audio])
examples = gr.Examples(
examples=[["p225", 'dataset/audio/p226/p226_341.wav', 0],
["p226", 'dataset/audio/p225/p225_220.wav', -5]],
inputs=[tgt_spk, src_audio, f0_shift])
demo.launch()
|