mpasila commited on
Commit
d719f2b
·
verified ·
1 Parent(s): fca1c74

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -74
app.py CHANGED
@@ -1,7 +1,5 @@
1
  """
2
  Gradio UI for Text-to-Speech using HiggsAudioServeEngine
3
- Adapted: Now compatible with Jupyter, Colab, Runpod, etc,
4
- by adding launch_notebook() and flexible path/context handling.
5
  """
6
 
7
  import argparse
@@ -18,25 +16,12 @@ from functools import lru_cache
18
  import re
19
  import torch
20
 
21
- # --- Safe import or stub for 'spaces' (for Huggingface Space only) ---
22
- try:
23
- import spaces
24
- except ImportError:
25
- class DummySpaces:
26
- def __getattr__(self, name): # any decorator
27
- return lambda *a, **k: (lambda f: f)
28
- spaces = DummySpaces()
29
-
30
  # Import HiggsAudio components
31
  from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
32
  from higgs_audio.data_types import ChatMLSample, AudioContent, Message
33
 
34
- # --- Add this for Colab/notebook path safety ---
35
- BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if "__file__" in globals() else os.getcwd()
36
-
37
- # Global engine/voice instance
38
  engine = None
39
- VOICE_PRESETS = {}
40
 
41
  # Default model configuration
42
  DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
@@ -52,17 +37,62 @@ DEFAULT_SYSTEM_PROMPT = (
52
 
53
  DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
54
 
55
- # ... PREDEFINED_EXAMPLES as before ...
56
-
57
- # (copy unchanged; omitted for brevity in this answer but use your full PREDEFINED_EXAMPLES dictionary)
58
-
59
  PREDEFINED_EXAMPLES = {
60
- # ... Same as your long dict above ...
61
- # (copy full version from original)
62
- # (you can copy exactly as in your current app.py)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  }
64
 
65
- # -- The rest of your code, but replacing path joins to use BASE_DIR instead of __file__! ---
66
 
67
  @lru_cache(maxsize=20)
68
  def encode_audio_file(file_path):
@@ -70,15 +100,17 @@ def encode_audio_file(file_path):
70
  with open(file_path, "rb") as audio_file:
71
  return base64.b64encode(audio_file.read()).decode("utf-8")
72
 
 
73
  def get_current_device():
74
  """Get the current device."""
75
  return "cuda" if torch.cuda.is_available() else "cpu"
76
 
 
77
  def load_voice_presets():
78
  """Load the voice presets from the voice_examples directory."""
79
  try:
80
  with open(
81
- os.path.join(BASE_DIR, "voice_examples", "config.json"),
82
  "r",
83
  ) as f:
84
  voice_dict = json.load(f)
@@ -93,9 +125,10 @@ def load_voice_presets():
93
  logger.error(f"Error loading voice presets: {e}")
94
  return {"EMPTY": "No reference voice"}
95
 
 
96
  def get_voice_preset(voice_preset):
97
  """Get the voice path and text for a given voice preset."""
98
- voice_path = os.path.join(BASE_DIR, "voice_examples", f"{voice_preset}.wav")
99
  if not os.path.exists(voice_path):
100
  logger.warning(f"Voice preset file not found: {voice_path}")
101
  return None, "Voice preset not found"
@@ -103,24 +136,54 @@ def get_voice_preset(voice_preset):
103
  text = VOICE_PRESETS.get(voice_preset, "No transcript available")
104
  return voice_path, text
105
 
106
- # -- rest of your normalization and utility code unchanged --
107
 
108
  def normalize_chinese_punctuation(text):
109
- # ... as before ...
 
 
 
110
  chinese_to_english_punct = {
111
- # ... as before ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  }
 
 
113
  for zh_punct, en_punct in chinese_to_english_punct.items():
114
  text = text.replace(zh_punct, en_punct)
 
115
  return text
116
 
 
117
  def normalize_text(transcript: str):
118
- # ... as before, unchanged ...
119
  transcript = normalize_chinese_punctuation(transcript)
 
120
  transcript = transcript.replace("(", " ")
121
  transcript = transcript.replace(")", " ")
122
  transcript = transcript.replace("°F", " degrees Fahrenheit")
123
  transcript = transcript.replace("°C", " degrees Celsius")
 
124
  for tag, replacement in [
125
  ("[laugh]", "<SE>[Laughter]</SE>"),
126
  ("[humming start]", "<SE>[Humming]</SE>"),
@@ -135,15 +198,17 @@ def normalize_text(transcript: str):
135
  ("[cough]", "<SE>[Cough]</SE>"),
136
  ]:
137
  transcript = transcript.replace(tag, replacement)
138
- # ... rest unchanged ...
139
  lines = transcript.split("\n")
140
  transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
141
  transcript = transcript.strip()
 
142
  if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
143
  transcript += "."
 
144
  return transcript
145
 
146
- @spaces.GPU
147
  def initialize_engine(model_path, audio_tokenizer_path) -> bool:
148
  """Initialize the HiggsAudioServeEngine."""
149
  global engine
@@ -160,14 +225,19 @@ def initialize_engine(model_path, audio_tokenizer_path) -> bool:
160
  logger.error(f"Failed to initialize engine: {e}")
161
  return False
162
 
 
163
  def check_return_audio(audio_wv: np.ndarray):
 
164
  if np.all(audio_wv == 0):
165
  logger.warning("Audio is silent, returning None")
166
 
 
167
  def process_text_output(text_output: str):
 
168
  text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
169
  return text_output
170
 
 
171
  def prepare_chatml_sample(
172
  voice_preset: str,
173
  text: str,
@@ -175,29 +245,45 @@ def prepare_chatml_sample(
175
  reference_text: Optional[str] = None,
176
  system_prompt: str = DEFAULT_SYSTEM_PROMPT,
177
  ):
 
178
  messages = []
 
 
179
  if len(system_prompt) > 0:
180
  messages.append(Message(role="system", content=system_prompt))
 
 
181
  audio_base64 = None
182
  ref_text = ""
 
183
  if reference_audio:
 
184
  audio_base64 = encode_audio_file(reference_audio)
185
  ref_text = reference_text or ""
186
  elif voice_preset != "EMPTY":
 
187
  voice_path, ref_text = get_voice_preset(voice_preset)
188
  if voice_path is None:
189
  logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
190
  else:
191
  audio_base64 = encode_audio_file(voice_path)
 
 
192
  if audio_base64 is not None:
 
193
  messages.append(Message(role="user", content=ref_text))
 
 
194
  audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
195
  messages.append(Message(role="assistant", content=[audio_content]))
 
 
196
  text = normalize_text(text)
197
  messages.append(Message(role="user", content=text))
 
198
  return ChatMLSample(messages=messages)
199
 
200
- @spaces.GPU(duration=120)
201
  def text_to_speech(
202
  text,
203
  voice_preset,
@@ -212,15 +298,22 @@ def text_to_speech(
212
  ras_win_len=7,
213
  ras_win_max_num_repeat=2,
214
  ):
 
215
  global engine
 
216
  if engine is None:
217
  initialize_engine(DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
 
218
  try:
 
219
  chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
 
 
220
  if stop_strings is None:
221
  stop_list = DEFAULT_STOP_STRINGS
222
  else:
223
  stop_list = [s for s in stop_strings["stops"] if s.strip()]
 
224
  request_id = f"tts-playground-{str(uuid.uuid4())}"
225
  logger.info(
226
  f"{request_id}: Generating speech for text: {text[:100]}..., \n"
@@ -228,6 +321,8 @@ def text_to_speech(
228
  f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}"
229
  )
230
  start_time = time.time()
 
 
231
  response = engine.generate(
232
  chat_ml_sample=chatml_sample,
233
  max_new_tokens=max_completion_tokens,
@@ -238,25 +333,34 @@ def text_to_speech(
238
  ras_win_len=ras_win_len if ras_win_len > 0 else None,
239
  ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
240
  )
 
241
  generation_time = time.time() - start_time
242
  logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
243
  gr.Info(f"Generated audio in {generation_time:.3f} seconds")
 
 
244
  text_output = process_text_output(response.generated_text)
 
245
  if response.audio is not None:
 
246
  audio_data = (response.audio * 32767).astype(np.int16)
247
  check_return_audio(audio_data)
248
  return text_output, (response.sampling_rate, audio_data)
249
  else:
250
  logger.warning("No audio generated")
251
  return text_output, None
 
252
  except Exception as e:
253
  error_msg = f"Error generating speech: {e}"
254
  logger.error(error_msg)
255
  gr.Error(error_msg)
256
  return f"❌ {error_msg}", None
257
 
 
258
  def create_ui():
259
- my_theme = gr.Theme.load(os.path.join(BASE_DIR, "theme.json"))
 
 
260
  custom_css = """
261
  .gradio-container input:focus,
262
  .gradio-container textarea:focus,
@@ -272,6 +376,8 @@ def create_ui():
272
  outline: none !important;
273
  background-color: var(--input-background-fill) !important;
274
  }
 
 
275
  .gradio-container input:hover,
276
  .gradio-container textarea:hover,
277
  .gradio-container select:hover,
@@ -281,45 +387,59 @@ def create_ui():
281
  border-color: var(--border-color-primary) !important;
282
  background-color: var(--input-background-fill) !important;
283
  }
 
 
284
  .gradio-container input[type="checkbox"]:checked {
285
  background-color: var(--primary-500) !important;
286
  border-color: var(--primary-500) !important;
287
  }
288
  """
 
289
  default_template = "smart-voice"
 
 
290
  with gr.Blocks(theme=my_theme, css=custom_css) as demo:
291
  gr.Markdown("# Higgs Audio Text-to-Speech Playground")
 
 
292
  with gr.Row():
293
  with gr.Column(scale=2):
 
294
  template_dropdown = gr.Dropdown(
295
  label="TTS Template",
296
  choices=list(PREDEFINED_EXAMPLES.keys()),
297
  value=default_template,
298
  info="Select a predefined example for system and input messages.",
299
  )
 
 
300
  template_description = gr.HTML(
301
  value=f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {PREDEFINED_EXAMPLES[default_template]["description"]}</p>',
302
  visible=True,
303
  )
 
304
  system_prompt = gr.TextArea(
305
  label="System Prompt",
306
  placeholder="Enter system prompt to guide the model...",
307
  value=PREDEFINED_EXAMPLES[default_template]["system_prompt"],
308
  lines=2,
309
  )
 
310
  input_text = gr.TextArea(
311
  label="Input Text",
312
  placeholder="Type the text you want to convert to speech...",
313
  value=PREDEFINED_EXAMPLES[default_template]["input_text"],
314
  lines=5,
315
  )
 
316
  voice_preset = gr.Dropdown(
317
  label="Voice Preset",
318
  choices=list(VOICE_PRESETS.keys()),
319
  value="EMPTY",
320
- interactive=False,
321
  visible=False,
322
  )
 
323
  with gr.Accordion(
324
  "Custom Reference (Optional)", open=False, visible=False
325
  ) as custom_reference_accordion:
@@ -329,6 +449,7 @@ def create_ui():
329
  placeholder="Enter the transcript of your reference audio...",
330
  lines=3,
331
  )
 
332
  with gr.Accordion("Advanced Parameters", open=False):
333
  max_completion_tokens = gr.Slider(
334
  minimum=128,
@@ -362,6 +483,7 @@ def create_ui():
362
  label="RAS Max Num Repeat",
363
  info="Maximum number of repetitions allowed in the window",
364
  )
 
365
  stop_strings = gr.Dataframe(
366
  label="Stop Strings",
367
  headers=["stops"],
@@ -370,11 +492,18 @@ def create_ui():
370
  interactive=True,
371
  col_count=(1, "fixed"),
372
  )
 
373
  submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
 
374
  with gr.Column(scale=2):
375
  output_text = gr.TextArea(label="Model Response", lines=2)
 
 
376
  output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
 
377
  stop_btn = gr.Button("Stop Playback", variant="primary")
 
 
378
  with gr.Row(visible=False) as voice_samples_section:
379
  voice_samples_table = gr.Dataframe(
380
  headers=["Voice Preset", "Sample Text"],
@@ -384,8 +513,10 @@ def create_ui():
384
  )
385
  sample_audio = gr.Audio(label="Voice Sample")
386
 
 
387
  def play_voice_sample(evt: gr.SelectData):
388
  try:
 
389
  preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
390
  if evt.index[0] < len(preset_names):
391
  preset = preset_names[evt.index[0]]
@@ -405,11 +536,14 @@ def create_ui():
405
 
406
  voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
407
 
 
408
  def apply_template(template_name):
409
  if template_name in PREDEFINED_EXAMPLES:
410
  template = PREDEFINED_EXAMPLES[template_name]
 
411
  is_voice_clone = template_name == "voice-clone"
412
  voice_preset_value = "belinda" if is_voice_clone else "EMPTY"
 
413
  ras_win_len_value = 0 if template_name == "single-speaker-bgm" else 7
414
  description_text = f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {template["description"]}</p>'
415
  return (
@@ -418,10 +552,10 @@ def create_ui():
418
  description_text, # template_description
419
  gr.update(
420
  value=voice_preset_value, interactive=is_voice_clone, visible=is_voice_clone
421
- ),
422
- gr.update(visible=is_voice_clone),
423
- gr.update(visible=is_voice_clone),
424
- ras_win_len_value,
425
  )
426
  else:
427
  return (
@@ -432,8 +566,11 @@ def create_ui():
432
  gr.update(),
433
  gr.update(),
434
  gr.update(),
435
- )
 
 
436
 
 
437
  template_dropdown.change(
438
  fn=apply_template,
439
  inputs=[template_dropdown],
@@ -448,6 +585,7 @@ def create_ui():
448
  ],
449
  )
450
 
 
451
  submit_btn.click(
452
  fn=text_to_speech,
453
  inputs=[
@@ -467,50 +605,20 @@ def create_ui():
467
  outputs=[output_text, output_audio],
468
  api_name="generate_speech",
469
  )
 
 
470
  stop_btn.click(
471
  fn=lambda: None,
472
  inputs=[],
473
  outputs=[output_audio],
474
  js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
475
  )
476
- return demo
477
-
478
- # ------ NEW! Notebook/Colab/Runpod Launch Function ------
479
- def launch_notebook(
480
- model_path=DEFAULT_MODEL_PATH,
481
- audio_tokenizer_path=DEFAULT_AUDIO_TOKENIZER_PATH,
482
- device=None,
483
- host="127.0.0.1",
484
- port=7860,
485
- inline=True,
486
- share=False,
487
- **gradio_kwargs
488
- ):
489
- """
490
- Launch the Gradio UI inside a notebook, Colab or script.
491
- - If inline=True (default), embeds in cell (Jupyter/Colab/Runpod, etc).
492
- - If share=True, Gradio will provide a public URL for the UI.
493
- """
494
- global VOICE_PRESETS
495
- VOICE_PRESETS = load_voice_presets()
496
 
497
- # Optionally initialize engine, or let it lazy init on first use
498
- # initialize_engine(model_path, audio_tokenizer_path)
499
 
500
- demo = create_ui()
501
- # Note: You can also pass other gradio launch kwargs here if desired.
502
- demo.launch(
503
- server_name=host,
504
- server_port=port,
505
- inline=inline,
506
- share=share,
507
- **gradio_kwargs,
508
- )
509
 
510
  def main():
511
- """
512
- Main function to parse arguments and launch the UI via CLI (notebooks should use launch_notebook()).
513
- """
514
  global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
515
 
516
  parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
@@ -525,9 +633,14 @@ def main():
525
  parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
526
 
527
  args = parser.parse_args()
 
 
528
  VOICE_PRESETS = load_voice_presets()
 
 
529
  demo = create_ui()
530
  demo.launch(server_name=args.host, server_port=args.port)
531
 
 
532
  if __name__ == "__main__":
533
  main()
 
1
  """
2
  Gradio UI for Text-to-Speech using HiggsAudioServeEngine
 
 
3
  """
4
 
5
  import argparse
 
16
  import re
17
  import torch
18
 
 
 
 
 
 
 
 
 
 
19
  # Import HiggsAudio components
20
  from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
21
  from higgs_audio.data_types import ChatMLSample, AudioContent, Message
22
 
23
+ # Global engine instance
 
 
 
24
  engine = None
 
25
 
26
  # Default model configuration
27
  DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
 
37
 
38
  DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
39
 
40
+ # Predefined examples for system and input messages
 
 
 
41
  PREDEFINED_EXAMPLES = {
42
+ "voice-clone": {
43
+ "system_prompt": "",
44
+ "input_text": "Hey there! I'm your friendly voice twin in the making. Pick a voice preset below or upload your own audio - let's clone some vocals and bring your voice to life! ",
45
+ "description": "Voice clone to clone the reference audio. Leave the system prompt empty.",
46
+ },
47
+ "smart-voice": {
48
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
49
+ "input_text": "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.",
50
+ "description": "Smart voice to generate speech based on the context",
51
+ },
52
+ "multispeaker-voice-description": {
53
+ "system_prompt": "You are an AI assistant designed to convert text into speech.\n"
54
+ "If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.\n"
55
+ "If no speaker tag is present, select a suitable voice on your own.\n\n"
56
+ "<|scene_desc_start|>\n"
57
+ "SPEAKER0: feminine\n"
58
+ "SPEAKER1: masculine\n"
59
+ "<|scene_desc_end|>",
60
+ "input_text": "[SPEAKER0] I can't believe you did that without even asking me first!\n"
61
+ "[SPEAKER1] Oh, come on! It wasn't a big deal, and I knew you would overreact like this.\n"
62
+ "[SPEAKER0] Overreact? You made a decision that affects both of us without even considering my opinion!\n"
63
+ "[SPEAKER1] Because I didn't have time to sit around waiting for you to make up your mind! Someone had to act.",
64
+ "description": "Multispeaker with different voice descriptions in the system prompt",
65
+ },
66
+ "single-speaker-voice-description": {
67
+ "system_prompt": "Generate audio following instruction.\n\n"
68
+ "<|scene_desc_start|>\n"
69
+ "SPEAKER0: He speaks with a clear British accent and a conversational, inquisitive tone. His delivery is articulate and at a moderate pace, and very clear audio.\n"
70
+ "<|scene_desc_end|>",
71
+ "input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
72
+ "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
73
+ "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n"
74
+ "\n"
75
+ "So here's the big question: Do you want to understand how deep learning works?\n",
76
+ "description": "Single speaker with voice description in the system prompt",
77
+ },
78
+ "single-speaker-zh": {
79
+ "system_prompt": "Generate audio following instruction.\n\n"
80
+ "<|scene_desc_start|>\n"
81
+ "Audio is recorded from a quiet room.\n"
82
+ "<|scene_desc_end|>",
83
+ "input_text": "大家好, 欢迎收听本期的跟李沐学AI. 今天沐哥在忙着洗数据, 所以由我, 希格斯主播代替他讲这期视频.\n"
84
+ "今天我们要聊的是一个你绝对不能忽视的话题: 多模态学习.\n"
85
+ "那么, 问题来了, 你真的了解多模态吗? 你知道如何自己动手构建多模态大模型吗.\n"
86
+ "或者说, 你能察觉到我其实是个机器人吗?",
87
+ "description": "Single speaker speaking Chinese",
88
+ },
89
+ "single-speaker-bgm": {
90
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
91
+ "input_text": "[music start] I will remember this, thought Ender, when I am defeated. To keep dignity, and give honor where it's due, so that defeat is not disgrace. And I hope I don't have to do it often. [music end]",
92
+ "description": "Single speaker with BGM using music tag. This is an experimental feature and you may need to try multiple times to get the best result.",
93
+ },
94
  }
95
 
 
96
 
97
  @lru_cache(maxsize=20)
98
  def encode_audio_file(file_path):
 
100
  with open(file_path, "rb") as audio_file:
101
  return base64.b64encode(audio_file.read()).decode("utf-8")
102
 
103
+
104
  def get_current_device():
105
  """Get the current device."""
106
  return "cuda" if torch.cuda.is_available() else "cpu"
107
 
108
+
109
  def load_voice_presets():
110
  """Load the voice presets from the voice_examples directory."""
111
  try:
112
  with open(
113
+ os.path.join(os.path.dirname(__file__), "voice_examples", "config.json"),
114
  "r",
115
  ) as f:
116
  voice_dict = json.load(f)
 
125
  logger.error(f"Error loading voice presets: {e}")
126
  return {"EMPTY": "No reference voice"}
127
 
128
+
129
  def get_voice_preset(voice_preset):
130
  """Get the voice path and text for a given voice preset."""
131
+ voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav")
132
  if not os.path.exists(voice_path):
133
  logger.warning(f"Voice preset file not found: {voice_path}")
134
  return None, "Voice preset not found"
 
136
  text = VOICE_PRESETS.get(voice_preset, "No transcript available")
137
  return voice_path, text
138
 
 
139
 
140
  def normalize_chinese_punctuation(text):
141
+ """
142
+ Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
143
+ """
144
+ # Mapping of Chinese punctuation to English punctuation
145
  chinese_to_english_punct = {
146
+ ",": ", ", # comma
147
+ "。": ".", # period
148
+ ":": ":", # colon
149
+ ";": ";", # semicolon
150
+ "?": "?", # question mark
151
+ "!": "!", # exclamation mark
152
+ "(": "(", # left parenthesis
153
+ ")": ")", # right parenthesis
154
+ "【": "[", # left square bracket
155
+ "】": "]", # right square bracket
156
+ "《": "<", # left angle quote
157
+ "》": ">", # right angle quote
158
+ "“": '"', # left double quotation
159
+ "”": '"', # right double quotation
160
+ "‘": "'", # left single quotation
161
+ "’": "'", # right single quotation
162
+ "、": ",", # enumeration comma
163
+ "—": "-", # em dash
164
+ "…": "...", # ellipsis
165
+ "·": ".", # middle dot
166
+ "「": '"', # left corner bracket
167
+ "」": '"', # right corner bracket
168
+ "『": '"', # left double corner bracket
169
+ "』": '"', # right double corner bracket
170
  }
171
+
172
+ # Replace each Chinese punctuation with its English counterpart
173
  for zh_punct, en_punct in chinese_to_english_punct.items():
174
  text = text.replace(zh_punct, en_punct)
175
+
176
  return text
177
 
178
+
179
  def normalize_text(transcript: str):
 
180
  transcript = normalize_chinese_punctuation(transcript)
181
+ # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future)
182
  transcript = transcript.replace("(", " ")
183
  transcript = transcript.replace(")", " ")
184
  transcript = transcript.replace("°F", " degrees Fahrenheit")
185
  transcript = transcript.replace("°C", " degrees Celsius")
186
+
187
  for tag, replacement in [
188
  ("[laugh]", "<SE>[Laughter]</SE>"),
189
  ("[humming start]", "<SE>[Humming]</SE>"),
 
198
  ("[cough]", "<SE>[Cough]</SE>"),
199
  ]:
200
  transcript = transcript.replace(tag, replacement)
201
+
202
  lines = transcript.split("\n")
203
  transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
204
  transcript = transcript.strip()
205
+
206
  if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
207
  transcript += "."
208
+
209
  return transcript
210
 
211
+
212
  def initialize_engine(model_path, audio_tokenizer_path) -> bool:
213
  """Initialize the HiggsAudioServeEngine."""
214
  global engine
 
225
  logger.error(f"Failed to initialize engine: {e}")
226
  return False
227
 
228
+
229
  def check_return_audio(audio_wv: np.ndarray):
230
+ # check if the audio returned is all silent
231
  if np.all(audio_wv == 0):
232
  logger.warning("Audio is silent, returning None")
233
 
234
+
235
  def process_text_output(text_output: str):
236
+ # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
237
  text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
238
  return text_output
239
 
240
+
241
  def prepare_chatml_sample(
242
  voice_preset: str,
243
  text: str,
 
245
  reference_text: Optional[str] = None,
246
  system_prompt: str = DEFAULT_SYSTEM_PROMPT,
247
  ):
248
+ """Prepare a ChatMLSample for the HiggsAudioServeEngine."""
249
  messages = []
250
+
251
+ # Add system message if provided
252
  if len(system_prompt) > 0:
253
  messages.append(Message(role="system", content=system_prompt))
254
+
255
+ # Add reference audio if provided
256
  audio_base64 = None
257
  ref_text = ""
258
+
259
  if reference_audio:
260
+ # Custom reference audio
261
  audio_base64 = encode_audio_file(reference_audio)
262
  ref_text = reference_text or ""
263
  elif voice_preset != "EMPTY":
264
+ # Voice preset
265
  voice_path, ref_text = get_voice_preset(voice_preset)
266
  if voice_path is None:
267
  logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
268
  else:
269
  audio_base64 = encode_audio_file(voice_path)
270
+
271
+ # Only add reference audio if we have it
272
  if audio_base64 is not None:
273
+ # Add user message with reference text
274
  messages.append(Message(role="user", content=ref_text))
275
+
276
+ # Add assistant message with audio content
277
  audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
278
  messages.append(Message(role="assistant", content=[audio_content]))
279
+
280
+ # Add the main user message
281
  text = normalize_text(text)
282
  messages.append(Message(role="user", content=text))
283
+
284
  return ChatMLSample(messages=messages)
285
 
286
+
287
  def text_to_speech(
288
  text,
289
  voice_preset,
 
298
  ras_win_len=7,
299
  ras_win_max_num_repeat=2,
300
  ):
301
+ """Convert text to speech using HiggsAudioServeEngine."""
302
  global engine
303
+
304
  if engine is None:
305
  initialize_engine(DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
306
+
307
  try:
308
+ # Prepare ChatML sample
309
  chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
310
+
311
+ # Convert stop strings format
312
  if stop_strings is None:
313
  stop_list = DEFAULT_STOP_STRINGS
314
  else:
315
  stop_list = [s for s in stop_strings["stops"] if s.strip()]
316
+
317
  request_id = f"tts-playground-{str(uuid.uuid4())}"
318
  logger.info(
319
  f"{request_id}: Generating speech for text: {text[:100]}..., \n"
 
321
  f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}"
322
  )
323
  start_time = time.time()
324
+
325
+ # Generate using the engine
326
  response = engine.generate(
327
  chat_ml_sample=chatml_sample,
328
  max_new_tokens=max_completion_tokens,
 
333
  ras_win_len=ras_win_len if ras_win_len > 0 else None,
334
  ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
335
  )
336
+
337
  generation_time = time.time() - start_time
338
  logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
339
  gr.Info(f"Generated audio in {generation_time:.3f} seconds")
340
+
341
+ # Process the response
342
  text_output = process_text_output(response.generated_text)
343
+
344
  if response.audio is not None:
345
+ # Convert to int16 for Gradio
346
  audio_data = (response.audio * 32767).astype(np.int16)
347
  check_return_audio(audio_data)
348
  return text_output, (response.sampling_rate, audio_data)
349
  else:
350
  logger.warning("No audio generated")
351
  return text_output, None
352
+
353
  except Exception as e:
354
  error_msg = f"Error generating speech: {e}"
355
  logger.error(error_msg)
356
  gr.Error(error_msg)
357
  return f"❌ {error_msg}", None
358
 
359
+
360
  def create_ui():
361
+ my_theme = gr.Theme.load("theme.json")
362
+
363
+ # Add custom CSS to disable focus highlighting on textboxes
364
  custom_css = """
365
  .gradio-container input:focus,
366
  .gradio-container textarea:focus,
 
376
  outline: none !important;
377
  background-color: var(--input-background-fill) !important;
378
  }
379
+
380
+ /* Override any hover effects as well */
381
  .gradio-container input:hover,
382
  .gradio-container textarea:hover,
383
  .gradio-container select:hover,
 
387
  border-color: var(--border-color-primary) !important;
388
  background-color: var(--input-background-fill) !important;
389
  }
390
+
391
+ /* Style for checked checkbox */
392
  .gradio-container input[type="checkbox"]:checked {
393
  background-color: var(--primary-500) !important;
394
  border-color: var(--primary-500) !important;
395
  }
396
  """
397
+
398
  default_template = "smart-voice"
399
+
400
+ """Create the Gradio UI."""
401
  with gr.Blocks(theme=my_theme, css=custom_css) as demo:
402
  gr.Markdown("# Higgs Audio Text-to-Speech Playground")
403
+
404
+ # Main UI section
405
  with gr.Row():
406
  with gr.Column(scale=2):
407
+ # Template selection dropdown
408
  template_dropdown = gr.Dropdown(
409
  label="TTS Template",
410
  choices=list(PREDEFINED_EXAMPLES.keys()),
411
  value=default_template,
412
  info="Select a predefined example for system and input messages.",
413
  )
414
+
415
+ # Template description display
416
  template_description = gr.HTML(
417
  value=f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {PREDEFINED_EXAMPLES[default_template]["description"]}</p>',
418
  visible=True,
419
  )
420
+
421
  system_prompt = gr.TextArea(
422
  label="System Prompt",
423
  placeholder="Enter system prompt to guide the model...",
424
  value=PREDEFINED_EXAMPLES[default_template]["system_prompt"],
425
  lines=2,
426
  )
427
+
428
  input_text = gr.TextArea(
429
  label="Input Text",
430
  placeholder="Type the text you want to convert to speech...",
431
  value=PREDEFINED_EXAMPLES[default_template]["input_text"],
432
  lines=5,
433
  )
434
+
435
  voice_preset = gr.Dropdown(
436
  label="Voice Preset",
437
  choices=list(VOICE_PRESETS.keys()),
438
  value="EMPTY",
439
+ interactive=False, # Disabled by default since default template is not voice-clone
440
  visible=False,
441
  )
442
+
443
  with gr.Accordion(
444
  "Custom Reference (Optional)", open=False, visible=False
445
  ) as custom_reference_accordion:
 
449
  placeholder="Enter the transcript of your reference audio...",
450
  lines=3,
451
  )
452
+
453
  with gr.Accordion("Advanced Parameters", open=False):
454
  max_completion_tokens = gr.Slider(
455
  minimum=128,
 
483
  label="RAS Max Num Repeat",
484
  info="Maximum number of repetitions allowed in the window",
485
  )
486
+ # Add stop strings component
487
  stop_strings = gr.Dataframe(
488
  label="Stop Strings",
489
  headers=["stops"],
 
492
  interactive=True,
493
  col_count=(1, "fixed"),
494
  )
495
+
496
  submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
497
+
498
  with gr.Column(scale=2):
499
  output_text = gr.TextArea(label="Model Response", lines=2)
500
+
501
+ # Audio output
502
  output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
503
+
504
  stop_btn = gr.Button("Stop Playback", variant="primary")
505
+
506
+ # Example voice
507
  with gr.Row(visible=False) as voice_samples_section:
508
  voice_samples_table = gr.Dataframe(
509
  headers=["Voice Preset", "Sample Text"],
 
513
  )
514
  sample_audio = gr.Audio(label="Voice Sample")
515
 
516
+ # Function to play voice sample when clicking on a row
517
  def play_voice_sample(evt: gr.SelectData):
518
  try:
519
+ # Get the preset name from the clicked row
520
  preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
521
  if evt.index[0] < len(preset_names):
522
  preset = preset_names[evt.index[0]]
 
536
 
537
  voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
538
 
539
+ # Function to handle template selection
540
  def apply_template(template_name):
541
  if template_name in PREDEFINED_EXAMPLES:
542
  template = PREDEFINED_EXAMPLES[template_name]
543
+ # Enable voice preset and custom reference only for voice-clone template
544
  is_voice_clone = template_name == "voice-clone"
545
  voice_preset_value = "belinda" if is_voice_clone else "EMPTY"
546
+ # Set ras_win_len to 0 for single-speaker-bgm, 7 for others
547
  ras_win_len_value = 0 if template_name == "single-speaker-bgm" else 7
548
  description_text = f'<p style="font-size: 0.85em; color: var(--body-text-color-subdued); margin: 0; padding: 0;"> {template["description"]}</p>'
549
  return (
 
552
  description_text, # template_description
553
  gr.update(
554
  value=voice_preset_value, interactive=is_voice_clone, visible=is_voice_clone
555
+ ), # voice_preset (value and interactivity)
556
+ gr.update(visible=is_voice_clone), # custom reference accordion visibility
557
+ gr.update(visible=is_voice_clone), # voice samples section visibility
558
+ ras_win_len_value, # ras_win_len
559
  )
560
  else:
561
  return (
 
566
  gr.update(),
567
  gr.update(),
568
  gr.update(),
569
+ ) # No change if template not found
570
+
571
+ # Set up event handlers
572
 
573
+ # Connect template dropdown to handler
574
  template_dropdown.change(
575
  fn=apply_template,
576
  inputs=[template_dropdown],
 
585
  ],
586
  )
587
 
588
+ # Connect submit button to the TTS function
589
  submit_btn.click(
590
  fn=text_to_speech,
591
  inputs=[
 
605
  outputs=[output_text, output_audio],
606
  api_name="generate_speech",
607
  )
608
+
609
+ # Stop button functionality
610
  stop_btn.click(
611
  fn=lambda: None,
612
  inputs=[],
613
  outputs=[output_audio],
614
  js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
615
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
+ return demo
 
618
 
 
 
 
 
 
 
 
 
 
619
 
620
  def main():
621
+ """Main function to parse arguments and launch the UI."""
 
 
622
  global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
623
 
624
  parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
 
633
  parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
634
 
635
  args = parser.parse_args()
636
+
637
+ # Update default values if provided via command line
638
  VOICE_PRESETS = load_voice_presets()
639
+
640
+ # Create and launch the UI
641
  demo = create_ui()
642
  demo.launch(server_name=args.host, server_port=args.port)
643
 
644
+
645
  if __name__ == "__main__":
646
  main()