Duy-NM commited on
Commit
6fcb961
·
1 Parent(s): 70ea763
Files changed (1) hide show
  1. app.py +33 -62
app.py CHANGED
@@ -8,10 +8,12 @@ from __future__ import annotations
8
 
9
  import gradio as gr
10
  import numpy as np
11
- import torch
12
- import torchaudio
13
- from huggingface_hub import hf_hub_download
14
- from seamless_communication.models.inference.translator import Translator
 
 
15
 
16
  DESCRIPTION = """
17
 
@@ -290,78 +292,47 @@ T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
290
 
291
  # Download sample input audio files
292
  filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
293
- for filename in filenames:
294
- hf_hub_download(
295
- repo_id="facebook/seamless_m4t",
296
- repo_type="space",
297
- filename=filename,
298
- local_dir=".",
299
- )
300
 
301
  AUDIO_SAMPLE_RATE = 16000.0
302
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
303
  DEFAULT_TARGET_LANGUAGE = "French"
304
 
305
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
306
- translator = Translator(
307
- model_name_or_card="seamlessM4T_large",
308
- vocoder_name_or_card="vocoder_36langs",
309
- device=device,
310
- dtype=torch.float16 if "cuda" in device.type else torch.float32,
311
- )
312
 
313
-
314
- def predict(
315
  task_name: str,
316
  audio_source: str,
317
  input_audio_mic: str | None,
318
  input_audio_file: str | None,
319
  input_text: str | None,
320
  source_language: str | None,
321
- target_language: str,
322
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
323
- task_name = task_name.split()[0]
324
- source_language_code = (
325
- LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
326
- )
327
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
 
 
 
 
 
328
 
329
- if task_name in ["S2ST", "S2TT", "ASR"]:
330
- if audio_source == "microphone":
331
- input_data = input_audio_mic
332
- else:
333
- input_data = input_audio_file
334
 
335
- arr, org_sr = torchaudio.load(input_data)
336
- new_arr = torchaudio.functional.resample(
337
- arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE
338
- )
339
- max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
340
- if new_arr.shape[1] > max_length:
341
- new_arr = new_arr[:, :max_length]
342
- gr.Warning(
343
- f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used."
344
- )
345
- torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
346
- else:
347
- input_data = input_text
348
- text_out, wav, sr = translator.predict(
349
- input=input_data,
350
- task_str=task_name,
351
- tgt_lang=target_language_code,
352
- src_lang=source_language_code,
353
- ngram_filtering=True,
354
- )
355
- if task_name in ["S2ST", "T2ST"]:
356
- return (sr, wav.cpu().detach().numpy()), text_out
357
- else:
358
- return None, text_out
359
 
360
 
361
  def process_s2st_example(
362
  input_audio_file: str, target_language: str
363
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
364
- return predict(
365
  task_name="S2ST",
366
  audio_source="file",
367
  input_audio_mic=None,
@@ -375,7 +346,7 @@ def process_s2st_example(
375
  def process_s2tt_example(
376
  input_audio_file: str, target_language: str
377
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
378
- return predict(
379
  task_name="S2TT",
380
  audio_source="file",
381
  input_audio_mic=None,
@@ -389,7 +360,7 @@ def process_s2tt_example(
389
  def process_t2st_example(
390
  input_text: str, source_language: str, target_language: str
391
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
392
- return predict(
393
  task_name="T2ST",
394
  audio_source="",
395
  input_audio_mic=None,
@@ -403,7 +374,7 @@ def process_t2st_example(
403
  def process_t2tt_example(
404
  input_text: str, source_language: str, target_language: str
405
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
406
- return predict(
407
  task_name="T2TT",
408
  audio_source="",
409
  input_audio_mic=None,
@@ -417,7 +388,7 @@ def process_t2tt_example(
417
  def process_asr_example(
418
  input_audio_file: str, target_language: str
419
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
420
- return predict(
421
  task_name="ASR",
422
  audio_source="file",
423
  input_audio_mic=None,
@@ -705,7 +676,7 @@ with gr.Blocks(css=css) as demo:
705
  )
706
 
707
  btn.click(
708
- fn=predict,
709
  inputs=[
710
  task_name,
711
  audio_source,
 
8
 
9
  import gradio as gr
10
  import numpy as np
11
+ # import torch
12
+
13
+
14
+ from gradio_client import Client
15
+
16
+ client = Client("https://facebook-seamless-m4t.hf.space/")
17
 
18
  DESCRIPTION = """
19
 
 
292
 
293
  # Download sample input audio files
294
  filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
295
+ # for filename in filenames:
296
+ # hf_hub_download(
297
+ # repo_id="facebook/seamless_m4t",
298
+ # repo_type="space",
299
+ # filename=filename,
300
+ # local_dir=".",
301
+ # )
302
 
303
  AUDIO_SAMPLE_RATE = 16000.0
304
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
305
  DEFAULT_TARGET_LANGUAGE = "French"
306
 
307
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
308
 
309
+ def api_predict(
 
310
  task_name: str,
311
  audio_source: str,
312
  input_audio_mic: str | None,
313
  input_audio_file: str | None,
314
  input_text: str | None,
315
  source_language: str | None,
316
+ target_language: str,):
317
+
318
+ audio_out, text_out = client.predict(task_name,
319
+ audio_source,
320
+ input_audio_mic,
321
+ input_audio_file,
322
+ input_text,
323
+ source_language,
324
+ target_language,
325
+ api_name="/run")
326
+ return audio_out, text_out
327
+
328
 
 
 
 
 
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
 
332
  def process_s2st_example(
333
  input_audio_file: str, target_language: str
334
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
335
+ return api_predict(
336
  task_name="S2ST",
337
  audio_source="file",
338
  input_audio_mic=None,
 
346
  def process_s2tt_example(
347
  input_audio_file: str, target_language: str
348
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
349
+ return api_predict(
350
  task_name="S2TT",
351
  audio_source="file",
352
  input_audio_mic=None,
 
360
  def process_t2st_example(
361
  input_text: str, source_language: str, target_language: str
362
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
363
+ return api_predict(
364
  task_name="T2ST",
365
  audio_source="",
366
  input_audio_mic=None,
 
374
  def process_t2tt_example(
375
  input_text: str, source_language: str, target_language: str
376
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
377
+ return api_predict(
378
  task_name="T2TT",
379
  audio_source="",
380
  input_audio_mic=None,
 
388
  def process_asr_example(
389
  input_audio_file: str, target_language: str
390
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
391
+ return api_predict(
392
  task_name="ASR",
393
  audio_source="file",
394
  input_audio_mic=None,
 
676
  )
677
 
678
  btn.click(
679
+ fn=api_predict,
680
  inputs=[
681
  task_name,
682
  audio_source,