File size: 3,866 Bytes
748ecaa
 
 
 
 
d5d8bf3
748ecaa
15961ae
 
e5d26e9
748ecaa
15961ae
46f1390
15961ae
46f1390
15961ae
 
46f1390
15961ae
 
 
 
 
 
 
 
46f1390
15961ae
46f1390
15961ae
 
d5d8bf3
46f1390
 
15961ae
 
46f1390
 
b1f1246
46f1390
 
 
15961ae
46f1390
 
 
 
 
 
 
 
 
 
 
e5d26e9
46f1390
 
748ecaa
15961ae
 
 
e5d26e9
748ecaa
d743fc1
46f1390
 
d743fc1
 
46f1390
 
 
 
 
 
 
 
 
15961ae
46f1390
 
 
 
 
 
 
 
 
 
 
15961ae
 
 
 
 
 
 
 
 
d5d8bf3
15961ae
 
d5d8bf3
15961ae
d5d8bf3
46f1390
 
 
 
 
 
15961ae
46f1390
 
 
 
748ecaa
 
46f1390
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torchaudio
import gradio as gr

from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes

# We'll keep a global dictionary of loaded models to avoid reloading
MODELS_CACHE = {}
device = "cuda"

def load_model(model_name: str):
    """
    Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
    """
    global MODELS_CACHE
    if model_name not in MODELS_CACHE:
        print(f"Loading model: {model_name}")
        model = Zonos.from_pretrained(model_name, device=device)
        model = model.requires_grad_(False).eval()
        model.bfloat16()  # optional if GPU supports bfloat16
        MODELS_CACHE[model_name] = model
        print(f"Model loaded successfully: {model_name}")
    return MODELS_CACHE[model_name]

def tts(text, speaker_audio, selected_language, model_choice):
    """
    text: str (Text prompt to synthesize)
    speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
    selected_language: str (language code)
    model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
    
    Returns (sample_rate, waveform) for Gradio audio output.
    """
    # Load the selected model
    model = load_model(model_choice)

    if not text:
        return None
    if speaker_audio is None:
        return None

    # Gradio gives audio in the format (sample_rate, numpy_array)
    sr, wav_np = speaker_audio

    # Convert to Torch tensor: shape (1, num_samples)
    wav_tensor = torch.from_numpy(wav_np).unsqueeze(0).float()
    if wav_tensor.dim() == 2 and wav_tensor.shape[0] > wav_tensor.shape[1]:
        # If shape is transposed, fix it
        wav_tensor = wav_tensor.T

    # Get speaker embedding
    with torch.no_grad():
        spk_embedding = model.make_speaker_embedding(wav_tensor, sr)
        spk_embedding = spk_embedding.to(device, dtype=torch.bfloat16)

    # Prepare conditioning dictionary
    cond_dict = make_cond_dict(
        text=text,
        speaker=spk_embedding,
        language=selected_language,
        device=device,
    )
    conditioning = model.prepare_conditioning(cond_dict)

    # Generate codes
    with torch.no_grad():
        codes = model.generate(conditioning)

    # Decode the codes into raw audio
    wav_out = model.autoencoder.decode(codes).cpu().detach().squeeze()
    sr_out = model.autoencoder.sampling_rate

    return (sr_out, wav_out.numpy())

def build_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Simple Zonos TTS Demo")

        with gr.Row():
            text_input = gr.Textbox(
                label="Text Prompt",
                value="Hello from Zonos!",
                lines=3
            )
            ref_audio_input = gr.Audio(
                label="Reference Audio (Speaker Cloning)",
                type="numpy"
            )

        # Model dropdown
        model_dropdown = gr.Dropdown(
            label="Model Choice",
            choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
            value="Zyphra/Zonos-v0.1-hybrid",
            interactive=True,
        )
        # Language dropdown (you can filter or use all from supported_language_codes)
        language_dropdown = gr.Dropdown(
            label="Language Code",
            choices=["en-us", "es-es", "fr-fr", "de-de", "it"],
            value="en-us",
            interactive=True,
        )

        generate_button = gr.Button("Generate")
        audio_output = gr.Audio(label="Synthesized Output", type="numpy")

        generate_button.click(
            fn=tts,
            inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown],
            outputs=audio_output,
        )

    return demo

if __name__ == "__main__":
    demo_app = build_demo()
    demo_app.launch(server_name="0.0.0.0", server_port=7860, share=True)