Steveeeeeeen HF staff commited on
Commit
15961ae
·
verified ·
1 Parent(s): d7ca016

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -34
app.py CHANGED
@@ -5,44 +5,42 @@ import gradio as gr
5
  from zonos.model import Zonos
6
  from zonos.conditioning import make_cond_dict, supported_language_codes
7
 
8
- # Global cache to hold the loaded model
9
- MODEL = None
10
  device = "cuda"
11
 
12
- def load_model():
13
  """
14
- Loads the Zonos model once and caches it globally.
15
- Adjust the model name if you want to switch from hybrid to transformer, etc.
16
  """
17
- global MODEL
18
- if MODEL is None:
19
- model_name = "Zyphra/Zonos-v0.1-hybrid"
20
  print(f"Loading model: {model_name}")
21
- MODEL = Zonos.from_pretrained(model_name, device="cuda")
22
- MODEL = MODEL.requires_grad_(False).eval()
23
- MODEL.bfloat16() # optional if your GPU supports bfloat16
24
- print("Model loaded successfully!")
25
- return MODEL
26
-
27
- def tts(text, speaker_audio, selected_language):
 
28
  """
29
- text: str
30
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
31
- selected_language: str (e.g., "en-us", "es-es", etc.)
 
32
 
33
  Returns (sample_rate, waveform) for Gradio audio output.
34
  """
35
- model = load_model()
 
36
 
37
- # If no text, return None
38
  if not text:
39
  return None
40
-
41
- # If no reference audio, return None
42
  if speaker_audio is None:
43
  return None
44
 
45
- # Gradio provides audio in (sample_rate, numpy_array)
46
  sr, wav_np = speaker_audio
47
 
48
  # Convert to Torch tensor: shape (1, num_samples)
@@ -58,9 +56,9 @@ def tts(text, speaker_audio, selected_language):
58
 
59
  # Prepare conditioning dictionary
60
  cond_dict = make_cond_dict(
61
- text=text, # The text prompt
62
- speaker=spk_embedding, # Speaker embedding
63
- language=selected_language, # Language from the Dropdown
64
  device=device,
65
  )
66
  conditioning = model.prepare_conditioning(cond_dict)
@@ -77,7 +75,7 @@ def tts(text, speaker_audio, selected_language):
77
 
78
  def build_demo():
79
  with gr.Blocks() as demo:
80
- gr.Markdown("# Simple Zonos TTS Demo (Text + Reference Audio + Language)")
81
 
82
  with gr.Row():
83
  text_input = gr.Textbox(
@@ -89,23 +87,28 @@ def build_demo():
89
  label="Reference Audio (Speaker Cloning)",
90
  type="numpy"
91
  )
92
- # Add a dropdown for language selection
 
 
 
 
 
 
 
 
93
  language_dropdown = gr.Dropdown(
94
- label="Language",
95
- choices=supported_language_codes,
96
  value="en-us",
97
- interactive=True
98
  )
99
 
100
  generate_button = gr.Button("Generate")
101
-
102
- # The output is an audio widget that Gradio will play
103
  audio_output = gr.Audio(label="Synthesized Output", type="numpy")
104
 
105
- # Bind the generate button: pass text, reference audio, and selected language
106
  generate_button.click(
107
  fn=tts,
108
- inputs=[text_input, ref_audio_input, language_dropdown],
109
  outputs=audio_output,
110
  )
111
 
 
5
  from zonos.model import Zonos
6
  from zonos.conditioning import make_cond_dict, supported_language_codes
7
 
8
+ # We'll keep a global dictionary of loaded models to avoid reloading
9
+ MODELS_CACHE = {}
10
  device = "cuda"
11
 
12
+ def load_model(model_name: str):
13
  """
14
+ Loads or retrieves a cached Zonos model, sets it to eval and bfloat16.
 
15
  """
16
+ global MODELS_CACHE
17
+ if model_name not in MODELS_CACHE:
 
18
  print(f"Loading model: {model_name}")
19
+ model = Zonos.from_pretrained(model_name, device=device)
20
+ model = model.requires_grad_(False).eval()
21
+ model.bfloat16() # optional if GPU supports bfloat16
22
+ MODELS_CACHE[model_name] = model
23
+ print(f"Model loaded successfully: {model_name}")
24
+ return MODELS_CACHE[model_name]
25
+
26
+ def tts(text, speaker_audio, selected_language, model_choice):
27
  """
28
+ text: str (Text prompt to synthesize)
29
  speaker_audio: (sample_rate, numpy_array) from Gradio if type="numpy"
30
+ selected_language: str (language code)
31
+ model_choice: str (which Zonos model to use, e.g., "Zyphra/Zonos-v0.1-hybrid")
32
 
33
  Returns (sample_rate, waveform) for Gradio audio output.
34
  """
35
+ # Load the selected model
36
+ model = load_model(model_choice)
37
 
 
38
  if not text:
39
  return None
 
 
40
  if speaker_audio is None:
41
  return None
42
 
43
+ # Gradio gives audio in the format (sample_rate, numpy_array)
44
  sr, wav_np = speaker_audio
45
 
46
  # Convert to Torch tensor: shape (1, num_samples)
 
56
 
57
  # Prepare conditioning dictionary
58
  cond_dict = make_cond_dict(
59
+ text=text,
60
+ speaker=spk_embedding,
61
+ language=selected_language,
62
  device=device,
63
  )
64
  conditioning = model.prepare_conditioning(cond_dict)
 
75
 
76
  def build_demo():
77
  with gr.Blocks() as demo:
78
+ gr.Markdown("# Simple Zonos TTS Demo")
79
 
80
  with gr.Row():
81
  text_input = gr.Textbox(
 
87
  label="Reference Audio (Speaker Cloning)",
88
  type="numpy"
89
  )
90
+
91
+ # Model dropdown
92
+ model_dropdown = gr.Dropdown(
93
+ label="Model Choice",
94
+ choices=["Zyphra/Zonos-v0.1-transformer", "Zyphra/Zonos-v0.1-hybrid"],
95
+ value="Zyphra/Zonos-v0.1-hybrid",
96
+ interactive=True,
97
+ )
98
+ # Language dropdown (you can filter or use all from supported_language_codes)
99
  language_dropdown = gr.Dropdown(
100
+ label="Language Code",
101
+ choices=["en-us", "es-es", "fr-fr", "de-de", "it"],
102
  value="en-us",
103
+ interactive=True,
104
  )
105
 
106
  generate_button = gr.Button("Generate")
 
 
107
  audio_output = gr.Audio(label="Synthesized Output", type="numpy")
108
 
 
109
  generate_button.click(
110
  fn=tts,
111
+ inputs=[text_input, ref_audio_input, language_dropdown, model_dropdown],
112
  outputs=audio_output,
113
  )
114