File size: 5,235 Bytes
748ecaa
 
 
a86425f
748ecaa
 
b8a3553
748ecaa
15961ae
 
e5d26e9
748ecaa
97132bd
 
 
b8a3553
 
 
 
 
 
 
 
 
15961ae
46f1390
15961ae
46f1390
15961ae
 
46f1390
15961ae
 
 
 
 
 
 
4a04525
b8a3553
46f1390
15961ae
46f1390
b8a3553
15961ae
22bde2c
 
46f1390
b8a3553
 
 
 
 
15961ae
46f1390
 
b1f1246
22bde2c
 
46f1390
 
 
22bde2c
46f1390
 
22bde2c
 
 
b8a3553
22bde2c
b8a3553
22bde2c
b8a3553
22bde2c
46f1390
 
 
 
e5d26e9
46f1390
 
b8a3553
 
 
 
 
 
d743fc1
46f1390
 
d743fc1
 
46f1390
 
 
 
 
 
 
 
97132bd
 
 
 
 
 
 
 
46f1390
97132bd
 
 
46f1390
 
 
 
 
 
 
 
 
 
15961ae
 
 
 
 
 
 
b8a3553
 
d5d8bf3
b8a3553
 
 
15961ae
d5d8bf3
46f1390
 
 
 
 
 
15961ae
46f1390
 
 
 
748ecaa
 
46f1390
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torchaudio
import gradio as gr
import spaces

from zonos.model import Zonos
from zonos.conditioning import make_cond_dict  # Keep this; remove supported_language_codes

# We'll keep a global dictionary of loaded models to avoid reloading
MODELS_CACHE = {}
device = "cuda"

banner_url = "https://huggingface.co/datasets/Steveeeeeeen/random_images/resolve/main/ZonosHeader.png"
BANNER = f'<div style="display: flex; justify-content: space-around;"><img src="{banner_url}" alt="Banner" style="width: 40vw; min-width: 150px; max-width: 300px;"> </div>'

# Define a list of tuples: (Display Label, Language Code)
LANGUAGES = [
    ("English",  "en-us"),
    ("Japanese", "ja"),
    ("Chinese",  "cmn"),
    ("French",   "fr-fr"),
    ("German",   "de"),
]

def load_model(model_name: str):
    """
    Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
    """
    global MODELS_CACHE
    if model_name not in MODELS_CACHE:
        print(f"Loading model: {model_name}")
        model = Zonos.from_pretrained(model_name, device=device)
        model = model.requires_grad_(False).eval()
        model.bfloat16()  # optional if GPU supports bfloat16
        MODELS_CACHE[model_name] = model
        print(f"Model loaded successfully: {model_name}")
    return MODELS_CACHE[model_name]

@spaces.GPU(duration=90)
def tts(text, speaker_audio, selected_language_label, model_choice):
    """
    text: str (Text prompt to synthesize)
    speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
    selected_language_label: str (the display name from the dropdown, e.g. "Chinese")
    model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")

    Returns (sr_out, wav_out_numpy).
    """
    # Map from label -> actual language code
    label_to_code = dict(LANGUAGES)
    # Convert the human-readable label back to the code
    selected_language = label_to_code[selected_language_label]

    model = load_model(model_choice)

    if not text:
        return None

    # If the user did not provide a reference audio, skip
    if speaker_audio is None:
        return None

    # Gradio gives audio in (sample_rate, numpy_array) format
    sr, wav_np = speaker_audio

    # Convert to Torch tensor
    wav_tensor = torch.from_numpy(wav_np).float()

    # If stereo or multi-channel, downmix to mono
    if wav_tensor.ndim == 2 and wav_tensor.shape[0] > 1:
        wav_tensor = wav_tensor.mean(dim=0)  # => (samples,)

    # Add batch dimension => (1, samples)
    wav_tensor = wav_tensor.unsqueeze(0)

    # Get speaker embedding
    with torch.no_grad():
        spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
        spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)

    # Prepare conditioning dictionary
    cond_dict = {
        "text": text,
        "speaker": spk_embedding,
        "language": selected_language,  # Use the code here
        "device": device,
    }
    conditioning = model.prepare_conditioning(cond_dict)

    # Generate codes
    with torch.no_grad():
        codes = model.generate(conditioning)

    # Decode the codes into raw audio
    wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
    sr_out = model.autoencoder.sampling_rate

    return (sr_out, wav_out.numpy())

def build_demo():
    with gr.Blocks(theme='davehornik/Tealy') as demo:
        gr.HTML(BANNER, elem_id="banner")
        gr.Markdown("## Zonos-v0.1 TTS Demo")
        gr.Markdown(
            """
> **Zero-shot TTS with Voice Cloning**: Input text and a 10–30 second speaker sample to generate high-quality text-to-speech output.

> **Audio Prefix Inputs**: Enhance speaker matching by adding an audio prefix to the text, enabling behaviors like whispering that are hard to achieve with voice cloning alone.

> **Multilingual Support**: Supports English, Japanese, Chinese, French, and German.
            """
        )
        with gr.Row():
            text_input = gr.Textbox(
                label="Text Prompt",
                value="Hello from Zonos!",
                lines=3
            )
            ref_audio_input = gr.Audio(
                label="Reference Audio (Speaker Cloning)",
                type="numpy"
            )

        model_dropdown = gr.Dropdown(
            label="Model Choice",
            choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
            value="Zyphra/Zonos-v0.1-hybrid",
            interactive=True,
        )

        # For the language dropdown, we display only the friendly label
        language_dropdown = gr.Dropdown(
            label="Language",
            choices=[label for (label, code) in LANGUAGES],
            value="English",  # default display
            interactive=True,
        )

        generate_button = gr.Button("Generate")
        audio_output = gr.Audio(label="Synthesized Output", type="numpy")

        generate_button.click(
            fn=tts,
            inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown],
            outputs=audio_output,
        )

    return demo

if __name__ == "__main__":
    demo_app = build_demo()
    demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)