Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchaudio | |
import gradio as gr | |
import spaces | |
from zonos.model import Zonos | |
from zonos.conditioning import make_cond_dict # Keep this; remove supported_language_codes | |
# We'll keep a global dictionary of loaded models to avoid reloading | |
MODELS_CACHE = {} | |
device = "cuda" | |
banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png" | |
BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>' | |
# Define a list of tuples: (Display Label, Language Code) | |
LANGUAGES = [ | |
("English", "en-us"), | |
("Japanese", "ja"), | |
("Chinese", "cmn"), | |
("French", "fr-fr"), | |
("German", "de"), | |
] | |
def load_model(model_name: str): | |
""" | |
Loads or retrieves a cached Zonos model, sets it to eval and bfloat16. | |
""" | |
global MODELS_CACHE | |
if model_name not in MODELS_CACHE: | |
print(f"Loading model: {model_name}") | |
model = Zonos.from_pretrained(model_name, device=device) | |
model = model.requires_grad_(False).eval() | |
model.bfloat16() # optional if GPU supports bfloat16 | |
MODELS_CACHE[model_name] = model | |
print(f"Model loaded successfully: {model_name}") | |
return MODELS_CACHE[model_name] | |
def tts(text, speaker_audio, selected_language_label, model_choice): | |
""" | |
text: str (Text prompt to synthesize) | |
speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy" | |
selected_language_label: str (the display name from the dropdown, e.g. "Chinese") | |
model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid") | |
Returns (sr_out, wav_out_numpy). | |
""" | |
# Map from label -> actual language code | |
label_to_code = dict(LANGUAGES) | |
# Convert the human-readable label back to the code | |
selected_language = label_to_code[selected_language_label] | |
model = load_model(model_choice) | |
if not text: | |
return None | |
# If the user did not provide a reference audio, skip | |
if speaker_audio is None: | |
return None | |
# Gradio gives audio in (sample_rate, numpy_array) format | |
sr, wav_np = speaker_audio | |
# Convert to Torch tensor | |
wav_tensor = torch.from_numpy(wav_np).float() | |
# If stereo or multi-channel, downmix to mono | |
if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1: | |
wav_tensor = wav_tensor.mean(dim=0) # => (samples,) | |
# Add batch dimension => (1, samples) | |
wav_tensor = wav_tensor.unsqueeze(0) | |
# Get speaker embedding | |
with torch.no_grad(): | |
spk_embedding = model.make_speaker_embedding(wav_tensor, sr) | |
spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16) | |
# Prepare conditioning dictionary | |
cond_dict = { | |
"text": text, | |
"speaker": spk_embedding, | |
"language": selected_language, # Use the code here | |
"device": device, | |
} | |
conditioning = model.prepare_conditioning(cond_dict) | |
# Generate codes | |
with torch.no_grad(): | |
codes = model.generate(conditioning) | |
# Decode the codes into raw audio | |
wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze() | |
sr_out = model.autoencoder.sampling_rate | |
return (sr_out, wav_out.numpy()) | |
def build_demo(): | |
with gr.Blocks(theme='davehornik/Tealy') as demo: | |
gr.HTML(BANNER, elem_id="banner") | |
gr.Markdown("## Zonos-v0.1 TTS Demo") | |
gr.Markdown( | |
""" | |
> **Zero-shot TTS with Voice Cloning**: Input text and a 10–30 second speaker sample to generate high-quality text-to-speech output. | |
> **Audio Prefix Inputs**: Enhance speaker matching by adding an audio prefix to the text, enabling behaviors like whispering that are hard to achieve with voice cloning alone. | |
> **Multilingual Support**: Supports English, Japanese, Chinese, French, and German. | |
""" | |
) | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="Text Prompt", | |
value="Hello from Zonos!", | |
lines=3 | |
) | |
ref_audio_input = gr.Audio( | |
label="Reference Audio (Speaker Cloning)", | |
type="numpy" | |
) | |
model_dropdown = gr.Dropdown( | |
label="Model Choice", | |
choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"], | |
value="Zyphra/Zonos-v0.1-hybrid", | |
interactive=True, | |
) | |
# For the language dropdown, we display only the friendly label | |
language_dropdown = gr.Dropdown( | |
label="Language", | |
choices=[label for (label, code) in LANGUAGES], | |
value="English", # default display | |
interactive=True, | |
) | |
generate_button = gr.Button("Generate") | |
audio_output = gr.Audio(label="Synthesized Output", type="numpy") | |
generate_button.click( | |
fn=tts, | |
inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown], | |
outputs=audio_output, | |
) | |
return demo | |
if __name__ == "__main__": | |
demo_app = build_demo() | |
demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True) | |