reach-vb HF staff hysts HF staff commited on
Commit
4936337
1 Parent(s): df28e5a

Add examples and description (#2)

Browse files

- Add examples and description (4ec7561ba4cb2c54cae93fd0e488e9968db8b902)


Co-authored-by: hysts <[email protected]>

Files changed (1) hide show
  1. app.py +153 -8
app.py CHANGED
@@ -14,7 +14,16 @@ from lang_list import (
14
  TEXT_SOURCE_LANGUAGE_NAMES,
15
  )
16
 
17
- DESCRIPTION = "# SeamlessM4T"
 
 
 
 
 
 
 
 
 
18
 
19
  TASK_NAMES = [
20
  "S2ST (Speech to Speech translation)",
@@ -23,10 +32,8 @@ TASK_NAMES = [
23
  "T2TT (Text to Text translation)",
24
  "ASR (Automatic Speech Recognition)",
25
  ]
26
-
27
  AUDIO_SAMPLE_RATE = 16000.0
28
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
29
-
30
  DEFAULT_TARGET_LANGUAGE = "French"
31
 
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -41,14 +48,14 @@ translator = Translator(
41
  def predict(
42
  task_name: str,
43
  audio_source: str,
44
- input_audio_mic: str,
45
- input_audio_file: str,
46
- input_text: str,
47
- source_language: str,
48
  target_language: str,
49
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
50
  task_name = task_name.split()[0]
51
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
52
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
53
 
54
  if task_name in ["S2ST", "S2TT", "ASR"]:
@@ -78,6 +85,66 @@ def predict(
78
  return None, text_out
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
82
  mic = audio_source == "microphone"
83
  return (
@@ -153,6 +220,17 @@ def update_output_ui(task_name: str) -> tuple[dict, dict]:
153
  raise ValueError(f"Unknown task: {task_name}")
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
156
  with gr.Blocks(css="style.css") as demo:
157
  gr.Markdown(DESCRIPTION)
158
  gr.DuplicateButton(
@@ -207,6 +285,61 @@ with gr.Blocks(css="style.css") as demo:
207
  )
208
  output_text = gr.Textbox(label="Translated text")
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  audio_source.change(
211
  fn=update_audio_ui,
212
  inputs=audio_source,
@@ -234,6 +367,18 @@ with gr.Blocks(css="style.css") as demo:
234
  outputs=[output_audio, output_text],
235
  queue=False,
236
  api_name=False,
 
 
 
 
 
 
 
 
 
 
 
 
237
  )
238
 
239
  btn.click(
 
14
  TEXT_SOURCE_LANGUAGE_NAMES,
15
  )
16
 
17
+ DESCRIPTION = """# SeamlessM4T
18
+
19
+ [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
20
+ translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
21
+
22
+ This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
23
+ translation and more, without relying on multiple separate models.
24
+ """
25
+
26
+ CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1"
27
 
28
  TASK_NAMES = [
29
  "S2ST (Speech to Speech translation)",
 
32
  "T2TT (Text to Text translation)",
33
  "ASR (Automatic Speech Recognition)",
34
  ]
 
35
  AUDIO_SAMPLE_RATE = 16000.0
36
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
 
37
  DEFAULT_TARGET_LANGUAGE = "French"
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
48
  def predict(
49
  task_name: str,
50
  audio_source: str,
51
+ input_audio_mic: str | None,
52
+ input_audio_file: str | None,
53
+ input_text: str | None,
54
+ source_language: str | None,
55
  target_language: str,
56
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
57
  task_name = task_name.split()[0]
58
+ source_language_code = LANGUAGE_NAME_TO_CODE.get(source_language, None)
59
  target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
60
 
61
  if task_name in ["S2ST", "S2TT", "ASR"]:
 
85
  return None, text_out
86
 
87
 
88
+ def process_s2st_example(input_audio_file: str, target_language: str) -> tuple[str, str]:
89
+ return predict(
90
+ task_name="S2ST",
91
+ audio_source="file",
92
+ input_audio_mic=None,
93
+ input_audio_file=input_audio_file,
94
+ input_text=None,
95
+ source_language=None,
96
+ target_language=target_language,
97
+ )
98
+
99
+
100
+ def process_s2tt_example(input_audio_file: str, target_language: str) -> tuple[str, str]:
101
+ return predict(
102
+ task_name="S2TT",
103
+ audio_source="file",
104
+ input_audio_mic=None,
105
+ input_audio_file=input_audio_file,
106
+ input_text=None,
107
+ source_language=None,
108
+ target_language=target_language,
109
+ )
110
+
111
+
112
+ def process_t2st_example(input_text: str, source_language: str, target_language: str) -> tuple[str, str]:
113
+ return predict(
114
+ task_name="T2ST",
115
+ audio_source="",
116
+ input_audio_mic=None,
117
+ input_audio_file=None,
118
+ input_text=input_text,
119
+ source_language=source_language,
120
+ target_language=target_language,
121
+ )
122
+
123
+
124
+ def process_t2tt_example(input_text: str, source_language: str, target_language: str) -> tuple[str, str]:
125
+ return predict(
126
+ task_name="T2TT",
127
+ audio_source="",
128
+ input_audio_mic=None,
129
+ input_audio_file=None,
130
+ input_text=input_text,
131
+ source_language=source_language,
132
+ target_language=target_language,
133
+ )
134
+
135
+
136
+ def process_asr_example(input_audio_file: str, target_language: str) -> tuple[str, str]:
137
+ return predict(
138
+ task_name="ASR",
139
+ audio_source="file",
140
+ input_audio_mic=None,
141
+ input_audio_file=input_audio_file,
142
+ input_text=None,
143
+ source_language=None,
144
+ target_language=target_language,
145
+ )
146
+
147
+
148
  def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
149
  mic = audio_source == "microphone"
150
  return (
 
220
  raise ValueError(f"Unknown task: {task_name}")
221
 
222
 
223
+ def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
224
+ task_name = task_name.split()[0]
225
+ return (
226
+ gr.update(visible=task_name == "S2ST"), # s2st_example_row
227
+ gr.update(visible=task_name == "S2TT"), # s2tt_example_row
228
+ gr.update(visible=task_name == "T2ST"), # t2st_example_row
229
+ gr.update(visible=task_name == "T2TT"), # t2tt_example_row
230
+ gr.update(visible=task_name == "ASR"), # asr_example_row
231
+ )
232
+
233
+
234
  with gr.Blocks(css="style.css") as demo:
235
  gr.Markdown(DESCRIPTION)
236
  gr.DuplicateButton(
 
285
  )
286
  output_text = gr.Textbox(label="Translated text")
287
 
288
+ with gr.Row(visible=True) as s2st_example_row:
289
+ s2st_examples = gr.Examples(
290
+ examples=[
291
+ ["assets/sample_input.mp3", "French"],
292
+ ["assets/sample_input.mp3", "Mandarin Chinese"],
293
+ ],
294
+ inputs=[input_audio_file, target_language],
295
+ outputs=[output_audio, output_text],
296
+ fn=process_s2st_example,
297
+ cache_examples=CACHE_EXAMPLES,
298
+ )
299
+ with gr.Row(visible=False) as s2tt_example_row:
300
+ s2tt_examples = gr.Examples(
301
+ examples=[
302
+ ["assets/sample_input.mp3", "French"],
303
+ ["assets/sample_input.mp3", "Mandarin Chinese"],
304
+ ],
305
+ inputs=[input_audio_file, target_language],
306
+ outputs=[output_audio, output_text],
307
+ fn=process_s2tt_example,
308
+ cache_examples=CACHE_EXAMPLES,
309
+ )
310
+ with gr.Row(visible=False) as t2st_example_row:
311
+ t2st_examples = gr.Examples(
312
+ examples=[
313
+ ["My favorite animal is the elephant.", "English", "French"],
314
+ ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
315
+ ],
316
+ inputs=[input_text, source_language, target_language],
317
+ outputs=[output_audio, output_text],
318
+ fn=process_t2st_example,
319
+ cache_examples=CACHE_EXAMPLES,
320
+ )
321
+ with gr.Row(visible=False) as t2tt_example_row:
322
+ t2tt_examples = gr.Examples(
323
+ examples=[
324
+ ["My favorite animal is the elephant.", "English", "French"],
325
+ ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
326
+ ],
327
+ inputs=[input_text, source_language, target_language],
328
+ outputs=[output_audio, output_text],
329
+ fn=process_t2tt_example,
330
+ cache_examples=CACHE_EXAMPLES,
331
+ )
332
+ with gr.Row(visible=False) as asr_example_row:
333
+ asr_examples = gr.Examples(
334
+ examples=[
335
+ ["assets/sample_input.mp3", "English"],
336
+ ],
337
+ inputs=[input_audio_file, target_language],
338
+ outputs=[output_audio, output_text],
339
+ fn=process_asr_example,
340
+ cache_examples=CACHE_EXAMPLES,
341
+ )
342
+
343
  audio_source.change(
344
  fn=update_audio_ui,
345
  inputs=audio_source,
 
367
  outputs=[output_audio, output_text],
368
  queue=False,
369
  api_name=False,
370
+ ).then(
371
+ fn=update_example_ui,
372
+ inputs=task_name,
373
+ outputs=[
374
+ s2st_example_row,
375
+ s2tt_example_row,
376
+ t2st_example_row,
377
+ t2tt_example_row,
378
+ asr_example_row,
379
+ ],
380
+ queue=False,
381
+ api_name=False,
382
  )
383
 
384
  btn.click(