Hemant0000 commited on
Commit
8c9588c
·
verified ·
1 Parent(s): 483afa3

Create gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +815 -0
gradio_app.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import torchaudio
5
+ import gradio as gr
6
+ import numpy as np
7
+ import tempfile
8
+ from einops import rearrange
9
+ from vocos import Vocos
10
+ from pydub import AudioSegment, silence
11
+ from model import CFM, UNetT, DiT, MMDiT
12
+ from cached_path import cached_path
13
+ from model.utils import (
14
+ load_checkpoint,
15
+ get_tokenizer,
16
+ convert_char_to_pinyin,
17
+ save_spectrogram,
18
+ )
19
+ from transformers import pipeline
20
+ import librosa
21
+ import click
22
+ import soundfile as sf
23
+
24
+ try:
25
+ import spaces
26
+ USING_SPACES = True
27
+ except ImportError:
28
+ USING_SPACES = False
29
+
30
+ def gpu_decorator(func):
31
+ if USING_SPACES:
32
+ return spaces.GPU(func)
33
+ else:
34
+ return func
35
+
36
+
37
+
38
+ SPLIT_WORDS = [
39
+ "but", "however", "nevertheless", "yet", "still",
40
+ "therefore", "thus", "hence", "consequently",
41
+ "moreover", "furthermore", "additionally",
42
+ "meanwhile", "alternatively", "otherwise",
43
+ "namely", "specifically", "for example", "such as",
44
+ "in fact", "indeed", "notably",
45
+ "in contrast", "on the other hand", "conversely",
46
+ "in conclusion", "to summarize", "finally"
47
+ ]
48
+
49
+ device = (
50
+ "cuda"
51
+ if torch.cuda.is_available()
52
+ else "mps" if torch.backends.mps.is_available() else "cpu"
53
+ )
54
+
55
+ print(f"Using {device} device")
56
+
57
+ pipe = pipeline(
58
+ "automatic-speech-recognition",
59
+ model="openai/whisper-large-v3-turbo",
60
+ torch_dtype=torch.float16,
61
+ device=device,
62
+ )
63
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
64
+
65
+ # --------------------- Settings -------------------- #
66
+
67
+ target_sample_rate = 24000
68
+ n_mel_channels = 100
69
+ hop_length = 256
70
+ target_rms = 0.1
71
+ nfe_step = 32 # 16, 32
72
+ cfg_strength = 2.0
73
+ ode_method = "euler"
74
+ sway_sampling_coef = -1.0
75
+ speed = 1.0
76
+ # fix_duration = 27 # None or float (duration in seconds)
77
+ fix_duration = None
78
+
79
+
80
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
81
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
82
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
83
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
84
+ model = CFM(
85
+ transformer=model_cls(
86
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
87
+ ),
88
+ mel_spec_kwargs=dict(
89
+ target_sample_rate=target_sample_rate,
90
+ n_mel_channels=n_mel_channels,
91
+ hop_length=hop_length,
92
+ ),
93
+ odeint_kwargs=dict(
94
+ method=ode_method,
95
+ ),
96
+ vocab_char_map=vocab_char_map,
97
+ ).to(device)
98
+
99
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
100
+
101
+ return model
102
+
103
+
104
+ # load models
105
+ F5TTS_model_cfg = dict(
106
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
107
+ )
108
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
109
+
110
+ F5TTS_ema_model = load_model(
111
+ "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
112
+ )
113
+ E2TTS_ema_model = load_model(
114
+ "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
+ )
116
+
117
+ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
+ if len(text.encode('utf-8')) <= max_chars:
119
+ return [text]
120
+ if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
+ text += '.'
122
+
123
+ sentences = re.split('([。.!?!?])', text)
124
+ sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
+
126
+ batches = []
127
+ current_batch = ""
128
+
129
+ def split_by_words(text):
130
+ words = text.split()
131
+ current_word_part = ""
132
+ word_batches = []
133
+ for word in words:
134
+ if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
+ current_word_part += word + ' '
136
+ else:
137
+ if current_word_part:
138
+ # Try to find a suitable split word
139
+ for split_word in split_words:
140
+ split_index = current_word_part.rfind(' ' + split_word + ' ')
141
+ if split_index != -1:
142
+ word_batches.append(current_word_part[:split_index].strip())
143
+ current_word_part = current_word_part[split_index:].strip() + ' '
144
+ break
145
+ else:
146
+ # If no suitable split word found, just append the current part
147
+ word_batches.append(current_word_part.strip())
148
+ current_word_part = ""
149
+ current_word_part += word + ' '
150
+ if current_word_part:
151
+ word_batches.append(current_word_part.strip())
152
+ return word_batches
153
+
154
+ for sentence in sentences:
155
+ if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
+ current_batch += sentence
157
+ else:
158
+ # If adding this sentence would exceed the limit
159
+ if current_batch:
160
+ batches.append(current_batch)
161
+ current_batch = ""
162
+
163
+ # If the sentence itself is longer than max_chars, split it
164
+ if len(sentence.encode('utf-8')) > max_chars:
165
+ # First, try to split by colon
166
+ colon_parts = sentence.split(':')
167
+ if len(colon_parts) > 1:
168
+ for part in colon_parts:
169
+ if len(part.encode('utf-8')) <= max_chars:
170
+ batches.append(part)
171
+ else:
172
+ # If colon part is still too long, split by comma
173
+ comma_parts = re.split('[,,]', part)
174
+ if len(comma_parts) > 1:
175
+ current_comma_part = ""
176
+ for comma_part in comma_parts:
177
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
+ current_comma_part += comma_part + ','
179
+ else:
180
+ if current_comma_part:
181
+ batches.append(current_comma_part.rstrip(','))
182
+ current_comma_part = comma_part + ','
183
+ if current_comma_part:
184
+ batches.append(current_comma_part.rstrip(','))
185
+ else:
186
+ # If no comma, split by words
187
+ batches.extend(split_by_words(part))
188
+ else:
189
+ # If no colon, split by comma
190
+ comma_parts = re.split('[,,]', sentence)
191
+ if len(comma_parts) > 1:
192
+ current_comma_part = ""
193
+ for comma_part in comma_parts:
194
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
+ current_comma_part += comma_part + ','
196
+ else:
197
+ if current_comma_part:
198
+ batches.append(current_comma_part.rstrip(','))
199
+ current_comma_part = comma_part + ','
200
+ if current_comma_part:
201
+ batches.append(current_comma_part.rstrip(','))
202
+ else:
203
+ # If no comma, split by words
204
+ batches.extend(split_by_words(sentence))
205
+ else:
206
+ current_batch = sentence
207
+
208
+ if current_batch:
209
+ batches.append(current_batch)
210
+
211
+ return batches
212
+
213
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
214
+ if exp_name == "F5-TTS":
215
+ ema_model = F5TTS_ema_model
216
+ elif exp_name == "E2-TTS":
217
+ ema_model = E2TTS_ema_model
218
+
219
+ audio, sr = ref_audio
220
+ if audio.shape[0] > 1:
221
+ audio = torch.mean(audio, dim=0, keepdim=True)
222
+
223
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
224
+ if rms < target_rms:
225
+ audio = audio * target_rms / rms
226
+ if sr != target_sample_rate:
227
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
228
+ audio = resampler(audio)
229
+ audio = audio.to(device)
230
+
231
+ generated_waves = []
232
+ spectrograms = []
233
+
234
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
235
+ # Prepare the text
236
+ if len(ref_text[-1].encode('utf-8')) == 1:
237
+ ref_text = ref_text + " "
238
+ text_list = [ref_text + gen_text]
239
+ final_text_list = convert_char_to_pinyin(text_list)
240
+
241
+ # Calculate duration
242
+ ref_audio_len = audio.shape[-1] // hop_length
243
+ zh_pause_punc = r"。,、;:?!"
244
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
245
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
246
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
247
+
248
+ # inference
249
+ with torch.inference_mode():
250
+ generated, _ = ema_model.sample(
251
+ cond=audio,
252
+ text=final_text_list,
253
+ duration=duration,
254
+ steps=nfe_step,
255
+ cfg_strength=cfg_strength,
256
+ sway_sampling_coef=sway_sampling_coef,
257
+ )
258
+
259
+ generated = generated[:, ref_audio_len:, :]
260
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
261
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
262
+ if rms < target_rms:
263
+ generated_wave = generated_wave * rms / target_rms
264
+
265
+ # wav -> numpy
266
+ generated_wave = generated_wave.squeeze().cpu().numpy()
267
+
268
+ generated_waves.append(generated_wave)
269
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
270
+
271
+ # Combine all generated waves
272
+ final_wave = np.concatenate(generated_waves)
273
+
274
+ # Remove silence
275
+ if remove_silence:
276
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
+ sf.write(f.name, final_wave, target_sample_rate)
278
+ aseg = AudioSegment.from_file(f.name)
279
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
280
+ non_silent_wave = AudioSegment.silent(duration=0)
281
+ for non_silent_seg in non_silent_segs:
282
+ non_silent_wave += non_silent_seg
283
+ aseg = non_silent_wave
284
+ aseg.export(f.name, format="wav")
285
+ final_wave, _ = torchaudio.load(f.name)
286
+ final_wave = final_wave.squeeze().cpu().numpy()
287
+
288
+ # Create a combined spectrogram
289
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
290
+
291
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
292
+ spectrogram_path = tmp_spectrogram.name
293
+ save_spectrogram(combined_spectrogram, spectrogram_path)
294
+
295
+ return (target_sample_rate, final_wave), spectrogram_path
296
+
297
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
+ if not custom_split_words.strip():
299
+ custom_words = [word.strip() for word in custom_split_words.split(',')]
300
+ global SPLIT_WORDS
301
+ SPLIT_WORDS = custom_words
302
+
303
+ print(gen_text)
304
+
305
+ gr.Info("Converting audio...")
306
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
+ aseg = AudioSegment.from_file(ref_audio_orig)
308
+
309
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
310
+ non_silent_wave = AudioSegment.silent(duration=0)
311
+ for non_silent_seg in non_silent_segs:
312
+ non_silent_wave += non_silent_seg
313
+ aseg = non_silent_wave
314
+
315
+ audio_duration = len(aseg)
316
+ if audio_duration > 15000:
317
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
318
+ aseg = aseg[:15000]
319
+ aseg.export(f.name, format="wav")
320
+ ref_audio = f.name
321
+
322
+ if not ref_text.strip():
323
+ gr.Info("No reference text provided, transcribing reference audio...")
324
+ ref_text = pipe(
325
+ ref_audio,
326
+ chunk_length_s=30,
327
+ batch_size=128,
328
+ generate_kwargs={"task": "transcribe"},
329
+ return_timestamps=False,
330
+ )["text"].strip()
331
+ gr.Info("Finished transcription")
332
+ else:
333
+ gr.Info("Using custom reference text...")
334
+
335
+ # Split the input text into batches
336
+ audio, sr = torchaudio.load(ref_audio)
337
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
+ gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
339
+ print('ref_text', ref_text)
340
+ for i, gen_text in enumerate(gen_text_batches):
341
+ print(f'gen_text {i}', gen_text)
342
+
343
+ gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
345
+
346
+ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
347
+ # Split the script into speaker blocks
348
+ speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
349
+ speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
350
+
351
+ generated_audio_segments = []
352
+
353
+ for i in range(0, len(speaker_blocks), 2):
354
+ speaker = speaker_blocks[i]
355
+ text = speaker_blocks[i+1].strip()
356
+
357
+ # Determine which speaker is talking
358
+ if speaker == speaker1_name:
359
+ ref_audio = ref_audio1
360
+ ref_text = ref_text1
361
+ elif speaker == speaker2_name:
362
+ ref_audio = ref_audio2
363
+ ref_text = ref_text2
364
+ else:
365
+ continue # Skip if the speaker is neither speaker1 nor speaker2
366
+
367
+ # Generate audio for this block
368
+ audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
369
+
370
+ # Convert the generated audio to a numpy array
371
+ sr, audio_data = audio
372
+
373
+ # Save the audio data as a WAV file
374
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
375
+ sf.write(temp_file.name, audio_data, sr)
376
+ audio_segment = AudioSegment.from_wav(temp_file.name)
377
+
378
+ generated_audio_segments.append(audio_segment)
379
+
380
+ # Add a short pause between speakers
381
+ pause = AudioSegment.silent(duration=500) # 500ms pause
382
+ generated_audio_segments.append(pause)
383
+
384
+ # Concatenate all audio segments
385
+ final_podcast = sum(generated_audio_segments)
386
+
387
+ # Export the final podcast
388
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
389
+ podcast_path = temp_file.name
390
+ final_podcast.export(podcast_path, format="wav")
391
+
392
+ return podcast_path
393
+
394
+ def parse_speechtypes_text(gen_text):
395
+ # Pattern to find (Emotion)
396
+ pattern = r'\((.*?)\)'
397
+
398
+ # Split the text by the pattern
399
+ tokens = re.split(pattern, gen_text)
400
+
401
+ segments = []
402
+
403
+ current_emotion = 'Regular'
404
+
405
+ for i in range(len(tokens)):
406
+ if i % 2 == 0:
407
+ # This is text
408
+ text = tokens[i].strip()
409
+ if text:
410
+ segments.append({'emotion': current_emotion, 'text': text})
411
+ else:
412
+ # This is emotion
413
+ emotion = tokens[i].strip()
414
+ current_emotion = emotion
415
+
416
+ return segments
417
+
418
+ def update_speed(new_speed):
419
+ global speed
420
+ speed = new_speed
421
+ return f"Speed set to: {speed}"
422
+
423
+ with gr.Blocks() as app_credits:
424
+ gr.Markdown("""
425
+ # Credits
426
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
427
+ * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
428
+ """)
429
+ with gr.Blocks() as app_tts:
430
+ gr.Markdown("# Batched TTS")
431
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
432
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
433
+ model_choice = gr.Radio(
434
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
435
+ )
436
+ generate_btn = gr.Button("Synthesize", variant="primary")
437
+ with gr.Accordion("Advanced Settings", open=False):
438
+ ref_text_input = gr.Textbox(
439
+ label="Reference Text",
440
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
441
+ lines=2,
442
+ )
443
+ remove_silence = gr.Checkbox(
444
+ label="Remove Silences",
445
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
446
+ value=True,
447
+ )
448
+ split_words_input = gr.Textbox(
449
+ label="Custom Split Words",
450
+ info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
451
+ lines=2,
452
+ )
453
+ speed_slider = gr.Slider(
454
+ label="Speed",
455
+ minimum=0.3,
456
+ maximum=2.0,
457
+ value=speed,
458
+ step=0.1,
459
+ info="Adjust the speed of the audio.",
460
+ )
461
+ speed_slider.change(update_speed, inputs=speed_slider)
462
+
463
+ audio_output = gr.Audio(label="Synthesized Audio")
464
+ spectrogram_output = gr.Image(label="Spectrogram")
465
+
466
+ generate_btn.click(
467
+ infer,
468
+ inputs=[
469
+ ref_audio_input,
470
+ ref_text_input,
471
+ gen_text_input,
472
+ model_choice,
473
+ remove_silence,
474
+ split_words_input,
475
+ ],
476
+ outputs=[audio_output, spectrogram_output],
477
+ )
478
+
479
+ with gr.Blocks() as app_podcast:
480
+ gr.Markdown("# Podcast Generation")
481
+ speaker1_name = gr.Textbox(label="Speaker 1 Name")
482
+ ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
483
+ ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
484
+
485
+ speaker2_name = gr.Textbox(label="Speaker 2 Name")
486
+ ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
487
+ ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
488
+
489
+ script_input = gr.Textbox(label="Podcast Script", lines=10,
490
+ placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
491
+
492
+ podcast_model_choice = gr.Radio(
493
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
494
+ )
495
+ podcast_remove_silence = gr.Checkbox(
496
+ label="Remove Silences",
497
+ value=True,
498
+ )
499
+ generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
500
+ podcast_output = gr.Audio(label="Generated Podcast")
501
+
502
+ def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
503
+ return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
504
+
505
+ generate_podcast_btn.click(
506
+ podcast_generation,
507
+ inputs=[
508
+ script_input,
509
+ speaker1_name,
510
+ ref_audio_input1,
511
+ ref_text_input1,
512
+ speaker2_name,
513
+ ref_audio_input2,
514
+ ref_text_input2,
515
+ podcast_model_choice,
516
+ podcast_remove_silence,
517
+ ],
518
+ outputs=podcast_output,
519
+ )
520
+
521
+ def parse_emotional_text(gen_text):
522
+ # Pattern to find (Emotion)
523
+ pattern = r'\((.*?)\)'
524
+
525
+ # Split the text by the pattern
526
+ tokens = re.split(pattern, gen_text)
527
+
528
+ segments = []
529
+
530
+ current_emotion = 'Regular'
531
+
532
+ for i in range(len(tokens)):
533
+ if i % 2 == 0:
534
+ # This is text
535
+ text = tokens[i].strip()
536
+ if text:
537
+ segments.append({'emotion': current_emotion, 'text': text})
538
+ else:
539
+ # This is emotion
540
+ emotion = tokens[i].strip()
541
+ current_emotion = emotion
542
+
543
+ return segments
544
+
545
+ with gr.Blocks() as app_emotional:
546
+ # New section for emotional generation
547
+ gr.Markdown(
548
+ """
549
+ # Multiple Speech-Type Generation
550
+ This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
551
+ **Example Input:**
552
+ (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
553
+ """
554
+ )
555
+
556
+ gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
557
+
558
+ # Regular speech type (mandatory)
559
+ with gr.Row():
560
+ regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
561
+ regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
562
+ regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
563
+
564
+ # Additional speech types (up to 9 more)
565
+ max_speech_types = 10
566
+ speech_type_names = []
567
+ speech_type_audios = []
568
+ speech_type_ref_texts = []
569
+ speech_type_delete_btns = []
570
+
571
+ for i in range(max_speech_types - 1):
572
+ with gr.Row():
573
+ name_input = gr.Textbox(label='Speech Type Name', visible=False)
574
+ audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
575
+ ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
576
+ delete_btn = gr.Button("Delete", variant="secondary", visible=False)
577
+ speech_type_names.append(name_input)
578
+ speech_type_audios.append(audio_input)
579
+ speech_type_ref_texts.append(ref_text_input)
580
+ speech_type_delete_btns.append(delete_btn)
581
+
582
+ # Button to add speech type
583
+ add_speech_type_btn = gr.Button("Add Speech Type")
584
+
585
+ # Keep track of current number of speech types
586
+ speech_type_count = gr.State(value=0)
587
+
588
+ # Function to add a speech type
589
+ def add_speech_type_fn(speech_type_count):
590
+ if speech_type_count < max_speech_types - 1:
591
+ speech_type_count += 1
592
+ # Prepare updates for the components
593
+ name_updates = []
594
+ audio_updates = []
595
+ ref_text_updates = []
596
+ delete_btn_updates = []
597
+ for i in range(max_speech_types - 1):
598
+ if i < speech_type_count:
599
+ name_updates.append(gr.update(visible=True))
600
+ audio_updates.append(gr.update(visible=True))
601
+ ref_text_updates.append(gr.update(visible=True))
602
+ delete_btn_updates.append(gr.update(visible=True))
603
+ else:
604
+ name_updates.append(gr.update())
605
+ audio_updates.append(gr.update())
606
+ ref_text_updates.append(gr.update())
607
+ delete_btn_updates.append(gr.update())
608
+ else:
609
+ # Optionally, show a warning
610
+ # gr.Warning("Maximum number of speech types reached.")
611
+ name_updates = [gr.update() for _ in range(max_speech_types - 1)]
612
+ audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
613
+ ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
614
+ delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
615
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
616
+
617
+ add_speech_type_btn.click(
618
+ add_speech_type_fn,
619
+ inputs=speech_type_count,
620
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
621
+ )
622
+
623
+ # Function to delete a speech type
624
+ def make_delete_speech_type_fn(index):
625
+ def delete_speech_type_fn(speech_type_count):
626
+ # Prepare updates
627
+ name_updates = []
628
+ audio_updates = []
629
+ ref_text_updates = []
630
+ delete_btn_updates = []
631
+
632
+ for i in range(max_speech_types - 1):
633
+ if i == index:
634
+ name_updates.append(gr.update(visible=False, value=''))
635
+ audio_updates.append(gr.update(visible=False, value=None))
636
+ ref_text_updates.append(gr.update(visible=False, value=''))
637
+ delete_btn_updates.append(gr.update(visible=False))
638
+ else:
639
+ name_updates.append(gr.update())
640
+ audio_updates.append(gr.update())
641
+ ref_text_updates.append(gr.update())
642
+ delete_btn_updates.append(gr.update())
643
+
644
+ speech_type_count = max(0, speech_type_count - 1)
645
+
646
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
647
+
648
+ return delete_speech_type_fn
649
+
650
+ for i, delete_btn in enumerate(speech_type_delete_btns):
651
+ delete_fn = make_delete_speech_type_fn(i)
652
+ delete_btn.click(
653
+ delete_fn,
654
+ inputs=speech_type_count,
655
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
656
+ )
657
+
658
+ # Text input for the prompt
659
+ gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
660
+
661
+ # Model choice
662
+ model_choice_emotional = gr.Radio(
663
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
664
+ )
665
+
666
+ with gr.Accordion("Advanced Settings", open=False):
667
+ remove_silence_emotional = gr.Checkbox(
668
+ label="Remove Silences",
669
+ value=True,
670
+ )
671
+
672
+ # Generate button
673
+ generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
674
+
675
+ # Output audio
676
+ audio_output_emotional = gr.Audio(label="Synthesized Audio")
677
+
678
+ def generate_emotional_speech(
679
+ regular_audio,
680
+ regular_ref_text,
681
+ gen_text,
682
+ *args,
683
+ ):
684
+ num_additional_speech_types = max_speech_types - 1
685
+ speech_type_names_list = args[:num_additional_speech_types]
686
+ speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
687
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
688
+ model_choice = args[3 * num_additional_speech_types]
689
+ remove_silence = args[3 * num_additional_speech_types + 1]
690
+
691
+ # Collect the speech types and their audios into a dict
692
+ speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
693
+
694
+ for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
695
+ if name_input and audio_input:
696
+ speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
697
+
698
+ # Parse the gen_text into segments
699
+ segments = parse_speechtypes_text(gen_text)
700
+
701
+ # For each segment, generate speech
702
+ generated_audio_segments = []
703
+ current_emotion = 'Regular'
704
+
705
+ for segment in segments:
706
+ emotion = segment['emotion']
707
+ text = segment['text']
708
+
709
+ if emotion in speech_types:
710
+ current_emotion = emotion
711
+ else:
712
+ # If emotion not available, default to Regular
713
+ current_emotion = 'Regular'
714
+
715
+ ref_audio = speech_types[current_emotion]['audio']
716
+ ref_text = speech_types[current_emotion].get('ref_text', '')
717
+
718
+ # Generate speech for this segment
719
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
720
+ sr, audio_data = audio
721
+
722
+ generated_audio_segments.append(audio_data)
723
+
724
+ # Concatenate all audio segments
725
+ if generated_audio_segments:
726
+ final_audio_data = np.concatenate(generated_audio_segments)
727
+ return (sr, final_audio_data)
728
+ else:
729
+ gr.Warning("No audio generated.")
730
+ return None
731
+
732
+ generate_emotional_btn.click(
733
+ generate_emotional_speech,
734
+ inputs=[
735
+ regular_audio,
736
+ regular_ref_text,
737
+ gen_text_input_emotional,
738
+ ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
739
+ model_choice_emotional,
740
+ remove_silence_emotional,
741
+ ],
742
+ outputs=audio_output_emotional,
743
+ )
744
+
745
+ # Validation function to disable Generate button if speech types are missing
746
+ def validate_speech_types(
747
+ gen_text,
748
+ regular_name,
749
+ *args
750
+ ):
751
+ num_additional_speech_types = max_speech_types - 1
752
+ speech_type_names_list = args[:num_additional_speech_types]
753
+
754
+ # Collect the speech types names
755
+ speech_types_available = set()
756
+ if regular_name:
757
+ speech_types_available.add(regular_name)
758
+ for name_input in speech_type_names_list:
759
+ if name_input:
760
+ speech_types_available.add(name_input)
761
+
762
+ # Parse the gen_text to get the speech types used
763
+ segments = parse_emotional_text(gen_text)
764
+ speech_types_in_text = set(segment['emotion'] for segment in segments)
765
+
766
+ # Check if all speech types in text are available
767
+ missing_speech_types = speech_types_in_text - speech_types_available
768
+
769
+ if missing_speech_types:
770
+ # Disable the generate button
771
+ return gr.update(interactive=False)
772
+ else:
773
+ # Enable the generate button
774
+ return gr.update(interactive=True)
775
+
776
+ gen_text_input_emotional.change(
777
+ validate_speech_types,
778
+ inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
779
+ outputs=generate_emotional_btn
780
+ )
781
+ with gr.Blocks() as app:
782
+ gr.Markdown(
783
+ """
784
+ # E2/F5 TTS
785
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
786
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
787
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
788
+ The checkpoints support English and Chinese.
789
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
790
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
791
+ """
792
+ )
793
+ gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
794
+
795
+ @click.command()
796
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
797
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
798
+ @click.option(
799
+ "--share",
800
+ "-s",
801
+ default=False,
802
+ is_flag=True,
803
+ help="Share the app via Gradio share link",
804
+ )
805
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
806
+ def main(port, host, share, api):
807
+ global app
808
+ print(f"Starting app...")
809
+ app.queue(api_open=api).launch(
810
+ server_name=host, server_port=port, share=share, show_api=api
811
+ )
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()