VoiceClone-TTS / app.py
Steveeeeeeen's picture
Steveeeeeeen HF staff
Update app.py
b8a3553 verified
raw
history blame
5.24 kB
import torch
import torchaudio
import gradio as gr
import spaces
from zonos.model import Zonos
from zonos.conditioning import make_cond_dict # Keep this; remove supported_language_codes
# We'll keep a global dictionary of loaded models to avoid reloading
MODELS_CACHE = {}
device = "cuda"
banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'
# Define a list of tuples: (Display Label, Language Code)
LANGUAGES = [
("English", "en-us"),
("Japanese", "ja"),
("Chinese", "cmn"),
("French", "fr-fr"),
("German", "de"),
]
def load_model(model_name: str):
"""
Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
"""
global MODELS_CACHE
if model_name not in MODELS_CACHE:
print(f"Loading model: {model_name}")
model = Zonos.from_pretrained(model_name, device=device)
model = model.requires_grad_(False).eval()
model.bfloat16() # optional if GPU supports bfloat16
MODELS_CACHE[model_name] = model
print(f"Model loaded successfully: {model_name}")
return MODELS_CACHE[model_name]
@spaces.GPU(duration=90)
def tts(text, speaker_audio, selected_language_label, model_choice):
"""
text: str (Text prompt to synthesize)
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
selected_language_label: str (the display name from the dropdown, e.g. "Chinese")
model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
Returns (sr_out, wav_out_numpy).
"""
# Map from label -> actual language code
label_to_code = dict(LANGUAGES)
# Convert the human-readable label back to the code
selected_language = label_to_code[selected_language_label]
model = load_model(model_choice)
if not text:
return None
# If the user did not provide a reference audio, skip
if speaker_audio is None:
return None
# Gradio gives audio in (sample_rate, numpy_array) format
sr, wav_np = speaker_audio
# Convert to Torch tensor
wav_tensor = torch.from_numpy(wav_np).float()
# If stereo or multi-channel, downmix to mono
if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
wav_tensor = wav_tensor.mean(dim=0) # => (samples,)
# Add batch dimension => (1, samples)
wav_tensor = wav_tensor.unsqueeze(0)
# Get speaker embedding
with torch.no_grad():
spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)
# Prepare conditioning dictionary
cond_dict = {
"text": text,
"speaker": spk_embedding,
"language": selected_language, # Use the code here
"device": device,
}
conditioning = model.prepare_conditioning(cond_dict)
# Generate codes
with torch.no_grad():
codes = model.generate(conditioning)
# Decode the codes into raw audio
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
sr_out = model.autoencoder.sampling_rate
return (sr_out, wav_out.numpy())
def build_demo():
with gr.Blocks(theme='davehornik/Tealy') as demo:
gr.HTML(BANNER, elem_id="banner")
gr.Markdown("## Zonos-v0.1 TTS Demo")
gr.Markdown(
"""
> **Zero-shot TTS with Voice Cloning**: Input text and a 10–30 second speaker sample to generate high-quality text-to-speech output.
> **Audio Prefix Inputs**: Enhance speaker matching by adding an audio prefix to the text, enabling behaviors like whispering that are hard to achieve with voice cloning alone.
> **Multilingual Support**: Supports English, Japanese, Chinese, French, and German.
"""
)
with gr.Row():
text_input = gr.Textbox(
label="Text Prompt",
value="Hello from Zonos!",
lines=3
)
ref_audio_input = gr.Audio(
label="Reference Audio (Speaker Cloning)",
type="numpy"
)
model_dropdown = gr.Dropdown(
label="Model Choice",
choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
value="Zyphra/Zonos-v0.1-hybrid",
interactive=True,
)
# For the language dropdown, we display only the friendly label
language_dropdown = gr.Dropdown(
label="Language",
choices=[label for (label, code) in LANGUAGES],
value="English", # default display
interactive=True,
)
generate_button = gr.Button("Generate")
audio_output = gr.Audio(label="Synthesized Output", type="numpy")
generate_button.click(
fn=tts,
inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown],
outputs=audio_output,
)
return demo
if __name__ == "__main__":
demo_app = build_demo()
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)