Hemant0000 commited on
Commit
d05fcdb
·
verified ·
1 Parent(s): 2a4220f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +723 -0
app.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ruff: noqa: E402
2
+ # Above allows ruff to ignore E402: module level import not at top of file
3
+
4
+ import re
5
+ import tempfile
6
+
7
+ import click
8
+ import gradio as gr
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import torchaudio
12
+ from cached_path import cached_path
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ try:
16
+ import spaces
17
+
18
+ USING_SPACES = True
19
+ except ImportError:
20
+ USING_SPACES = False
21
+
22
+
23
+ def gpu_decorator(func):
24
+ if USING_SPACES:
25
+ return spaces.GPU(func)
26
+ else:
27
+ return func
28
+
29
+
30
+ from f5_tts.model import DiT, UNetT
31
+ from f5_tts.infer.utils_infer import (
32
+ load_vocoder,
33
+ load_model,
34
+ preprocess_ref_audio_text,
35
+ infer_process,
36
+ remove_silence_for_generated_wav,
37
+ save_spectrogram,
38
+ )
39
+
40
+ vocoder = load_vocoder()
41
+
42
+
43
+ # load models
44
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
45
+ F5TTS_ema_model = load_model(
46
+ DiT, F5TTS_model_cfg, str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
47
+ )
48
+
49
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50
+ E2TTS_ema_model = load_model(
51
+ UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
52
+ )
53
+
54
+ chat_model_state = None
55
+ chat_tokenizer_state = None
56
+
57
+
58
+ @gpu_decorator
59
+ def generate_response(messages, model, tokenizer):
60
+ """Generate response using Qwen"""
61
+ text = tokenizer.apply_chat_template(
62
+ messages,
63
+ tokenize=False,
64
+ add_generation_prompt=True,
65
+ )
66
+
67
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
68
+ generated_ids = model.generate(
69
+ **model_inputs,
70
+ max_new_tokens=512,
71
+ temperature=0.7,
72
+ top_p=0.95,
73
+ )
74
+
75
+ generated_ids = [
76
+ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
77
+ ]
78
+ return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
79
+
80
+
81
+ @gpu_decorator
82
+ def infer(
83
+ ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
84
+ ):
85
+ ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
86
+
87
+ if model == "F5-TTS":
88
+ ema_model = F5TTS_ema_model
89
+ elif model == "E2-TTS":
90
+ ema_model = E2TTS_ema_model
91
+
92
+ final_wave, final_sample_rate, combined_spectrogram = infer_process(
93
+ ref_audio,
94
+ ref_text,
95
+ gen_text,
96
+ ema_model,
97
+ vocoder,
98
+ cross_fade_duration=cross_fade_duration,
99
+ speed=speed,
100
+ show_info=show_info,
101
+ progress=gr.Progress(),
102
+ )
103
+
104
+ # Remove silence
105
+ if remove_silence:
106
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
107
+ sf.write(f.name, final_wave, final_sample_rate)
108
+ remove_silence_for_generated_wav(f.name)
109
+ final_wave, _ = torchaudio.load(f.name)
110
+ final_wave = final_wave.squeeze().cpu().numpy()
111
+
112
+ # Save the spectrogram
113
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
114
+ spectrogram_path = tmp_spectrogram.name
115
+ save_spectrogram(combined_spectrogram, spectrogram_path)
116
+
117
+ return (final_sample_rate, final_wave), spectrogram_path
118
+
119
+
120
+ with gr.Blocks() as app_credits:
121
+ gr.Markdown("""
122
+ # Credits
123
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
124
+ * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
125
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
126
+ """)
127
+ with gr.Blocks() as app_tts:
128
+ gr.Markdown("# Batched TTS")
129
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
130
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
131
+ model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
132
+ generate_btn = gr.Button("Synthesize", variant="primary")
133
+ with gr.Accordion("Advanced Settings", open=False):
134
+ ref_text_input = gr.Textbox(
135
+ label="Reference Text",
136
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
137
+ lines=2,
138
+ )
139
+ remove_silence = gr.Checkbox(
140
+ label="Remove Silences",
141
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
142
+ value=False,
143
+ )
144
+ speed_slider = gr.Slider(
145
+ label="Speed",
146
+ minimum=0.3,
147
+ maximum=2.0,
148
+ value=1.0,
149
+ step=0.1,
150
+ info="Adjust the speed of the audio.",
151
+ )
152
+ cross_fade_duration_slider = gr.Slider(
153
+ label="Cross-Fade Duration (s)",
154
+ minimum=0.0,
155
+ maximum=1.0,
156
+ value=0.15,
157
+ step=0.01,
158
+ info="Set the duration of the cross-fade between audio clips.",
159
+ )
160
+
161
+ audio_output = gr.Audio(label="Synthesized Audio")
162
+ spectrogram_output = gr.Image(label="Spectrogram")
163
+
164
+ generate_btn.click(
165
+ infer,
166
+ inputs=[
167
+ ref_audio_input,
168
+ ref_text_input,
169
+ gen_text_input,
170
+ model_choice,
171
+ remove_silence,
172
+ cross_fade_duration_slider,
173
+ speed_slider,
174
+ ],
175
+ outputs=[audio_output, spectrogram_output],
176
+ )
177
+
178
+
179
+ def parse_speechtypes_text(gen_text):
180
+ # Pattern to find {speechtype}
181
+ pattern = r"\{(.*?)\}"
182
+
183
+ # Split the text by the pattern
184
+ tokens = re.split(pattern, gen_text)
185
+
186
+ segments = []
187
+
188
+ current_style = "Regular"
189
+
190
+ for i in range(len(tokens)):
191
+ if i % 2 == 0:
192
+ # This is text
193
+ text = tokens[i].strip()
194
+ if text:
195
+ segments.append({"style": current_style, "text": text})
196
+ else:
197
+ # This is style
198
+ style = tokens[i].strip()
199
+ current_style = style
200
+
201
+ return segments
202
+
203
+
204
+ with gr.Blocks() as app_multistyle:
205
+ # New section for multistyle generation
206
+ gr.Markdown(
207
+ """
208
+ # Multiple Speech-Type Generation
209
+ This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, and the system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
210
+ """
211
+ )
212
+
213
+ with gr.Row():
214
+ gr.Markdown(
215
+ """
216
+ **Example Input:**
217
+ {Regular} Hello, I'd like to order a sandwich please.
218
+ {Surprised} What do you mean you're out of bread?
219
+ {Sad} I really wanted a sandwich though...
220
+ {Angry} You know what, darn you and your little shop!
221
+ {Whisper} I'll just go back home and cry now.
222
+ {Shouting} Why me?!
223
+ """
224
+ )
225
+
226
+ gr.Markdown(
227
+ """
228
+ **Example Input 2:**
229
+ {Speaker1_Happy} Hello, I'd like to order a sandwich please.
230
+ {Speaker2_Regular} Sorry, we're out of bread.
231
+ {Speaker1_Sad} I really wanted a sandwich though...
232
+ {Speaker2_Whisper} I'll give you the last one I was hiding.
233
+ """
234
+ )
235
+
236
+ gr.Markdown(
237
+ "Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button."
238
+ )
239
+
240
+ # Regular speech type (mandatory)
241
+ with gr.Row():
242
+ with gr.Column():
243
+ regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
244
+ regular_insert = gr.Button("Insert", variant="secondary")
245
+ regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
246
+ regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=2)
247
+
248
+ # Additional speech types (up to 99 more)
249
+ max_speech_types = 100
250
+ speech_type_rows = []
251
+ speech_type_names = [regular_name]
252
+ speech_type_audios = []
253
+ speech_type_ref_texts = []
254
+ speech_type_delete_btns = []
255
+ speech_type_insert_btns = []
256
+ speech_type_insert_btns.append(regular_insert)
257
+
258
+ for i in range(max_speech_types - 1):
259
+ with gr.Row(visible=False) as row:
260
+ with gr.Column():
261
+ name_input = gr.Textbox(label="Speech Type Name")
262
+ delete_btn = gr.Button("Delete", variant="secondary")
263
+ insert_btn = gr.Button("Insert", variant="secondary")
264
+ audio_input = gr.Audio(label="Reference Audio", type="filepath")
265
+ ref_text_input = gr.Textbox(label="Reference Text", lines=2)
266
+ speech_type_rows.append(row)
267
+ speech_type_names.append(name_input)
268
+ speech_type_audios.append(audio_input)
269
+ speech_type_ref_texts.append(ref_text_input)
270
+ speech_type_delete_btns.append(delete_btn)
271
+ speech_type_insert_btns.append(insert_btn)
272
+
273
+ # Button to add speech type
274
+ add_speech_type_btn = gr.Button("Add Speech Type")
275
+
276
+ # Keep track of current number of speech types
277
+ speech_type_count = gr.State(value=0)
278
+
279
+ # Function to add a speech type
280
+ def add_speech_type_fn(speech_type_count):
281
+ if speech_type_count < max_speech_types - 1:
282
+ speech_type_count += 1
283
+ # Prepare updates for the rows
284
+ row_updates = []
285
+ for i in range(max_speech_types - 1):
286
+ if i < speech_type_count:
287
+ row_updates.append(gr.update(visible=True))
288
+ else:
289
+ row_updates.append(gr.update())
290
+ else:
291
+ # Optionally, show a warning
292
+ row_updates = [gr.update() for _ in range(max_speech_types - 1)]
293
+ return [speech_type_count] + row_updates
294
+
295
+ add_speech_type_btn.click(
296
+ add_speech_type_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows
297
+ )
298
+
299
+ # Function to delete a speech type
300
+ def make_delete_speech_type_fn(index):
301
+ def delete_speech_type_fn(speech_type_count):
302
+ # Prepare updates
303
+ row_updates = []
304
+
305
+ for i in range(max_speech_types - 1):
306
+ if i == index:
307
+ row_updates.append(gr.update(visible=False))
308
+ else:
309
+ row_updates.append(gr.update())
310
+
311
+ speech_type_count = max(0, speech_type_count - 1)
312
+
313
+ return [speech_type_count] + row_updates
314
+
315
+ return delete_speech_type_fn
316
+
317
+ # Update delete button clicks
318
+ for i, delete_btn in enumerate(speech_type_delete_btns):
319
+ delete_fn = make_delete_speech_type_fn(i)
320
+ delete_btn.click(delete_fn, inputs=speech_type_count, outputs=[speech_type_count] + speech_type_rows)
321
+
322
+ # Text input for the prompt
323
+ gen_text_input_multistyle = gr.Textbox(
324
+ label="Text to Generate",
325
+ lines=10,
326
+ placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
327
+ )
328
+
329
+ def make_insert_speech_type_fn(index):
330
+ def insert_speech_type_fn(current_text, speech_type_name):
331
+ current_text = current_text or ""
332
+ speech_type_name = speech_type_name or "None"
333
+ updated_text = current_text + f"{{{speech_type_name}}} "
334
+ return gr.update(value=updated_text)
335
+
336
+ return insert_speech_type_fn
337
+
338
+ for i, insert_btn in enumerate(speech_type_insert_btns):
339
+ insert_fn = make_insert_speech_type_fn(i)
340
+ insert_btn.click(
341
+ insert_fn,
342
+ inputs=[gen_text_input_multistyle, speech_type_names[i]],
343
+ outputs=gen_text_input_multistyle,
344
+ )
345
+
346
+ # Model choice
347
+ model_choice_multistyle = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
348
+
349
+ with gr.Accordion("Advanced Settings", open=False):
350
+ remove_silence_multistyle = gr.Checkbox(
351
+ label="Remove Silences",
352
+ value=False,
353
+ )
354
+
355
+ # Generate button
356
+ generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
357
+
358
+ # Output audio
359
+ audio_output_multistyle = gr.Audio(label="Synthesized Audio")
360
+
361
+ @gpu_decorator
362
+ def generate_multistyle_speech(
363
+ regular_audio,
364
+ regular_ref_text,
365
+ gen_text,
366
+ *args,
367
+ ):
368
+ num_additional_speech_types = max_speech_types - 1
369
+ speech_type_names_list = args[:num_additional_speech_types]
370
+ speech_type_audios_list = args[num_additional_speech_types : 2 * num_additional_speech_types]
371
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types : 3 * num_additional_speech_types]
372
+ model_choice = args[3 * num_additional_speech_types + 1]
373
+ remove_silence = args[3 * num_additional_speech_types + 1]
374
+
375
+ # Collect the speech types and their audios into a dict
376
+ speech_types = {"Regular": {"audio": regular_audio, "ref_text": regular_ref_text}}
377
+
378
+ for name_input, audio_input, ref_text_input in zip(
379
+ speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
380
+ ):
381
+ if name_input and audio_input:
382
+ speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
383
+
384
+ # Parse the gen_text into segments
385
+ segments = parse_speechtypes_text(gen_text)
386
+
387
+ # For each segment, generate speech
388
+ generated_audio_segments = []
389
+ current_style = "Regular"
390
+
391
+ for segment in segments:
392
+ style = segment["style"]
393
+ text = segment["text"]
394
+
395
+ if style in speech_types:
396
+ current_style = style
397
+ else:
398
+ # If style not available, default to Regular
399
+ current_style = "Regular"
400
+
401
+ ref_audio = speech_types[current_style]["audio"]
402
+ ref_text = speech_types[current_style].get("ref_text", "")
403
+
404
+ # Generate speech for this segment
405
+ audio, _ = infer(
406
+ ref_audio, ref_text, text, model_choice, remove_silence, 0, show_info=print
407
+ ) # show_info=print no pull to top when generating
408
+ sr, audio_data = audio
409
+
410
+ generated_audio_segments.append(audio_data)
411
+
412
+ # Concatenate all audio segments
413
+ if generated_audio_segments:
414
+ final_audio_data = np.concatenate(generated_audio_segments)
415
+ return (sr, final_audio_data)
416
+ else:
417
+ gr.Warning("No audio generated.")
418
+ return None
419
+
420
+ generate_multistyle_btn.click(
421
+ generate_multistyle_speech,
422
+ inputs=[
423
+ regular_audio,
424
+ regular_ref_text,
425
+ gen_text_input_multistyle,
426
+ ]
427
+ + speech_type_names
428
+ + speech_type_audios
429
+ + speech_type_ref_texts
430
+ + [
431
+ model_choice_multistyle,
432
+ remove_silence_multistyle,
433
+ ],
434
+ outputs=audio_output_multistyle,
435
+ )
436
+
437
+ # Validation function to disable Generate button if speech types are missing
438
+ def validate_speech_types(gen_text, regular_name, *args):
439
+ num_additional_speech_types = max_speech_types - 1
440
+ speech_type_names_list = args[:num_additional_speech_types]
441
+
442
+ # Collect the speech types names
443
+ speech_types_available = set()
444
+ if regular_name:
445
+ speech_types_available.add(regular_name)
446
+ for name_input in speech_type_names_list:
447
+ if name_input:
448
+ speech_types_available.add(name_input)
449
+
450
+ # Parse the gen_text to get the speech types used
451
+ segments = parse_speechtypes_text(gen_text)
452
+ speech_types_in_text = set(segment["style"] for segment in segments)
453
+
454
+ # Check if all speech types in text are available
455
+ missing_speech_types = speech_types_in_text - speech_types_available
456
+
457
+ if missing_speech_types:
458
+ # Disable the generate button
459
+ return gr.update(interactive=False)
460
+ else:
461
+ # Enable the generate button
462
+ return gr.update(interactive=True)
463
+
464
+ gen_text_input_multistyle.change(
465
+ validate_speech_types,
466
+ inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
467
+ outputs=generate_multistyle_btn,
468
+ )
469
+
470
+
471
+ with gr.Blocks() as app_chat:
472
+ gr.Markdown(
473
+ """
474
+ # Voice Chat
475
+ Have a conversation with an AI using your reference voice!
476
+ 1. Upload a reference audio clip and optionally its transcript.
477
+ 2. Load the chat model.
478
+ 3. Record your message through your microphone.
479
+ 4. The AI will respond using the reference voice.
480
+ """
481
+ )
482
+
483
+ if not USING_SPACES:
484
+ load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
485
+
486
+ chat_interface_container = gr.Column(visible=False)
487
+
488
+ @gpu_decorator
489
+ def load_chat_model():
490
+ global chat_model_state, chat_tokenizer_state
491
+ if chat_model_state is None:
492
+ show_info = gr.Info
493
+ show_info("Loading chat model...")
494
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
495
+ chat_model_state = AutoModelForCausalLM.from_pretrained(
496
+ model_name, torch_dtype="auto", device_map="auto"
497
+ )
498
+ chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
499
+ show_info("Chat model loaded.")
500
+
501
+ return gr.update(visible=False), gr.update(visible=True)
502
+
503
+ load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
504
+
505
+ else:
506
+ chat_interface_container = gr.Column()
507
+
508
+ if chat_model_state is None:
509
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
510
+ chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
511
+ chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
512
+
513
+ with chat_interface_container:
514
+ with gr.Row():
515
+ with gr.Column():
516
+ ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
517
+ with gr.Column():
518
+ with gr.Accordion("Advanced Settings", open=False):
519
+ model_choice_chat = gr.Radio(
520
+ choices=["F5-TTS", "E2-TTS"],
521
+ label="TTS Model",
522
+ value="F5-TTS",
523
+ )
524
+ remove_silence_chat = gr.Checkbox(
525
+ label="Remove Silences",
526
+ value=True,
527
+ )
528
+ ref_text_chat = gr.Textbox(
529
+ label="Reference Text",
530
+ info="Optional: Leave blank to auto-transcribe",
531
+ lines=2,
532
+ )
533
+ system_prompt_chat = gr.Textbox(
534
+ label="System Prompt",
535
+ value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
536
+ lines=2,
537
+ )
538
+
539
+ chatbot_interface = gr.Chatbot(label="Conversation")
540
+
541
+ with gr.Row():
542
+ with gr.Column():
543
+ audio_input_chat = gr.Microphone(
544
+ label="Speak your message",
545
+ type="filepath",
546
+ )
547
+ audio_output_chat = gr.Audio(autoplay=True)
548
+ with gr.Column():
549
+ text_input_chat = gr.Textbox(
550
+ label="Type your message",
551
+ lines=1,
552
+ )
553
+ send_btn_chat = gr.Button("Send")
554
+ clear_btn_chat = gr.Button("Clear Conversation")
555
+
556
+ conversation_state = gr.State(
557
+ value=[
558
+ {
559
+ "role": "system",
560
+ "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
561
+ }
562
+ ]
563
+ )
564
+
565
+ # Modify process_audio_input to use model and tokenizer from state
566
+ @gpu_decorator
567
+ def process_audio_input(audio_path, text, history, conv_state):
568
+ """Handle audio or text input from user"""
569
+
570
+ if not audio_path and not text.strip():
571
+ return history, conv_state, ""
572
+
573
+ if audio_path:
574
+ text = preprocess_ref_audio_text(audio_path, text)[1]
575
+
576
+ if not text.strip():
577
+ return history, conv_state, ""
578
+
579
+ conv_state.append({"role": "user", "content": text})
580
+ history.append((text, None))
581
+
582
+ response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
583
+
584
+ conv_state.append({"role": "assistant", "content": response})
585
+ history[-1] = (text, response)
586
+
587
+ return history, conv_state, ""
588
+
589
+ @gpu_decorator
590
+ def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
591
+ """Generate TTS audio for AI response"""
592
+ if not history or not ref_audio:
593
+ return None
594
+
595
+ last_user_message, last_ai_response = history[-1]
596
+ if not last_ai_response:
597
+ return None
598
+
599
+ audio_result, _ = infer(
600
+ ref_audio,
601
+ ref_text,
602
+ last_ai_response,
603
+ model,
604
+ remove_silence,
605
+ cross_fade_duration=0.15,
606
+ speed=1.0,
607
+ show_info=print, # show_info=print no pull to top when generating
608
+ )
609
+ return audio_result
610
+
611
+ def clear_conversation():
612
+ """Reset the conversation"""
613
+ return [], [
614
+ {
615
+ "role": "system",
616
+ "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
617
+ }
618
+ ]
619
+
620
+ def update_system_prompt(new_prompt):
621
+ """Update the system prompt and reset the conversation"""
622
+ new_conv_state = [{"role": "system", "content": new_prompt}]
623
+ return [], new_conv_state
624
+
625
+ # Handle audio input
626
+ audio_input_chat.stop_recording(
627
+ process_audio_input,
628
+ inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
629
+ outputs=[chatbot_interface, conversation_state],
630
+ ).then(
631
+ generate_audio_response,
632
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
633
+ outputs=[audio_output_chat],
634
+ ).then(
635
+ lambda: None,
636
+ None,
637
+ audio_input_chat,
638
+ )
639
+
640
+ # Handle text input
641
+ text_input_chat.submit(
642
+ process_audio_input,
643
+ inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
644
+ outputs=[chatbot_interface, conversation_state],
645
+ ).then(
646
+ generate_audio_response,
647
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
648
+ outputs=[audio_output_chat],
649
+ ).then(
650
+ lambda: None,
651
+ None,
652
+ text_input_chat,
653
+ )
654
+
655
+ # Handle send button
656
+ send_btn_chat.click(
657
+ process_audio_input,
658
+ inputs=[audio_input_chat, text_input_chat, chatbot_interface, conversation_state],
659
+ outputs=[chatbot_interface, conversation_state],
660
+ ).then(
661
+ generate_audio_response,
662
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
663
+ outputs=[audio_output_chat],
664
+ ).then(
665
+ lambda: None,
666
+ None,
667
+ text_input_chat,
668
+ )
669
+
670
+ # Handle clear button
671
+ clear_btn_chat.click(
672
+ clear_conversation,
673
+ outputs=[chatbot_interface, conversation_state],
674
+ )
675
+
676
+ # Handle system prompt change and reset conversation
677
+ system_prompt_chat.change(
678
+ update_system_prompt,
679
+ inputs=system_prompt_chat,
680
+ outputs=[chatbot_interface, conversation_state],
681
+ )
682
+
683
+
684
+ with gr.Blocks() as app:
685
+ gr.Markdown(
686
+ """
687
+ # E2/F5 TTS
688
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
689
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
690
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
691
+ The checkpoints support English and Chinese.
692
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
693
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
694
+ """
695
+ )
696
+ gr.TabbedInterface(
697
+ [app_tts, app_multistyle, app_chat, app_credits],
698
+ ["TTS", "Multi-Speech", "Voice-Chat", "Credits"],
699
+ )
700
+
701
+
702
+ @click.command()
703
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
704
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
705
+ @click.option(
706
+ "--share",
707
+ "-s",
708
+ default=False,
709
+ is_flag=True,
710
+ help="Share the app via Gradio share link",
711
+ )
712
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
713
+ def main(port, host, share, api):
714
+ global app
715
+ print("Starting app...")
716
+ app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
717
+
718
+
719
+ if __name__ == "__main__":
720
+ if not USING_SPACES:
721
+ main()
722
+ else:
723
+ app.queue().launch()