VTechAI commited on
Commit
5a19f99
·
1 Parent(s): 1bb45ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +694 -694
app.py CHANGED
@@ -1,694 +1,694 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from __future__ import annotations
8
-
9
- import gradio as gr
10
- import numpy as np
11
- # import torch
12
-
13
-
14
- from gradio_client import Client
15
-
16
- client = Client("https://facebook-seamless-m4t.hf.space/")
17
-
18
- DESCRIPTION = """
19
-
20
- # SM4T
21
-
22
- Ứng dụng có thể chuyển đổi giọng nói hoặc chữ viết sang giọng nói hoặc chữ viết của một ngôn ngữ khác.
23
- \nHiện tại SM4T đã hỗ trợ 94 ngôn ngữ khác nhau.
24
-
25
- """
26
-
27
- TASK_NAMES = [
28
- "S2ST (Speech to Speech translation)",
29
- "S2TT (Speech to Text translation)",
30
- "T2ST (Text to Speech translation)",
31
- "T2TT (Text to Text translation)",
32
- "ASR (Automatic Speech Recognition)",
33
- ]
34
-
35
- # Language dict
36
- language_code_to_name = {
37
- "afr": "Afrikaans",
38
- "amh": "Amharic",
39
- "arb": "Modern Standard Arabic",
40
- "ary": "Moroccan Arabic",
41
- "arz": "Egyptian Arabic",
42
- "asm": "Assamese",
43
- "ast": "Asturian",
44
- "azj": "North Azerbaijani",
45
- "bel": "Belarusian",
46
- "ben": "Bengali",
47
- "bos": "Bosnian",
48
- "bul": "Bulgarian",
49
- "cat": "Catalan",
50
- "ceb": "Cebuano",
51
- "ces": "Czech",
52
- "ckb": "Central Kurdish",
53
- "cmn": "Mandarin Chinese",
54
- "cym": "Welsh",
55
- "dan": "Danish",
56
- "deu": "German",
57
- "ell": "Greek",
58
- "eng": "English",
59
- "est": "Estonian",
60
- "eus": "Basque",
61
- "fin": "Finnish",
62
- "fra": "French",
63
- "gaz": "West Central Oromo",
64
- "gle": "Irish",
65
- "glg": "Galician",
66
- "guj": "Gujarati",
67
- "heb": "Hebrew",
68
- "hin": "Hindi",
69
- "hrv": "Croatian",
70
- "hun": "Hungarian",
71
- "hye": "Armenian",
72
- "ibo": "Igbo",
73
- "ind": "Indonesian",
74
- "isl": "Icelandic",
75
- "ita": "Italian",
76
- "jav": "Javanese",
77
- "jpn": "Japanese",
78
- "kam": "Kamba",
79
- "kan": "Kannada",
80
- "kat": "Georgian",
81
- "kaz": "Kazakh",
82
- "kea": "Kabuverdianu",
83
- "khk": "Halh Mongolian",
84
- "khm": "Khmer",
85
- "kir": "Kyrgyz",
86
- "kor": "Korean",
87
- "lao": "Lao",
88
- "lit": "Lithuanian",
89
- "ltz": "Luxembourgish",
90
- "lug": "Ganda",
91
- "luo": "Luo",
92
- "lvs": "Standard Latvian",
93
- "mai": "Maithili",
94
- "mal": "Malayalam",
95
- "mar": "Marathi",
96
- "mkd": "Macedonian",
97
- "mlt": "Maltese",
98
- "mni": "Meitei",
99
- "mya": "Burmese",
100
- "nld": "Dutch",
101
- "nno": "Norwegian Nynorsk",
102
- "nob": "Norwegian Bokm\u00e5l",
103
- "npi": "Nepali",
104
- "nya": "Nyanja",
105
- "oci": "Occitan",
106
- "ory": "Odia",
107
- "pan": "Punjabi",
108
- "pbt": "Southern Pashto",
109
- "pes": "Western Persian",
110
- "pol": "Polish",
111
- "por": "Portuguese",
112
- "ron": "Romanian",
113
- "rus": "Russian",
114
- "slk": "Slovak",
115
- "slv": "Slovenian",
116
- "sna": "Shona",
117
- "snd": "Sindhi",
118
- "som": "Somali",
119
- "spa": "Spanish",
120
- "srp": "Serbian",
121
- "swe": "Swedish",
122
- "swh": "Swahili",
123
- "tam": "Tamil",
124
- "tel": "Telugu",
125
- "tgk": "Tajik",
126
- "tgl": "Tagalog",
127
- "tha": "Thai",
128
- "tur": "Turkish",
129
- "ukr": "Ukrainian",
130
- "urd": "Urdu",
131
- "uzn": "Northern Uzbek",
132
- "vie": "Vietnamese",
133
- "xho": "Xhosa",
134
- "yor": "Yoruba",
135
- "yue": "Cantonese",
136
- "zlm": "Colloquial Malay",
137
- "zsm": "Standard Malay",
138
- "zul": "Zulu",
139
- }
140
- LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
141
-
142
- # Source langs: S2ST / S2TT / ASR don't need source lang
143
- # T2TT / T2ST use this
144
- text_source_language_codes = [
145
- "afr",
146
- "amh",
147
- "arb",
148
- "ary",
149
- "arz",
150
- "asm",
151
- "azj",
152
- "bel",
153
- "ben",
154
- "bos",
155
- "bul",
156
- "cat",
157
- "ceb",
158
- "ces",
159
- "ckb",
160
- "cmn",
161
- "cym",
162
- "dan",
163
- "deu",
164
- "ell",
165
- "eng",
166
- "est",
167
- "eus",
168
- "fin",
169
- "fra",
170
- "gaz",
171
- "gle",
172
- "glg",
173
- "guj",
174
- "heb",
175
- "hin",
176
- "hrv",
177
- "hun",
178
- "hye",
179
- "ibo",
180
- "ind",
181
- "isl",
182
- "ita",
183
- "jav",
184
- "jpn",
185
- "kan",
186
- "kat",
187
- "kaz",
188
- "khk",
189
- "khm",
190
- "kir",
191
- "kor",
192
- "lao",
193
- "lit",
194
- "lug",
195
- "luo",
196
- "lvs",
197
- "mai",
198
- "mal",
199
- "mar",
200
- "mkd",
201
- "mlt",
202
- "mni",
203
- "mya",
204
- "nld",
205
- "nno",
206
- "nob",
207
- "npi",
208
- "nya",
209
- "ory",
210
- "pan",
211
- "pbt",
212
- "pes",
213
- "pol",
214
- "por",
215
- "ron",
216
- "rus",
217
- "slk",
218
- "slv",
219
- "sna",
220
- "snd",
221
- "som",
222
- "spa",
223
- "srp",
224
- "swe",
225
- "swh",
226
- "tam",
227
- "tel",
228
- "tgk",
229
- "tgl",
230
- "tha",
231
- "tur",
232
- "ukr",
233
- "urd",
234
- "uzn",
235
- "vie",
236
- "yor",
237
- "yue",
238
- "zsm",
239
- "zul",
240
- ]
241
- TEXT_SOURCE_LANGUAGE_NAMES = sorted(
242
- [language_code_to_name[code] for code in text_source_language_codes]
243
- )
244
-
245
- # Target langs:
246
- # S2ST / T2ST
247
- s2st_target_language_codes = [
248
- "eng",
249
- "arb",
250
- "ben",
251
- "cat",
252
- "ces",
253
- "cmn",
254
- "cym",
255
- "dan",
256
- "deu",
257
- "est",
258
- "fin",
259
- "fra",
260
- "hin",
261
- "ind",
262
- "ita",
263
- "jpn",
264
- "kor",
265
- "mlt",
266
- "nld",
267
- "pes",
268
- "pol",
269
- "por",
270
- "ron",
271
- "rus",
272
- "slk",
273
- "spa",
274
- "swe",
275
- "swh",
276
- "tel",
277
- "tgl",
278
- "tha",
279
- "tur",
280
- "ukr",
281
- "urd",
282
- "uzn",
283
- "vie",
284
- ]
285
- S2ST_TARGET_LANGUAGE_NAMES = sorted(
286
- [language_code_to_name[code] for code in s2st_target_language_codes]
287
- )
288
- # S2TT / ASR
289
- S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
290
- # T2TT
291
- T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
292
-
293
- # Download sample input audio files
294
- filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
295
- # for filename in filenames:
296
- # hf_hub_download(
297
- # repo_id="facebook/seamless_m4t",
298
- # repo_type="space",
299
- # filename=filename,
300
- # local_dir=".",
301
- # )
302
-
303
- AUDIO_SAMPLE_RATE = 16000.0
304
- MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
305
- DEFAULT_TARGET_LANGUAGE = "French"
306
-
307
- # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
308
-
309
- def api_predict(
310
- task_name: str,
311
- audio_source: str,
312
- input_audio_mic: str | None,
313
- input_audio_file: str | None,
314
- input_text: str | None,
315
- source_language: str | None,
316
- target_language: str,):
317
-
318
- audio_out, text_out = client.predict(task_name,
319
- audio_source,
320
- input_audio_mic,
321
- input_audio_file,
322
- input_text,
323
- source_language,
324
- target_language,
325
- api_name="/run")
326
- return audio_out, text_out
327
-
328
-
329
-
330
-
331
-
332
- def process_s2st_example(
333
- input_audio_file: str, target_language: str
334
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
335
- return api_predict(
336
- task_name="S2ST",
337
- audio_source="file",
338
- input_audio_mic=None,
339
- input_audio_file=input_audio_file,
340
- input_text=None,
341
- source_language=None,
342
- target_language=target_language,
343
- )
344
-
345
-
346
- def process_s2tt_example(
347
- input_audio_file: str, target_language: str
348
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
349
- return api_predict(
350
- task_name="S2TT",
351
- audio_source="file",
352
- input_audio_mic=None,
353
- input_audio_file=input_audio_file,
354
- input_text=None,
355
- source_language=None,
356
- target_language=target_language,
357
- )
358
-
359
-
360
- def process_t2st_example(
361
- input_text: str, source_language: str, target_language: str
362
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
363
- return api_predict(
364
- task_name="T2ST",
365
- audio_source="",
366
- input_audio_mic=None,
367
- input_audio_file=None,
368
- input_text=input_text,
369
- source_language=source_language,
370
- target_language=target_language,
371
- )
372
-
373
-
374
- def process_t2tt_example(
375
- input_text: str, source_language: str, target_language: str
376
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
377
- return api_predict(
378
- task_name="T2TT",
379
- audio_source="",
380
- input_audio_mic=None,
381
- input_audio_file=None,
382
- input_text=input_text,
383
- source_language=source_language,
384
- target_language=target_language,
385
- )
386
-
387
-
388
- def process_asr_example(
389
- input_audio_file: str, target_language: str
390
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
391
- return api_predict(
392
- task_name="ASR",
393
- audio_source="file",
394
- input_audio_mic=None,
395
- input_audio_file=input_audio_file,
396
- input_text=None,
397
- source_language=None,
398
- target_language=target_language,
399
- )
400
-
401
-
402
- def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
403
- mic = audio_source == "microphone"
404
- return (
405
- gr.update(visible=mic, value=None), # input_audio_mic
406
- gr.update(visible=not mic, value=None), # input_audio_file
407
- )
408
-
409
-
410
- def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
411
- task_name = task_name.split()[0]
412
- if task_name == "S2ST":
413
- return (
414
- gr.update(visible=True), # audio_box
415
- gr.update(visible=False), # input_text
416
- gr.update(visible=False), # source_language
417
- gr.update(
418
- visible=True,
419
- choices=S2ST_TARGET_LANGUAGE_NAMES,
420
- value=DEFAULT_TARGET_LANGUAGE,
421
- ), # target_language
422
- )
423
- elif task_name == "S2TT":
424
- return (
425
- gr.update(visible=True), # audio_box
426
- gr.update(visible=False), # input_text
427
- gr.update(visible=False), # source_language
428
- gr.update(
429
- visible=True,
430
- choices=S2TT_TARGET_LANGUAGE_NAMES,
431
- value=DEFAULT_TARGET_LANGUAGE,
432
- ), # target_language
433
- )
434
- elif task_name == "T2ST":
435
- return (
436
- gr.update(visible=False), # audio_box
437
- gr.update(visible=True), # input_text
438
- gr.update(visible=True), # source_language
439
- gr.update(
440
- visible=True,
441
- choices=S2ST_TARGET_LANGUAGE_NAMES,
442
- value=DEFAULT_TARGET_LANGUAGE,
443
- ), # target_language
444
- )
445
- elif task_name == "T2TT":
446
- return (
447
- gr.update(visible=False), # audio_box
448
- gr.update(visible=True), # input_text
449
- gr.update(visible=True), # source_language
450
- gr.update(
451
- visible=True,
452
- choices=T2TT_TARGET_LANGUAGE_NAMES,
453
- value=DEFAULT_TARGET_LANGUAGE,
454
- ), # target_language
455
- )
456
- elif task_name == "ASR":
457
- return (
458
- gr.update(visible=True), # audio_box
459
- gr.update(visible=False), # input_text
460
- gr.update(visible=False), # source_language
461
- gr.update(
462
- visible=True,
463
- choices=S2TT_TARGET_LANGUAGE_NAMES,
464
- value=DEFAULT_TARGET_LANGUAGE,
465
- ), # target_language
466
- )
467
- else:
468
- raise ValueError(f"Unknown task: {task_name}")
469
-
470
-
471
- def update_output_ui(task_name: str) -> tuple[dict, dict]:
472
- task_name = task_name.split()[0]
473
- if task_name in ["S2ST", "T2ST"]:
474
- return (
475
- gr.update(visible=True, value=None), # output_audio
476
- gr.update(value=None), # output_text
477
- )
478
- elif task_name in ["S2TT", "T2TT", "ASR"]:
479
- return (
480
- gr.update(visible=False, value=None), # output_audio
481
- gr.update(value=None), # output_text
482
- )
483
- else:
484
- raise ValueError(f"Unknown task: {task_name}")
485
-
486
-
487
- def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
488
- task_name = task_name.split()[0]
489
- return (
490
- gr.update(visible=task_name == "S2ST"), # s2st_example_row
491
- gr.update(visible=task_name == "S2TT"), # s2tt_example_row
492
- gr.update(visible=task_name == "T2ST"), # t2st_example_row
493
- gr.update(visible=task_name == "T2TT"), # t2tt_example_row
494
- gr.update(visible=task_name == "ASR"), # asr_example_row
495
- )
496
-
497
-
498
- css = """
499
- h1 {
500
- text-align: center;
501
- }
502
-
503
- .contain {
504
- max-width: 730px;
505
- margin: auto;
506
- padding-top: 1.5rem;
507
- }
508
- """
509
-
510
- with gr.Blocks(css=css) as demo:
511
- gr.Markdown(DESCRIPTION)
512
- with gr.Group():
513
- task_name = gr.Dropdown(
514
- label="Task",
515
- choices=TASK_NAMES,
516
- value=TASK_NAMES[0],
517
- )
518
- with gr.Row():
519
- source_language = gr.Dropdown(
520
- label="Source language",
521
- choices=TEXT_SOURCE_LANGUAGE_NAMES,
522
- value="English",
523
- visible=False,
524
- )
525
- target_language = gr.Dropdown(
526
- label="Target language",
527
- choices=S2ST_TARGET_LANGUAGE_NAMES,
528
- value=DEFAULT_TARGET_LANGUAGE,
529
- )
530
- with gr.Row() as audio_box:
531
- audio_source = gr.Radio(
532
- label="Audio source",
533
- choices=["file", "microphone"],
534
- value="file",
535
- )
536
- input_audio_mic = gr.Audio(
537
- label="Input speech",
538
- type="filepath",
539
- source="microphone",
540
- visible=False,
541
- )
542
- input_audio_file = gr.Audio(
543
- label="Input speech",
544
- type="filepath",
545
- source="upload",
546
- visible=True,
547
- )
548
- input_text = gr.Textbox(label="Input text", visible=False)
549
- with gr.Row():
550
- btn = gr.Button("Translate")
551
- btn_clean = gr.ClearButton([input_audio_mic, input_audio_file])
552
- # gr.Markdown("## Text Examples")
553
- with gr.Column():
554
- output_audio = gr.Audio(
555
- label="Translated speech",
556
- autoplay=False,
557
- streaming=False,
558
- type="numpy",
559
- )
560
- output_text = gr.Textbox(label="Translated text")
561
-
562
- with gr.Row(visible=True) as s2st_example_row:
563
- s2st_examples = gr.Examples(
564
- examples=[
565
- ["assets/sample_input.mp3", "French"],
566
- ["assets/sample_input.mp3", "Mandarin Chinese"],
567
- ["assets/sample_input_2.mp3", "Hindi"],
568
- ["assets/sample_input_2.mp3", "Spanish"],
569
- ],
570
- inputs=[input_audio_file, target_language],
571
- outputs=[output_audio, output_text],
572
- fn=process_s2st_example,
573
- )
574
- with gr.Row(visible=False) as s2tt_example_row:
575
- s2tt_examples = gr.Examples(
576
- examples=[
577
- ["assets/sample_input.mp3", "French"],
578
- ["assets/sample_input.mp3", "Mandarin Chinese"],
579
- ["assets/sample_input_2.mp3", "Hindi"],
580
- ["assets/sample_input_2.mp3", "Spanish"],
581
- ],
582
- inputs=[input_audio_file, target_language],
583
- outputs=[output_audio, output_text],
584
- fn=process_s2tt_example,
585
- )
586
- with gr.Row(visible=False) as t2st_example_row:
587
- t2st_examples = gr.Examples(
588
- examples=[
589
- ["My favorite animal is the elephant.", "English", "French"],
590
- ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
591
- [
592
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
593
- "English",
594
- "Hindi",
595
- ],
596
- [
597
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
598
- "English",
599
- "Spanish",
600
- ],
601
- ],
602
- inputs=[input_text, source_language, target_language],
603
- outputs=[output_audio, output_text],
604
- fn=process_t2st_example,
605
- )
606
- with gr.Row(visible=False) as t2tt_example_row:
607
- t2tt_examples = gr.Examples(
608
- examples=[
609
- ["My favorite animal is the elephant.", "English", "French"],
610
- ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
611
- [
612
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
613
- "English",
614
- "Hindi",
615
- ],
616
- [
617
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
618
- "English",
619
- "Spanish",
620
- ],
621
- ],
622
- inputs=[input_text, source_language, target_language],
623
- outputs=[output_audio, output_text],
624
- fn=process_t2tt_example,
625
- )
626
- with gr.Row(visible=False) as asr_example_row:
627
- asr_examples = gr.Examples(
628
- examples=[
629
- ["assets/sample_input.mp3", "English"],
630
- ["assets/sample_input_2.mp3", "English"],
631
- ],
632
- inputs=[input_audio_file, target_language],
633
- outputs=[output_audio, output_text],
634
- fn=process_asr_example,
635
- )
636
-
637
- audio_source.change(
638
- fn=update_audio_ui,
639
- inputs=audio_source,
640
- outputs=[
641
- input_audio_mic,
642
- input_audio_file,
643
- ],
644
- queue=False,
645
- api_name=False,
646
- )
647
- task_name.change(
648
- fn=update_input_ui,
649
- inputs=task_name,
650
- outputs=[
651
- audio_box,
652
- input_text,
653
- source_language,
654
- target_language,
655
- ],
656
- queue=False,
657
- api_name=False,
658
- ).then(
659
- fn=update_output_ui,
660
- inputs=task_name,
661
- outputs=[output_audio, output_text],
662
- queue=False,
663
- api_name=False,
664
- ).then(
665
- fn=update_example_ui,
666
- inputs=task_name,
667
- outputs=[
668
- s2st_example_row,
669
- s2tt_example_row,
670
- t2st_example_row,
671
- t2tt_example_row,
672
- asr_example_row,
673
- ],
674
- queue=False,
675
- api_name=False,
676
- )
677
-
678
- btn.click(
679
- fn=api_predict,
680
- inputs=[
681
- task_name,
682
- audio_source,
683
- input_audio_mic,
684
- input_audio_file,
685
- input_text,
686
- source_language,
687
- target_language,
688
- ],
689
- outputs=[output_audio, output_text],
690
- api_name="run",
691
- )
692
-
693
- if __name__ == "__main__":
694
- demo.queue().launch()
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from __future__ import annotations
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ # import torch
12
+
13
+
14
+ from gradio_client import Client
15
+
16
+ client = Client("https://facebook-seamless-m4t.hf.space/")
17
+
18
+ DESCRIPTION = """
19
+
20
+ # SM4T
21
+
22
+ Ứng dụng có thể chuyển đổi giọng nói hoặc chữ viết sang giọng nói hoặc chữ viết của một ngôn ngữ khác.
23
+ \nHiện tại SM4T đã hỗ trợ 94 ngôn ngữ khác nhau.
24
+
25
+ """
26
+
27
+ TASK_NAMES = [
28
+ "S2ST (Speech to Speech translation)",
29
+ "S2TT (Speech to Text translation)",
30
+ "T2ST (Text to Speech translation)",
31
+ "T2TT (Text to Text translation)",
32
+ "ASR (Automatic Speech Recognition)",
33
+ ]
34
+
35
+ # Language dict
36
+ language_code_to_name = {
37
+ "afr": "Afrikaans",
38
+ "amh": "Amharic",
39
+ "arb": "Modern Standard Arabic",
40
+ "ary": "Moroccan Arabic",
41
+ "arz": "Egyptian Arabic",
42
+ "asm": "Assamese",
43
+ "ast": "Asturian",
44
+ "azj": "North Azerbaijani",
45
+ "bel": "Belarusian",
46
+ "ben": "Bengali",
47
+ "bos": "Bosnian",
48
+ "bul": "Bulgarian",
49
+ "cat": "Catalan",
50
+ "ceb": "Cebuano",
51
+ "ces": "Czech",
52
+ "ckb": "Central Kurdish",
53
+ "cmn": "Mandarin Chinese",
54
+ "cym": "Welsh",
55
+ "dan": "Danish",
56
+ "deu": "German",
57
+ "ell": "Greek",
58
+ "eng": "English",
59
+ "est": "Estonian",
60
+ "eus": "Basque",
61
+ "fin": "Finnish",
62
+ "fra": "French",
63
+ "gaz": "West Central Oromo",
64
+ "gle": "Irish",
65
+ "glg": "Galician",
66
+ "guj": "Gujarati",
67
+ "heb": "Hebrew",
68
+ "hin": "Hindi",
69
+ "hrv": "Croatian",
70
+ "hun": "Hungarian",
71
+ "hye": "Armenian",
72
+ "ibo": "Igbo",
73
+ "ind": "Indonesian",
74
+ "isl": "Icelandic",
75
+ "ita": "Italian",
76
+ "jav": "Javanese",
77
+ "jpn": "Japanese",
78
+ "kam": "Kamba",
79
+ "kan": "Kannada",
80
+ "kat": "Georgian",
81
+ "kaz": "Kazakh",
82
+ "kea": "Kabuverdianu",
83
+ "khk": "Halh Mongolian",
84
+ "khm": "Khmer",
85
+ "kir": "Kyrgyz",
86
+ "kor": "Korean",
87
+ "lao": "Lao",
88
+ "lit": "Lithuanian",
89
+ "ltz": "Luxembourgish",
90
+ "lug": "Ganda",
91
+ "luo": "Luo",
92
+ "lvs": "Standard Latvian",
93
+ "mai": "Maithili",
94
+ "mal": "Malayalam",
95
+ "mar": "Marathi",
96
+ "mkd": "Macedonian",
97
+ "mlt": "Maltese",
98
+ "mni": "Meitei",
99
+ "mya": "Burmese",
100
+ "nld": "Dutch",
101
+ "nno": "Norwegian Nynorsk",
102
+ "nob": "Norwegian Bokm\u00e5l",
103
+ "npi": "Nepali",
104
+ "nya": "Nyanja",
105
+ "oci": "Occitan",
106
+ "ory": "Odia",
107
+ "pan": "Punjabi",
108
+ "pbt": "Southern Pashto",
109
+ "pes": "Western Persian",
110
+ "pol": "Polish",
111
+ "por": "Portuguese",
112
+ "ron": "Romanian",
113
+ "rus": "Russian",
114
+ "slk": "Slovak",
115
+ "slv": "Slovenian",
116
+ "sna": "Shona",
117
+ "snd": "Sindhi",
118
+ "som": "Somali",
119
+ "spa": "Spanish",
120
+ "srp": "Serbian",
121
+ "swe": "Swedish",
122
+ "swh": "Swahili",
123
+ "tam": "Tamil",
124
+ "tel": "Telugu",
125
+ "tgk": "Tajik",
126
+ "tgl": "Tagalog",
127
+ "tha": "Thai",
128
+ "tur": "Turkish",
129
+ "ukr": "Ukrainian",
130
+ "urd": "Urdu",
131
+ "uzn": "Northern Uzbek",
132
+ "vie": "Vietnamese",
133
+ "xho": "Xhosa",
134
+ "yor": "Yoruba",
135
+ "yue": "Cantonese",
136
+ "zlm": "Colloquial Malay",
137
+ "zsm": "Standard Malay",
138
+ "zul": "Zulu",
139
+ }
140
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
141
+
142
+ # Source langs: S2ST / S2TT / ASR don't need source lang
143
+ # T2TT / T2ST use this
144
+ text_source_language_codes = [
145
+ "afr",
146
+ "amh",
147
+ "arb",
148
+ "ary",
149
+ "arz",
150
+ "asm",
151
+ "azj",
152
+ "bel",
153
+ "ben",
154
+ "bos",
155
+ "bul",
156
+ "cat",
157
+ "ceb",
158
+ "ces",
159
+ "ckb",
160
+ "cmn",
161
+ "cym",
162
+ "dan",
163
+ "deu",
164
+ "ell",
165
+ "eng",
166
+ "est",
167
+ "eus",
168
+ "fin",
169
+ "fra",
170
+ "gaz",
171
+ "gle",
172
+ "glg",
173
+ "guj",
174
+ "heb",
175
+ "hin",
176
+ "hrv",
177
+ "hun",
178
+ "hye",
179
+ "ibo",
180
+ "ind",
181
+ "isl",
182
+ "ita",
183
+ "jav",
184
+ "jpn",
185
+ "kan",
186
+ "kat",
187
+ "kaz",
188
+ "khk",
189
+ "khm",
190
+ "kir",
191
+ "kor",
192
+ "lao",
193
+ "lit",
194
+ "lug",
195
+ "luo",
196
+ "lvs",
197
+ "mai",
198
+ "mal",
199
+ "mar",
200
+ "mkd",
201
+ "mlt",
202
+ "mni",
203
+ "mya",
204
+ "nld",
205
+ "nno",
206
+ "nob",
207
+ "npi",
208
+ "nya",
209
+ "ory",
210
+ "pan",
211
+ "pbt",
212
+ "pes",
213
+ "pol",
214
+ "por",
215
+ "ron",
216
+ "rus",
217
+ "slk",
218
+ "slv",
219
+ "sna",
220
+ "snd",
221
+ "som",
222
+ "spa",
223
+ "srp",
224
+ "swe",
225
+ "swh",
226
+ "tam",
227
+ "tel",
228
+ "tgk",
229
+ "tgl",
230
+ "tha",
231
+ "tur",
232
+ "ukr",
233
+ "urd",
234
+ "uzn",
235
+ "vie",
236
+ "yor",
237
+ "yue",
238
+ "zsm",
239
+ "zul",
240
+ ]
241
+ TEXT_SOURCE_LANGUAGE_NAMES = sorted(
242
+ [language_code_to_name[code] for code in text_source_language_codes]
243
+ )
244
+
245
+ # Target langs:
246
+ # S2ST / T2ST
247
+ s2st_target_language_codes = [
248
+ "eng",
249
+ "arb",
250
+ "ben",
251
+ "cat",
252
+ "ces",
253
+ "cmn",
254
+ "cym",
255
+ "dan",
256
+ "deu",
257
+ "est",
258
+ "fin",
259
+ "fra",
260
+ "hin",
261
+ "ind",
262
+ "ita",
263
+ "jpn",
264
+ "kor",
265
+ "mlt",
266
+ "nld",
267
+ "pes",
268
+ "pol",
269
+ "por",
270
+ "ron",
271
+ "rus",
272
+ "slk",
273
+ "spa",
274
+ "swe",
275
+ "swh",
276
+ "tel",
277
+ "tgl",
278
+ "tha",
279
+ "tur",
280
+ "ukr",
281
+ "urd",
282
+ "uzn",
283
+ "vie",
284
+ ]
285
+ S2ST_TARGET_LANGUAGE_NAMES = sorted(
286
+ [language_code_to_name[code] for code in s2st_target_language_codes]
287
+ )
288
+ # S2TT / ASR
289
+ S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
290
+ # T2TT
291
+ T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
292
+
293
+ # Download sample input audio files
294
+ filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
295
+ # for filename in filenames:
296
+ # hf_hub_download(
297
+ # repo_id="facebook/seamless_m4t",
298
+ # repo_type="space",
299
+ # filename=filename,
300
+ # local_dir=".",
301
+ # )
302
+
303
+ AUDIO_SAMPLE_RATE = 16000.0
304
+ MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
305
+ DEFAULT_TARGET_LANGUAGE = "French"
306
+
307
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
308
+
309
+ def api_predict(
310
+ task_name: str,
311
+ audio_source: str,
312
+ input_audio_mic: str | None,
313
+ input_audio_file: str | None,
314
+ input_text: str | None,
315
+ source_language: str | None,
316
+ target_language: str,):
317
+
318
+ audio_out, text_out = client.predict(task_name,
319
+ audio_source,
320
+ input_audio_mic,
321
+ input_audio_file,
322
+ input_text,
323
+ source_language,
324
+ target_language,
325
+ api_name="/run")
326
+ return audio_out, text_out
327
+
328
+
329
+
330
+
331
+
332
+ def process_s2st_example(
333
+ input_audio_file: str, target_language: str
334
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
335
+ return api_predict(
336
+ task_name="S2ST",
337
+ audio_source="file",
338
+ input_audio_mic=None,
339
+ input_audio_file=input_audio_file,
340
+ input_text=None,
341
+ source_language=None,
342
+ target_language=target_language,
343
+ )
344
+
345
+
346
+ def process_s2tt_example(
347
+ input_audio_file: str, target_language: str
348
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
349
+ return api_predict(
350
+ task_name="S2TT",
351
+ audio_source="file",
352
+ input_audio_mic=None,
353
+ input_audio_file=input_audio_file,
354
+ input_text=None,
355
+ source_language=None,
356
+ target_language=target_language,
357
+ )
358
+
359
+
360
+ def process_t2st_example(
361
+ input_text: str, source_language: str, target_language: str
362
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
363
+ return api_predict(
364
+ task_name="T2ST",
365
+ audio_source="",
366
+ input_audio_mic=None,
367
+ input_audio_file=None,
368
+ input_text=input_text,
369
+ source_language=source_language,
370
+ target_language=target_language,
371
+ )
372
+
373
+
374
+ def process_t2tt_example(
375
+ input_text: str, source_language: str, target_language: str
376
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
377
+ return api_predict(
378
+ task_name="T2TT",
379
+ audio_source="",
380
+ input_audio_mic=None,
381
+ input_audio_file=None,
382
+ input_text=input_text,
383
+ source_language=source_language,
384
+ target_language=target_language,
385
+ )
386
+
387
+
388
+ def process_asr_example(
389
+ input_audio_file: str, target_language: str
390
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
391
+ return api_predict(
392
+ task_name="ASR",
393
+ audio_source="file",
394
+ input_audio_mic=None,
395
+ input_audio_file=input_audio_file,
396
+ input_text=None,
397
+ source_language=None,
398
+ target_language=target_language,
399
+ )
400
+
401
+
402
+ def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
403
+ mic = audio_source == "microphone"
404
+ return (
405
+ gr.update(visible=mic, value=None), # input_audio_mic
406
+ gr.update(visible=not mic, value=None), # input_audio_file
407
+ )
408
+
409
+
410
+ def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
411
+ task_name = task_name.split()[0]
412
+ if task_name == "S2ST":
413
+ return (
414
+ gr.update(visible=True), # audio_box
415
+ gr.update(visible=False), # input_text
416
+ gr.update(visible=False), # source_language
417
+ gr.update(
418
+ visible=True,
419
+ choices=S2ST_TARGET_LANGUAGE_NAMES,
420
+ value=DEFAULT_TARGET_LANGUAGE,
421
+ ), # target_language
422
+ )
423
+ elif task_name == "S2TT":
424
+ return (
425
+ gr.update(visible=True), # audio_box
426
+ gr.update(visible=False), # input_text
427
+ gr.update(visible=False), # source_language
428
+ gr.update(
429
+ visible=True,
430
+ choices=S2TT_TARGET_LANGUAGE_NAMES,
431
+ value=DEFAULT_TARGET_LANGUAGE,
432
+ ), # target_language
433
+ )
434
+ elif task_name == "T2ST":
435
+ return (
436
+ gr.update(visible=False), # audio_box
437
+ gr.update(visible=True), # input_text
438
+ gr.update(visible=True), # source_language
439
+ gr.update(
440
+ visible=True,
441
+ choices=S2ST_TARGET_LANGUAGE_NAMES,
442
+ value=DEFAULT_TARGET_LANGUAGE,
443
+ ), # target_language
444
+ )
445
+ elif task_name == "T2TT":
446
+ return (
447
+ gr.update(visible=False), # audio_box
448
+ gr.update(visible=True), # input_text
449
+ gr.update(visible=True), # source_language
450
+ gr.update(
451
+ visible=True,
452
+ choices=T2TT_TARGET_LANGUAGE_NAMES,
453
+ value=DEFAULT_TARGET_LANGUAGE,
454
+ ), # target_language
455
+ )
456
+ elif task_name == "ASR":
457
+ return (
458
+ gr.update(visible=True), # audio_box
459
+ gr.update(visible=False), # input_text
460
+ gr.update(visible=False), # source_language
461
+ gr.update(
462
+ visible=True,
463
+ choices=S2TT_TARGET_LANGUAGE_NAMES,
464
+ value=DEFAULT_TARGET_LANGUAGE,
465
+ ), # target_language
466
+ )
467
+ else:
468
+ raise ValueError(f"Unknown task: {task_name}")
469
+
470
+
471
+ def update_output_ui(task_name: str) -> tuple[dict, dict]:
472
+ task_name = task_name.split()[0]
473
+ if task_name in ["S2ST", "T2ST"]:
474
+ return (
475
+ gr.update(visible=True, value=None), # output_audio
476
+ gr.update(value=None), # output_text
477
+ )
478
+ elif task_name in ["S2TT", "T2TT", "ASR"]:
479
+ return (
480
+ gr.update(visible=False, value=None), # output_audio
481
+ gr.update(value=None), # output_text
482
+ )
483
+ else:
484
+ raise ValueError(f"Unknown task: {task_name}")
485
+
486
+
487
+ def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
488
+ task_name = task_name.split()[0]
489
+ return (
490
+ gr.update(visible=task_name == "S2ST"), # s2st_example_row
491
+ gr.update(visible=task_name == "S2TT"), # s2tt_example_row
492
+ gr.update(visible=task_name == "T2ST"), # t2st_example_row
493
+ gr.update(visible=task_name == "T2TT"), # t2tt_example_row
494
+ gr.update(visible=task_name == "ASR"), # asr_example_row
495
+ )
496
+
497
+
498
+ css = """
499
+ h1 {
500
+ text-align: center;
501
+ }
502
+
503
+ #.contain {
504
+ # max-width: 730px;
505
+ # margin: auto;
506
+ # padding-top: 1.5rem;
507
+ #}
508
+ """
509
+
510
+ with gr.Blocks(css=css) as demo:
511
+ gr.Markdown(DESCRIPTION)
512
+ with gr.Group():
513
+ task_name = gr.Dropdown(
514
+ label="Task",
515
+ choices=TASK_NAMES,
516
+ value=TASK_NAMES[0],
517
+ )
518
+ with gr.Row():
519
+ source_language = gr.Dropdown(
520
+ label="Source language",
521
+ choices=TEXT_SOURCE_LANGUAGE_NAMES,
522
+ value="English",
523
+ visible=False,
524
+ )
525
+ target_language = gr.Dropdown(
526
+ label="Target language",
527
+ choices=S2ST_TARGET_LANGUAGE_NAMES,
528
+ value=DEFAULT_TARGET_LANGUAGE,
529
+ )
530
+ with gr.Row() as audio_box:
531
+ audio_source = gr.Radio(
532
+ label="Audio source",
533
+ choices=["file", "microphone"],
534
+ value="file",
535
+ )
536
+ input_audio_mic = gr.Audio(
537
+ label="Input speech",
538
+ type="filepath",
539
+ source="microphone",
540
+ visible=False,
541
+ )
542
+ input_audio_file = gr.Audio(
543
+ label="Input speech",
544
+ type="filepath",
545
+ source="upload",
546
+ visible=True,
547
+ )
548
+ input_text = gr.Textbox(label="Input text", visible=False)
549
+ with gr.Row():
550
+ btn = gr.Button("Translate")
551
+ btn_clean = gr.ClearButton([input_audio_mic, input_audio_file])
552
+ # gr.Markdown("## Text Examples")
553
+ with gr.Column():
554
+ output_audio = gr.Audio(
555
+ label="Translated speech",
556
+ autoplay=False,
557
+ streaming=False,
558
+ type="numpy",
559
+ )
560
+ output_text = gr.Textbox(label="Translated text")
561
+
562
+ with gr.Row(visible=True) as s2st_example_row:
563
+ s2st_examples = gr.Examples(
564
+ examples=[
565
+ ["assets/sample_input.mp3", "French"],
566
+ ["assets/sample_input.mp3", "Mandarin Chinese"],
567
+ ["assets/sample_input_2.mp3", "Hindi"],
568
+ ["assets/sample_input_2.mp3", "Spanish"],
569
+ ],
570
+ inputs=[input_audio_file, target_language],
571
+ outputs=[output_audio, output_text],
572
+ fn=process_s2st_example,
573
+ )
574
+ with gr.Row(visible=False) as s2tt_example_row:
575
+ s2tt_examples = gr.Examples(
576
+ examples=[
577
+ ["assets/sample_input.mp3", "French"],
578
+ ["assets/sample_input.mp3", "Mandarin Chinese"],
579
+ ["assets/sample_input_2.mp3", "Hindi"],
580
+ ["assets/sample_input_2.mp3", "Spanish"],
581
+ ],
582
+ inputs=[input_audio_file, target_language],
583
+ outputs=[output_audio, output_text],
584
+ fn=process_s2tt_example,
585
+ )
586
+ with gr.Row(visible=False) as t2st_example_row:
587
+ t2st_examples = gr.Examples(
588
+ examples=[
589
+ ["My favorite animal is the elephant.", "English", "French"],
590
+ ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
591
+ [
592
+ "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
593
+ "English",
594
+ "Hindi",
595
+ ],
596
+ [
597
+ "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
598
+ "English",
599
+ "Spanish",
600
+ ],
601
+ ],
602
+ inputs=[input_text, source_language, target_language],
603
+ outputs=[output_audio, output_text],
604
+ fn=process_t2st_example,
605
+ )
606
+ with gr.Row(visible=False) as t2tt_example_row:
607
+ t2tt_examples = gr.Examples(
608
+ examples=[
609
+ ["My favorite animal is the elephant.", "English", "French"],
610
+ ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
611
+ [
612
+ "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
613
+ "English",
614
+ "Hindi",
615
+ ],
616
+ [
617
+ "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
618
+ "English",
619
+ "Spanish",
620
+ ],
621
+ ],
622
+ inputs=[input_text, source_language, target_language],
623
+ outputs=[output_audio, output_text],
624
+ fn=process_t2tt_example,
625
+ )
626
+ with gr.Row(visible=False) as asr_example_row:
627
+ asr_examples = gr.Examples(
628
+ examples=[
629
+ ["assets/sample_input.mp3", "English"],
630
+ ["assets/sample_input_2.mp3", "English"],
631
+ ],
632
+ inputs=[input_audio_file, target_language],
633
+ outputs=[output_audio, output_text],
634
+ fn=process_asr_example,
635
+ )
636
+
637
+ audio_source.change(
638
+ fn=update_audio_ui,
639
+ inputs=audio_source,
640
+ outputs=[
641
+ input_audio_mic,
642
+ input_audio_file,
643
+ ],
644
+ queue=False,
645
+ api_name=False,
646
+ )
647
+ task_name.change(
648
+ fn=update_input_ui,
649
+ inputs=task_name,
650
+ outputs=[
651
+ audio_box,
652
+ input_text,
653
+ source_language,
654
+ target_language,
655
+ ],
656
+ queue=False,
657
+ api_name=False,
658
+ ).then(
659
+ fn=update_output_ui,
660
+ inputs=task_name,
661
+ outputs=[output_audio, output_text],
662
+ queue=False,
663
+ api_name=False,
664
+ ).then(
665
+ fn=update_example_ui,
666
+ inputs=task_name,
667
+ outputs=[
668
+ s2st_example_row,
669
+ s2tt_example_row,
670
+ t2st_example_row,
671
+ t2tt_example_row,
672
+ asr_example_row,
673
+ ],
674
+ queue=False,
675
+ api_name=False,
676
+ )
677
+
678
+ btn.click(
679
+ fn=api_predict,
680
+ inputs=[
681
+ task_name,
682
+ audio_source,
683
+ input_audio_mic,
684
+ input_audio_file,
685
+ input_text,
686
+ source_language,
687
+ target_language,
688
+ ],
689
+ outputs=[output_audio, output_text],
690
+ api_name="run",
691
+ )
692
+
693
+ if __name__ == "__main__":
694
+ demo.queue().launch()