stanimirovb's picture
change title to make demo difference easier to see
aae23b3 verified
import os
import random
import gradio as gr
import numpy as np
import torch
import torchaudio
from huggingface_hub import snapshot_download
from play_voice_inference.utils.voice_tokenizer import VoiceBpeTokenizer
from play_voice_inference.models.play_voice import LanguageIdentifiers, SpeakerAttributes, SpeechAttributes, load_play_voice
from play_voice_inference.utils.play_voice_sampler import PlayVoiceSampler
from play_voice_inference.utils.pv_diff_sampler import PlayVoiceDiffusionDecoderSampler
torch.set_grad_enabled(False)
device = torch.device('cuda')
HF_TOKEN = os.environ['HF_TOKEN']
print("Loading models...")
tokenizer = VoiceBpeTokenizer()
MODEL_DIR = snapshot_download('PlayHT/play-voice-v0-multi', token=HF_TOKEN)
PV_AR_PT = MODEL_DIR + '/pv-v1-ar.pth'
play_voice = load_play_voice(PV_AR_PT, device)
sampler = PlayVoiceSampler(play_voice).to(device)
NUM_DIFFUSION_STEPS: int = 150
DIFFUSION_PT = MODEL_DIR + '/pv-v1-diff-xf.pth'
DIFFUSION_VOCODER_PT = MODEL_DIR + '/pv-v1-diff-bigvgan.pt'
vocoder = PlayVoiceDiffusionDecoderSampler.from_path(
DIFFUSION_PT,
DIFFUSION_VOCODER_PT,
steps=NUM_DIFFUSION_STEPS,
silent=True,
use_fp16=True,
device=device
)
print("Preparing voices...")
VOICES_DIR = snapshot_download('PlayHT/play-voice-voices', repo_type='dataset', token=HF_TOKEN)
def load_audio(path: str, sr=24000):
audio, orig_sr = torchaudio.load(path)
if orig_sr != sr:
audio = torchaudio.transforms.Resample(orig_sr, sr)(audio)
return audio
def make_pcm(audio: torch.Tensor):
# Must convert to 16-bit PCM for gradio
# remove batch dim if any
# if len(audio.shape) > 2:
# audio = audio[0]
# audio = audio.transpose(0, 1) # gradio expects [samples, channels] and throws very unhelpful errors if it's wrong
gen_np = audio.squeeze().cpu().numpy()
i = np.iinfo("int16")
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
gen_np = (gen_np * abs_max + offset).clip(i.min, i.max).astype("int16")
return gen_np
initial_voices = []
for item in os.listdir(VOICES_DIR):
if item.endswith(".wav"):
name = os.path.splitext(item)[0]
initial_voices.append({"name": name, "audio": load_audio(os.path.join(VOICES_DIR, item))})
initial_voices.sort(key=lambda x: x["name"])
print(f"Found {len(initial_voices)} initial voices")
def get_voice_labels(voices: list[dict]):
labels = []
for voice in voices:
labels.append(voice["name"])
return labels
with gr.Blocks(analytics_enabled=False, title="Play Voice", mode="tts") as iface:
local_voices = gr.State(initial_voices)
def get_selected_voice_by_label(voices, label: str):
labels = get_voice_labels(voices)
for i, voice_label in enumerate(labels):
if voice_label == label:
return voices[i]
raise Exception("Voice not found: " + label)
def make_voice_dropdown(voices):
choices = get_voice_labels(voices)
return gr.Dropdown(
choices=choices,
value=choices[-1] if len(choices) > 0 else None,
label="Voice",
)
def make_enum_dropdown(enum, label, default=None, allow_none=False):
choices = [e.name for e in enum]
if allow_none:
choices.append("none")
return gr.Dropdown(
choices=choices,
value=default,
label=label,
)
def get_enum_value(enum, value):
if value == "none":
return None
return enum[value]
gr.Markdown("# Play Voice (pretrained)\n")
with gr.Tab("TTS"):
speak_text = gr.Textbox(lines=2, placeholder="What would you like to say?", label="Text")
speak_voice = make_voice_dropdown(initial_voices)
with gr.Accordion("Settings", open=False):
speaker_attributes = make_enum_dropdown(
SpeakerAttributes, "Speaker Attributes", "full_sentence", allow_none=True
)
speech_attributes = make_enum_dropdown(SpeechAttributes, "Speech Attributes", "none", allow_none=True)
language = make_enum_dropdown(LanguageIdentifiers, "Language", "none", allow_none=True)
temperature = gr.Slider(minimum=0, maximum=2.0, value=0.3, label="Temperature")
repetition_penalty = gr.Slider(minimum=1.0, maximum=10.0, value=1.8, label="Repetition Penalty")
filter_thresh = gr.Slider(minimum=0.1, maximum=1.0, value=0.75, label="Top-p Threshold")
voice_guidance = gr.Slider(minimum=0.0, maximum=6.0, value=0.4, label="Voice Guidance")
style_guidance = gr.Slider(minimum=0.0, maximum=6.0, value=0.1, label="Style Guidance")
text_guidance = gr.Slider(minimum=0.0, maximum=6.0, value=0.6, label="Text Guidance")
speak_submit = gr.Button("Speak!")
speak_result = gr.Audio(label="Result", interactive=False)
ref_voice = gr.Audio(label="Reference Voice", interactive=False)
@torch.no_grad()
def handle_speak(
text,
voices,
voice_name,
voice_guidance,
speaker_attributes,
speech_attributes,
language,
temperature,
repetition_penalty,
top_p,
style_guidance,
text_guidance,
):
if text.strip() == "":
text = "I am PlayVoice, the voice of the future. Feed me your words and I will speak them, hahahaha!"
voice = get_selected_voice_by_label(voices, voice_name)
seed = random.randint(0, 2**32 - 1)
print(f"Voice: {voice['name']} Text: {text}")
voice_emb = sampler.get_voice_embedding(voice["audio"])
text_tokens = []
text_tokens.append(torch.tensor(tokenizer.encode(text), dtype=torch.int, device=device))
text_tokens = torch.nn.utils.rnn.pad_sequence(text_tokens, batch_first=True, padding_value=0)
torch.manual_seed(seed)
sample_result = sampler.sample_batched(
text_tokens=text_tokens,
text_guidance=text_guidance,
voice_emb=voice_emb,
voice_guidance=voice_guidance,
speaker_attributes=get_enum_value(SpeakerAttributes, speaker_attributes),
speech_attributes=get_enum_value(SpeechAttributes, speech_attributes),
language_identifier=get_enum_value(LanguageIdentifiers, language),
style_guidance=float(style_guidance),
temperature=float(temperature),
repetition_penalty=float(repetition_penalty),
top_p=float(top_p),
)
latents = sample_result["latents"]
audio = vocoder.sample(text_tokens, latents, ref_wav=voice["audio"])
audio = make_pcm(audio)
return {
speak_result: (vocoder.OUTPUT_FREQUENCY, audio),
ref_voice: (22050, make_pcm(voice["audio"])),
}
speak_submit.click(
handle_speak,
inputs=[
speak_text,
local_voices,
speak_voice,
voice_guidance,
speaker_attributes,
speech_attributes,
language,
temperature,
repetition_penalty,
filter_thresh,
style_guidance,
text_guidance,
],
outputs=[
speak_result,
ref_voice,
],
)
with gr.Tab("Clone Voice"):
new_voice_name = gr.Textbox(value="cloned-voice", label="Voice Name")
new_voice_audio = gr.Audio(label="Voice Audio (20s min, ideally 30s, anything longer will be truncated)",
sources=["upload", "microphone"],
)
new_voice_submit = gr.Button("Create!")
new_voice_result = gr.Label("")
def on_new_voice_submit(voices, name, raw_audio):
assert raw_audio is not None, "Must provide audio"
sr = raw_audio[0]
torch_audio = torch.from_numpy(raw_audio[1]).float() / 32768.0
if torch_audio.ndim == 1:
torch_audio = torch_audio.unsqueeze(0)
else:
torch_audio = torch_audio.transpose(0, 1).mean(dim=0, keepdim=True)
if sr != 24000:
if sr < 16000:
raise Exception(
"Garbage in, garbage out. Please provide audio with a sample rate of at least 16kHz, ideally 24kHz."
)
torch_audio = torchaudio.transforms.Resample(sr, 24000)(torch_audio)
# trim to 30s
if torch_audio.shape[1] > 24000 * 30:
torch_audio = torch_audio[:, : 24000 * 30]
# add to local voices
voices.append({"name": name, "audio": torch_audio})
return {
speak_voice: make_voice_dropdown(voices),
new_voice_result: f"Created voice {name}",
}
new_voice_submit.click(
on_new_voice_submit,
inputs = [
local_voices,
new_voice_name,
new_voice_audio
],
outputs=[
speak_voice,
new_voice_result
]
)
iface.launch(show_error=True, share=False)