NguyenNhatSakura commited on
Commit
f7f1184
·
verified ·
1 Parent(s): 1839023

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -258
app.py CHANGED
@@ -1,61 +1,113 @@
1
- import csv
2
- import datetime
 
3
  import os
4
- import re
5
- import time
6
- import uuid
7
- from io import StringIO
 
8
 
9
  import gradio as gr
10
- import spaces
11
  import torch
12
  import torchaudio
13
- from huggingface_hub import HfApi, hf_hub_download, snapshot_download
 
 
 
 
14
  from TTS.tts.configs.xtts_config import XttsConfig
15
  from TTS.tts.models.xtts import Xtts
16
- from vinorm import TTSnorm
17
 
18
- # download for mecab
19
- os.system("python -m unidic download")
 
 
 
 
 
20
 
21
- HF_TOKEN = os.environ.get("HF_TOKEN")
22
- api = HfApi(token=HF_TOKEN)
 
23
 
24
- # This will trigger downloading model
25
- print("Downloading if not downloaded viXTTS")
26
- checkpoint_dir = "model/"
27
- repo_id = "capleaf/viXTTS"
28
- use_deepspeed = False
29
 
30
- os.makedirs(checkpoint_dir, exist_ok=True)
 
 
 
31
 
32
- required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
33
- files_in_dir = os.listdir(checkpoint_dir)
34
- if not all(file in files_in_dir for file in required_files):
35
- snapshot_download(
36
- repo_id=repo_id,
37
- repo_type="model",
38
- local_dir=checkpoint_dir,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
- hf_hub_download(
41
- repo_id="coqui/XTTS-v2",
42
- filename="speakers_xtts.pth",
43
- local_dir=checkpoint_dir,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
-
46
- xtts_config = os.path.join(checkpoint_dir, "config.json")
47
- config = XttsConfig()
48
- config.load_json(xtts_config)
49
- MODEL = Xtts.init_from_config(config)
50
- MODEL.load_checkpoint(
51
- config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
52
- )
53
- if torch.cuda.is_available():
54
- MODEL.cuda()
55
-
56
- supported_languages = config.languages
57
- if not "vi" in supported_languages:
58
- supported_languages.append("vi")
59
 
60
 
61
  def normalize_vietnamese_text(text):
@@ -89,225 +141,269 @@ def calculate_keep_len(text, lang):
89
  return -1
90
 
91
 
92
- @spaces.GPU
93
- def predict(
94
- prompt,
95
- language,
96
- audio_file_pth,
97
- normalize_text=True,
98
- ):
99
- if language not in supported_languages:
100
- metrics_text = gr.Warning(
101
- f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- return (None, metrics_text)
105
-
106
- speaker_wav = audio_file_pth
107
-
108
- if len(prompt) < 2:
109
- metrics_text = gr.Warning("Please give a longer prompt text")
110
- return (None, metrics_text)
111
 
112
- if len(prompt) > 250:
113
- metrics_text = gr.Warning(
114
- str(len(prompt))
115
- + " characters.\n"
116
- + "Your prompt is too long, please keep it under 250 characters\n"
117
- + "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
118
- )
119
- return (None, metrics_text)
120
-
121
- try:
122
- metrics_text = ""
123
- t_latent = time.time()
124
-
125
- try:
126
- (
127
- gpt_cond_latent,
128
- speaker_embedding,
129
- ) = MODEL.get_conditioning_latents(
130
- audio_path=speaker_wav,
131
- gpt_cond_len=30,
132
- gpt_cond_chunk_len=4,
133
- max_ref_length=60,
134
- )
135
-
136
- except Exception as e:
137
- print("Speaker encoding error", str(e))
138
- metrics_text = gr.Warning(
139
- "It appears something wrong with reference, did you unmute your microphone?"
140
- )
141
- return (None, metrics_text)
142
-
143
- prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
144
-
145
- if normalize_text and language == "vi":
146
- prompt = normalize_vietnamese_text(prompt)
147
-
148
- print("I: Generating new audio...")
149
- t0 = time.time()
150
- out = MODEL.inference(
151
- prompt,
152
- language,
153
- gpt_cond_latent,
154
- speaker_embedding,
155
- repetition_penalty=5.0,
156
- temperature=0.75,
157
  enable_text_splitting=True,
158
  )
159
- inference_time = time.time() - t0
160
- print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
161
- metrics_text += (
162
- f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
163
- )
164
- real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
165
- print(f"Real-time factor (RTF): {real_time_factor}")
166
- metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
167
-
168
- # Temporary hack for short sentences
169
- keep_len = calculate_keep_len(prompt, language)
170
- out["wav"] = out["wav"][:keep_len]
171
-
172
- torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
173
-
174
- except RuntimeError as e:
175
- if "device-side assert" in str(e):
176
- # cannot do anything on cuda device side error, need tor estart
177
- print(
178
- f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
179
- flush=True,
180
- )
181
- gr.Warning("Unhandled Exception encounter, please retry in a minute")
182
- print("Cuda device-assert Runtime encountered need restart")
183
-
184
- error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
185
- error_data = [
186
- error_time,
187
- prompt,
188
- language,
189
- audio_file_pth,
190
- ]
191
- error_data = [str(e) if type(e) != str else e for e in error_data]
192
- print(error_data)
193
- print(speaker_wav)
194
- write_io = StringIO()
195
- csv.writer(write_io).writerows([error_data])
196
- csv_upload = write_io.getvalue().encode()
197
-
198
- filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
199
- print("Writing error csv")
200
- error_api = HfApi()
201
- error_api.upload_file(
202
- path_or_fileobj=csv_upload,
203
- path_in_repo=filename,
204
- repo_id="coqui/xtts-flagged-dataset",
205
- repo_type="dataset",
206
- )
207
-
208
- # speaker_wav
209
- print("Writing error reference audio")
210
- speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
211
- error_api = HfApi()
212
- error_api.upload_file(
213
- path_or_fileobj=speaker_wav,
214
- path_in_repo=speaker_filename,
215
- repo_id="coqui/xtts-flagged-dataset",
216
- repo_type="dataset",
217
- )
218
-
219
- # HF Space specific.. This error is unrecoverable need to restart space
220
- space = api.get_space_runtime(repo_id=repo_id)
221
- if space.stage != "BUILDING":
222
- api.restart_space(repo_id=repo_id)
223
- else:
224
- print("TRIED TO RESTART but space is building")
225
-
226
- else:
227
- if "Failed to decode" in str(e):
228
- print("Speaker encoding error", str(e))
229
- metrics_text = gr.Warning(
230
- metrics_text="It appears something wrong with reference, did you unmute your microphone?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  )
232
- else:
233
- print("RuntimeError: non device-side assert error:", str(e))
234
- metrics_text = gr.Warning(
235
- "Something unexpected happened please retry again."
236
  )
237
- return (None, metrics_text)
238
- return ("output.wav", metrics_text)
239
-
240
-
241
- with gr.Blocks(analytics_enabled=False) as demo:
242
- with gr.Row():
243
- with gr.Column():
244
- # placeholder to align the image
245
- pass
246
-
247
- with gr.Row():
248
- with gr.Column():
249
- input_text_gr = gr.Textbox(
250
- label="Text Prompt (Văn bản cần đọc)",
251
- info="Mỗi câu nên từ 10 từ trở lên. Tối đa 250 ký tự (khoảng 2 - 3 câu).",
252
- value="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt.",
253
- )
254
- language_gr = gr.Dropdown(
255
- label="Language (Ngôn ngữ)",
256
- choices=[
257
- "vi",
258
- "en",
259
- "es",
260
- "fr",
261
- "de",
262
- "it",
263
- "pt",
264
- "pl",
265
- "tr",
266
- "ru",
267
- "nl",
268
- "cs",
269
- "ar",
270
- "zh-cn",
271
- "ja",
272
- "ko",
273
- "hu",
274
- "hi",
275
- ],
276
- max_choices=1,
277
- value="vi",
278
- )
279
- normalize_text = gr.Checkbox(
280
- label="Chuẩn hóa văn bản tiếng Việt",
281
- info="Normalize Vietnamese text",
282
- value=True,
283
- )
284
- ref_gr = gr.Audio(
285
- label="Reference Audio (Giọng mẫu)",
286
- type="filepath",
287
- value="model/samples/nu-luu-loat.wav",
288
- )
289
- tts_button = gr.Button(
290
- "Đọc Ngay",
291
- elem_id="send-btn",
292
- visible=True,
293
- variant="primary",
294
- )
295
-
296
- with gr.Column():
297
- audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
298
- out_text_gr = gr.Text(label="Metrics")
299
-
300
- tts_button.click(
301
- predict,
302
- [
303
- input_text_gr,
304
- language_gr,
305
- ref_gr,
306
- normalize_text,
307
- ],
308
- outputs=[audio_gr, out_text_gr],
309
- api_name="predict",
310
- )
311
 
312
- demo.queue()
313
- demo.launch(debug=True, show_api=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import logging
4
  import os
5
+ import string
6
+ import subprocess
7
+ import sys
8
+ import tempfile
9
+ from datetime import datetime
10
 
11
  import gradio as gr
12
+ import soundfile as sf
13
  import torch
14
  import torchaudio
15
+ from huggingface_hub import hf_hub_download, snapshot_download
16
+ from underthesea import sent_tokenize
17
+ from unidecode import unidecode
18
+ from vinorm import TTSnorm
19
+
20
  from TTS.tts.configs.xtts_config import XttsConfig
21
  from TTS.tts.models.xtts import Xtts
 
22
 
23
+ XTTS_MODEL = None
24
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ MODEL_DIR = os.path.join(SCRIPT_DIR, "model")
26
+ OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output")
27
+ FILTER_SUFFIX = "_DeepFilterNet3.wav"
28
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
29
+
30
 
31
+ def clear_gpu_cache():
32
+ if torch.cuda.is_available():
33
+ torch.cuda.empty_cache()
34
 
 
 
 
 
 
35
 
36
+ def load_model(checkpoint_dir="model/", repo_id="capleaf/viXTTS", use_deepspeed=False):
37
+ global XTTS_MODEL
38
+ clear_gpu_cache()
39
+ os.makedirs(checkpoint_dir, exist_ok=True)
40
 
41
+ required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
42
+ files_in_dir = os.listdir(checkpoint_dir)
43
+ if not all(file in files_in_dir for file in required_files):
44
+ yield f"Missing model files! Downloading from {repo_id}..."
45
+ snapshot_download(
46
+ repo_id=repo_id,
47
+ repo_type="model",
48
+ local_dir=checkpoint_dir,
49
+ )
50
+ hf_hub_download(
51
+ repo_id="coqui/XTTS-v2",
52
+ filename="speakers_xtts.pth",
53
+ local_dir=checkpoint_dir,
54
+ )
55
+ yield f"Model download finished..."
56
+
57
+ xtts_config = os.path.join(checkpoint_dir, "config.json")
58
+ config = XttsConfig()
59
+ config.load_json(xtts_config)
60
+ XTTS_MODEL = Xtts.init_from_config(config)
61
+ yield "Loading model..."
62
+ XTTS_MODEL.load_checkpoint(
63
+ config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
64
  )
65
+ if torch.cuda.is_available():
66
+ XTTS_MODEL.cuda()
67
+
68
+ print("Model Loaded!")
69
+ yield "Model Loaded!"
70
+
71
+
72
+ # Define dictionaries to store cached results
73
+ cache_queue = []
74
+ speaker_audio_cache = {}
75
+ filter_cache = {}
76
+ conditioning_latents_cache = {}
77
+
78
+
79
+ def invalidate_cache(cache_limit=50):
80
+ """Invalidate the cache for the oldest key"""
81
+ if len(cache_queue) > cache_limit:
82
+ key_to_remove = cache_queue.pop(0)
83
+ print("Invalidating cache", key_to_remove)
84
+ if os.path.exists(key_to_remove):
85
+ os.remove(key_to_remove)
86
+ if os.path.exists(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")):
87
+ os.remove(key_to_remove.replace(".wav", "_DeepFilterNet3.wav"))
88
+ if key_to_remove in filter_cache:
89
+ del filter_cache[key_to_remove]
90
+ if key_to_remove in conditioning_latents_cache:
91
+ del conditioning_latents_cache[key_to_remove]
92
+
93
+
94
+ def generate_hash(data):
95
+ hash_object = hashlib.md5()
96
+ hash_object.update(data)
97
+ return hash_object.hexdigest()
98
+
99
+
100
+ def get_file_name(text, max_char=50):
101
+ filename = text[:max_char]
102
+ filename = filename.lower()
103
+ filename = filename.replace(" ", "_")
104
+ filename = filename.translate(
105
+ str.maketrans("", "", string.punctuation.replace("_", ""))
106
  )
107
+ filename = unidecode(filename)
108
+ current_datetime = datetime.now().strftime("%m%d%H%M%S")
109
+ filename = f"{current_datetime}_{filename}"
110
+ return filename
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  def normalize_vietnamese_text(text):
 
141
  return -1
142
 
143
 
144
+ def run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text):
145
+ global filter_cache, conditioning_latents_cache, cache_queue
146
+
147
+ if XTTS_MODEL is None:
148
+ return "You need to run the previous step to load the model !!", None, None
149
+
150
+ if not speaker_audio_file:
151
+ return "You need to provide reference audio!!!", None, None
152
+
153
+ # Use the file name as the key, since it's suppose to be unique 💀
154
+ speaker_audio_key = speaker_audio_file
155
+ if not speaker_audio_key in cache_queue:
156
+ cache_queue.append(speaker_audio_key)
157
+ invalidate_cache()
158
+
159
+ # Check if filtered reference is cached
160
+ if use_deepfilter and speaker_audio_key in filter_cache:
161
+ print("Using filter cache...")
162
+ speaker_audio_file = filter_cache[speaker_audio_key]
163
+ elif use_deepfilter:
164
+ print("Running filter...")
165
+ subprocess.run(
166
+ [
167
+ "deepFilter",
168
+ speaker_audio_file,
169
+ "-o",
170
+ os.path.dirname(speaker_audio_file),
171
+ ]
172
  )
173
+ filter_cache[speaker_audio_key] = speaker_audio_file.replace(
174
+ ".wav", FILTER_SUFFIX
175
+ )
176
+ speaker_audio_file = filter_cache[speaker_audio_key]
177
+
178
+ # Check if conditioning latents are cached
179
+ cache_key = (
180
+ speaker_audio_key,
181
+ XTTS_MODEL.config.gpt_cond_len,
182
+ XTTS_MODEL.config.max_ref_len,
183
+ XTTS_MODEL.config.sound_norm_refs,
184
+ )
185
+ if cache_key in conditioning_latents_cache:
186
+ print("Using conditioning latents cache...")
187
+ gpt_cond_latent, speaker_embedding = conditioning_latents_cache[cache_key]
188
+ else:
189
+ print("Computing conditioning latents...")
190
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
191
+ audio_path=speaker_audio_file,
192
+ gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
193
+ max_ref_length=XTTS_MODEL.config.max_ref_len,
194
+ sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
195
+ )
196
+ conditioning_latents_cache[cache_key] = (gpt_cond_latent, speaker_embedding)
197
 
198
+ if normalize_text and lang == "vi":
199
+ tts_text = normalize_vietnamese_text(tts_text)
 
 
 
 
 
200
 
201
+ # Split text by sentence
202
+ if lang in ["ja", "zh-cn"]:
203
+ sentences = tts_text.split("。")
204
+ else:
205
+ sentences = sent_tokenize(tts_text)
206
+
207
+ from pprint import pprint
208
+
209
+ pprint(sentences)
210
+
211
+ wav_chunks = []
212
+ for sentence in sentences:
213
+ if sentence.strip() == "":
214
+ continue
215
+ wav_chunk = XTTS_MODEL.inference(
216
+ text=sentence,
217
+ language=lang,
218
+ gpt_cond_latent=gpt_cond_latent,
219
+ speaker_embedding=speaker_embedding,
220
+ # The following values are carefully chosen for viXTTS
221
+ temperature=0.3,
222
+ length_penalty=1.0,
223
+ repetition_penalty=10.0,
224
+ top_k=30,
225
+ top_p=0.85,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  enable_text_splitting=True,
227
  )
228
+
229
+ keep_len = calculate_keep_len(sentence, lang)
230
+ wav_chunk["wav"] = wav_chunk["wav"][:keep_len]
231
+
232
+ wav_chunks.append(torch.tensor(wav_chunk["wav"]))
233
+
234
+ out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)
235
+ gr_audio_id = os.path.basename(os.path.dirname(speaker_audio_file))
236
+ out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}_{gr_audio_id}.wav")
237
+ print("Saving output to ", out_path)
238
+ torchaudio.save(out_path, out_wav, 24000)
239
+
240
+ return "Speech generated !", out_path
241
+
242
+
243
+ # Define a logger to redirect
244
+ class Logger:
245
+ def __init__(self, filename="log.out"):
246
+ self.log_file = filename
247
+ self.terminal = sys.stdout
248
+ self.log = open(self.log_file, "w")
249
+
250
+ def write(self, message):
251
+ self.terminal.write(message)
252
+ self.log.write(message)
253
+
254
+ def flush(self):
255
+ self.terminal.flush()
256
+ self.log.flush()
257
+
258
+ def isatty(self):
259
+ return False
260
+
261
+
262
+ # Redirect stdout and stderr to a file
263
+ sys.stdout = Logger()
264
+ sys.stderr = sys.stdout
265
+
266
+
267
+ logging.basicConfig(
268
+ level=logging.ERROR,
269
+ format="%(asctime)s [%(levelname)s] %(message)s",
270
+ handlers=[logging.StreamHandler(sys.stdout)],
271
+ )
272
+
273
+
274
+ def read_logs():
275
+ sys.stdout.flush()
276
+ with open(sys.stdout.log_file, "r") as f:
277
+ return f.read()
278
+
279
+
280
+ if __name__ == "__main__":
281
+ parser = argparse.ArgumentParser(
282
+ description="""viXTTS inference demo\n\n""",
283
+ formatter_class=argparse.RawTextHelpFormatter,
284
+ )
285
+ parser.add_argument(
286
+ "--port",
287
+ type=int,
288
+ help="Port to run the gradio demo. Default: 5003",
289
+ default=5003,
290
+ )
291
+
292
+ parser.add_argument(
293
+ "--model_dir",
294
+ type=str,
295
+ help="Path to the checkpoint directory. This directory must contain 04 files: model.pth, config.json, vocab.json and speakers_xtts.pth",
296
+ default=None,
297
+ )
298
+
299
+ parser.add_argument(
300
+ "--reference_audio",
301
+ type=str,
302
+ help="Path to the reference audio file.",
303
+ default=None,
304
+ )
305
+
306
+ args = parser.parse_args()
307
+ if args.model_dir:
308
+ MODEL_DIR = os.path.abspath(args.model_dir)
309
+
310
+ REFERENCE_AUDIO = os.path.join(SCRIPT_DIR, "assets", "sample_female.wav")
311
+ if args.reference_audio:
312
+ REFERENCE_AUDIO = os.abspath(args.reference_audio)
313
+
314
+ with gr.Blocks() as demo:
315
+ intro = """
316
+ # viXTTS Inference Demo
317
+ Visit viXTTS on HuggingFace: [viXTTS](https://huggingface.co/capleaf/viXTTS)
318
+ """
319
+ gr.Markdown(intro)
320
+ with gr.Row():
321
+ with gr.Column() as col1:
322
+ repo_id = gr.Textbox(
323
+ label="HuggingFace Repo ID",
324
+ value="capleaf/viXTTS",
325
  )
326
+ checkpoint_dir = gr.Textbox(
327
+ label="viXTTS model directory",
328
+ value=MODEL_DIR,
 
329
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ use_deepspeed = gr.Checkbox(
332
+ value=True, label="Use DeepSpeed for faster inference"
333
+ )
334
+
335
+ progress_load = gr.Label(label="Progress:")
336
+ load_btn = gr.Button(
337
+ value="Step 1 - Load viXTTS model", variant="primary"
338
+ )
339
+
340
+ with gr.Column() as col2:
341
+ speaker_reference_audio = gr.Audio(
342
+ label="Speaker reference audio:",
343
+ value=REFERENCE_AUDIO,
344
+ type="filepath",
345
+ )
346
+
347
+ tts_language = gr.Dropdown(
348
+ label="Language",
349
+ value="vi",
350
+ choices=[
351
+ "vi",
352
+ "en",
353
+ "es",
354
+ "fr",
355
+ "de",
356
+ "it",
357
+ "pt",
358
+ "pl",
359
+ "tr",
360
+ "ru",
361
+ "nl",
362
+ "cs",
363
+ "ar",
364
+ "zh",
365
+ "hu",
366
+ "ko",
367
+ "ja",
368
+ ],
369
+ )
370
+
371
+ use_filter = gr.Checkbox(
372
+ label="Denoise Reference Audio",
373
+ value=True,
374
+ )
375
+
376
+ normalize_text = gr.Checkbox(
377
+ label="Normalize Input Text",
378
+ value=True,
379
+ )
380
+
381
+ tts_text = gr.Textbox(
382
+ label="Input Text.",
383
+ value="Xin chào, tôi là một công cụ chuyển đổi văn bản thành giọng nói tiếng Việt được phát triển bởi nhóm Nón lá.",
384
+ )
385
+ tts_btn = gr.Button(value="Step 2 - Inference", variant="primary")
386
+
387
+ with gr.Column() as col3:
388
+ progress_gen = gr.Label(label="Progress:")
389
+ tts_output_audio = gr.Audio(label="Generated Audio.")
390
+
391
+ load_btn.click(
392
+ fn=load_model,
393
+ inputs=[checkpoint_dir, repo_id, use_deepspeed],
394
+ outputs=[progress_load],
395
+ )
396
+
397
+ tts_btn.click(
398
+ fn=run_tts,
399
+ inputs=[
400
+ tts_language,
401
+ tts_text,
402
+ speaker_reference_audio,
403
+ use_filter,
404
+ normalize_text,
405
+ ],
406
+ outputs=[progress_gen, tts_output_audio],
407
+ )
408
+
409
+ demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")