SM4T / app.py
VTechAI's picture
Update app.py
5a19f99
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import gradio as gr
import numpy as np
# import torch
from gradio_client import Client
client = Client("https://facebook-seamless-m4t.hf.space/")
DESCRIPTION = """
# SM4T
Ứng dụng có thể chuyển đổi giọng nói hoặc chữ viết sang giọng nói hoặc chữ viết của một ngôn ngữ khác.
\nHiện tại SM4T đã hỗ trợ 94 ngôn ngữ khác nhau.
"""
TASK_NAMES = [
"S2ST (Speech to Speech translation)",
"S2TT (Speech to Text translation)",
"T2ST (Text to Speech translation)",
"T2TT (Text to Text translation)",
"ASR (Automatic Speech Recognition)",
]
# Language dict
language_code_to_name = {
"afr": "Afrikaans",
"amh": "Amharic",
"arb": "Modern Standard Arabic",
"ary": "Moroccan Arabic",
"arz": "Egyptian Arabic",
"asm": "Assamese",
"ast": "Asturian",
"azj": "North Azerbaijani",
"bel": "Belarusian",
"ben": "Bengali",
"bos": "Bosnian",
"bul": "Bulgarian",
"cat": "Catalan",
"ceb": "Cebuano",
"ces": "Czech",
"ckb": "Central Kurdish",
"cmn": "Mandarin Chinese",
"cym": "Welsh",
"dan": "Danish",
"deu": "German",
"ell": "Greek",
"eng": "English",
"est": "Estonian",
"eus": "Basque",
"fin": "Finnish",
"fra": "French",
"gaz": "West Central Oromo",
"gle": "Irish",
"glg": "Galician",
"guj": "Gujarati",
"heb": "Hebrew",
"hin": "Hindi",
"hrv": "Croatian",
"hun": "Hungarian",
"hye": "Armenian",
"ibo": "Igbo",
"ind": "Indonesian",
"isl": "Icelandic",
"ita": "Italian",
"jav": "Javanese",
"jpn": "Japanese",
"kam": "Kamba",
"kan": "Kannada",
"kat": "Georgian",
"kaz": "Kazakh",
"kea": "Kabuverdianu",
"khk": "Halh Mongolian",
"khm": "Khmer",
"kir": "Kyrgyz",
"kor": "Korean",
"lao": "Lao",
"lit": "Lithuanian",
"ltz": "Luxembourgish",
"lug": "Ganda",
"luo": "Luo",
"lvs": "Standard Latvian",
"mai": "Maithili",
"mal": "Malayalam",
"mar": "Marathi",
"mkd": "Macedonian",
"mlt": "Maltese",
"mni": "Meitei",
"mya": "Burmese",
"nld": "Dutch",
"nno": "Norwegian Nynorsk",
"nob": "Norwegian Bokm\u00e5l",
"npi": "Nepali",
"nya": "Nyanja",
"oci": "Occitan",
"ory": "Odia",
"pan": "Punjabi",
"pbt": "Southern Pashto",
"pes": "Western Persian",
"pol": "Polish",
"por": "Portuguese",
"ron": "Romanian",
"rus": "Russian",
"slk": "Slovak",
"slv": "Slovenian",
"sna": "Shona",
"snd": "Sindhi",
"som": "Somali",
"spa": "Spanish",
"srp": "Serbian",
"swe": "Swedish",
"swh": "Swahili",
"tam": "Tamil",
"tel": "Telugu",
"tgk": "Tajik",
"tgl": "Tagalog",
"tha": "Thai",
"tur": "Turkish",
"ukr": "Ukrainian",
"urd": "Urdu",
"uzn": "Northern Uzbek",
"vie": "Vietnamese",
"xho": "Xhosa",
"yor": "Yoruba",
"yue": "Cantonese",
"zlm": "Colloquial Malay",
"zsm": "Standard Malay",
"zul": "Zulu",
}
LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
# Source langs: S2ST / S2TT / ASR don't need source lang
# T2TT / T2ST use this
text_source_language_codes = [
"afr",
"amh",
"arb",
"ary",
"arz",
"asm",
"azj",
"bel",
"ben",
"bos",
"bul",
"cat",
"ceb",
"ces",
"ckb",
"cmn",
"cym",
"dan",
"deu",
"ell",
"eng",
"est",
"eus",
"fin",
"fra",
"gaz",
"gle",
"glg",
"guj",
"heb",
"hin",
"hrv",
"hun",
"hye",
"ibo",
"ind",
"isl",
"ita",
"jav",
"jpn",
"kan",
"kat",
"kaz",
"khk",
"khm",
"kir",
"kor",
"lao",
"lit",
"lug",
"luo",
"lvs",
"mai",
"mal",
"mar",
"mkd",
"mlt",
"mni",
"mya",
"nld",
"nno",
"nob",
"npi",
"nya",
"ory",
"pan",
"pbt",
"pes",
"pol",
"por",
"ron",
"rus",
"slk",
"slv",
"sna",
"snd",
"som",
"spa",
"srp",
"swe",
"swh",
"tam",
"tel",
"tgk",
"tgl",
"tha",
"tur",
"ukr",
"urd",
"uzn",
"vie",
"yor",
"yue",
"zsm",
"zul",
]
TEXT_SOURCE_LANGUAGE_NAMES = sorted(
[language_code_to_name[code] for code in text_source_language_codes]
)
# Target langs:
# S2ST / T2ST
s2st_target_language_codes = [
"eng",
"arb",
"ben",
"cat",
"ces",
"cmn",
"cym",
"dan",
"deu",
"est",
"fin",
"fra",
"hin",
"ind",
"ita",
"jpn",
"kor",
"mlt",
"nld",
"pes",
"pol",
"por",
"ron",
"rus",
"slk",
"spa",
"swe",
"swh",
"tel",
"tgl",
"tha",
"tur",
"ukr",
"urd",
"uzn",
"vie",
]
S2ST_TARGET_LANGUAGE_NAMES = sorted(
[language_code_to_name[code] for code in s2st_target_language_codes]
)
# S2TT / ASR
S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
# T2TT
T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
# Download sample input audio files
filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
# for filename in filenames:
# hf_hub_download(
# repo_id="facebook/seamless_m4t",
# repo_type="space",
# filename=filename,
# local_dir=".",
# )
AUDIO_SAMPLE_RATE = 16000.0
MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
DEFAULT_TARGET_LANGUAGE = "French"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def api_predict(
task_name: str,
audio_source: str,
input_audio_mic: str | None,
input_audio_file: str | None,
input_text: str | None,
source_language: str | None,
target_language: str,):
audio_out, text_out = client.predict(task_name,
audio_source,
input_audio_mic,
input_audio_file,
input_text,
source_language,
target_language,
api_name="/run")
return audio_out, text_out
def process_s2st_example(
input_audio_file: str, target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
return api_predict(
task_name="S2ST",
audio_source="file",
input_audio_mic=None,
input_audio_file=input_audio_file,
input_text=None,
source_language=None,
target_language=target_language,
)
def process_s2tt_example(
input_audio_file: str, target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
return api_predict(
task_name="S2TT",
audio_source="file",
input_audio_mic=None,
input_audio_file=input_audio_file,
input_text=None,
source_language=None,
target_language=target_language,
)
def process_t2st_example(
input_text: str, source_language: str, target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
return api_predict(
task_name="T2ST",
audio_source="",
input_audio_mic=None,
input_audio_file=None,
input_text=input_text,
source_language=source_language,
target_language=target_language,
)
def process_t2tt_example(
input_text: str, source_language: str, target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
return api_predict(
task_name="T2TT",
audio_source="",
input_audio_mic=None,
input_audio_file=None,
input_text=input_text,
source_language=source_language,
target_language=target_language,
)
def process_asr_example(
input_audio_file: str, target_language: str
) -> tuple[tuple[int, np.ndarray] | None, str]:
return api_predict(
task_name="ASR",
audio_source="file",
input_audio_mic=None,
input_audio_file=input_audio_file,
input_text=None,
source_language=None,
target_language=target_language,
)
def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
mic = audio_source == "microphone"
return (
gr.update(visible=mic, value=None), # input_audio_mic
gr.update(visible=not mic, value=None), # input_audio_file
)
def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
task_name = task_name.split()[0]
if task_name == "S2ST":
return (
gr.update(visible=True), # audio_box
gr.update(visible=False), # input_text
gr.update(visible=False), # source_language
gr.update(
visible=True,
choices=S2ST_TARGET_LANGUAGE_NAMES,
value=DEFAULT_TARGET_LANGUAGE,
), # target_language
)
elif task_name == "S2TT":
return (
gr.update(visible=True), # audio_box
gr.update(visible=False), # input_text
gr.update(visible=False), # source_language
gr.update(
visible=True,
choices=S2TT_TARGET_LANGUAGE_NAMES,
value=DEFAULT_TARGET_LANGUAGE,
), # target_language
)
elif task_name == "T2ST":
return (
gr.update(visible=False), # audio_box
gr.update(visible=True), # input_text
gr.update(visible=True), # source_language
gr.update(
visible=True,
choices=S2ST_TARGET_LANGUAGE_NAMES,
value=DEFAULT_TARGET_LANGUAGE,
), # target_language
)
elif task_name == "T2TT":
return (
gr.update(visible=False), # audio_box
gr.update(visible=True), # input_text
gr.update(visible=True), # source_language
gr.update(
visible=True,
choices=T2TT_TARGET_LANGUAGE_NAMES,
value=DEFAULT_TARGET_LANGUAGE,
), # target_language
)
elif task_name == "ASR":
return (
gr.update(visible=True), # audio_box
gr.update(visible=False), # input_text
gr.update(visible=False), # source_language
gr.update(
visible=True,
choices=S2TT_TARGET_LANGUAGE_NAMES,
value=DEFAULT_TARGET_LANGUAGE,
), # target_language
)
else:
raise ValueError(f"Unknown task: {task_name}")
def update_output_ui(task_name: str) -> tuple[dict, dict]:
task_name = task_name.split()[0]
if task_name in ["S2ST", "T2ST"]:
return (
gr.update(visible=True, value=None), # output_audio
gr.update(value=None), # output_text
)
elif task_name in ["S2TT", "T2TT", "ASR"]:
return (
gr.update(visible=False, value=None), # output_audio
gr.update(value=None), # output_text
)
else:
raise ValueError(f"Unknown task: {task_name}")
def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
task_name = task_name.split()[0]
return (
gr.update(visible=task_name == "S2ST"), # s2st_example_row
gr.update(visible=task_name == "S2TT"), # s2tt_example_row
gr.update(visible=task_name == "T2ST"), # t2st_example_row
gr.update(visible=task_name == "T2TT"), # t2tt_example_row
gr.update(visible=task_name == "ASR"), # asr_example_row
)
css = """
h1 {
text-align: center;
}
#.contain {
# max-width: 730px;
# margin: auto;
# padding-top: 1.5rem;
#}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
task_name = gr.Dropdown(
label="Task",
choices=TASK_NAMES,
value=TASK_NAMES[0],
)
with gr.Row():
source_language = gr.Dropdown(
label="Source language",
choices=TEXT_SOURCE_LANGUAGE_NAMES,
value="English",
visible=False,
)
target_language = gr.Dropdown(
label="Target language",
choices=S2ST_TARGET_LANGUAGE_NAMES,
value=DEFAULT_TARGET_LANGUAGE,
)
with gr.Row() as audio_box:
audio_source = gr.Radio(
label="Audio source",
choices=["file", "microphone"],
value="file",
)
input_audio_mic = gr.Audio(
label="Input speech",
type="filepath",
source="microphone",
visible=False,
)
input_audio_file = gr.Audio(
label="Input speech",
type="filepath",
source="upload",
visible=True,
)
input_text = gr.Textbox(label="Input text", visible=False)
with gr.Row():
btn = gr.Button("Translate")
btn_clean = gr.ClearButton([input_audio_mic, input_audio_file])
# gr.Markdown("## Text Examples")
with gr.Column():
output_audio = gr.Audio(
label="Translated speech",
autoplay=False,
streaming=False,
type="numpy",
)
output_text = gr.Textbox(label="Translated text")
with gr.Row(visible=True) as s2st_example_row:
s2st_examples = gr.Examples(
examples=[
["assets/sample_input.mp3", "French"],
["assets/sample_input.mp3", "Mandarin Chinese"],
["assets/sample_input_2.mp3", "Hindi"],
["assets/sample_input_2.mp3", "Spanish"],
],
inputs=[input_audio_file, target_language],
outputs=[output_audio, output_text],
fn=process_s2st_example,
)
with gr.Row(visible=False) as s2tt_example_row:
s2tt_examples = gr.Examples(
examples=[
["assets/sample_input.mp3", "French"],
["assets/sample_input.mp3", "Mandarin Chinese"],
["assets/sample_input_2.mp3", "Hindi"],
["assets/sample_input_2.mp3", "Spanish"],
],
inputs=[input_audio_file, target_language],
outputs=[output_audio, output_text],
fn=process_s2tt_example,
)
with gr.Row(visible=False) as t2st_example_row:
t2st_examples = gr.Examples(
examples=[
["My favorite animal is the elephant.", "English", "French"],
["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
[
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
"English",
"Hindi",
],
[
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
"English",
"Spanish",
],
],
inputs=[input_text, source_language, target_language],
outputs=[output_audio, output_text],
fn=process_t2st_example,
)
with gr.Row(visible=False) as t2tt_example_row:
t2tt_examples = gr.Examples(
examples=[
["My favorite animal is the elephant.", "English", "French"],
["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
[
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
"English",
"Hindi",
],
[
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
"English",
"Spanish",
],
],
inputs=[input_text, source_language, target_language],
outputs=[output_audio, output_text],
fn=process_t2tt_example,
)
with gr.Row(visible=False) as asr_example_row:
asr_examples = gr.Examples(
examples=[
["assets/sample_input.mp3", "English"],
["assets/sample_input_2.mp3", "English"],
],
inputs=[input_audio_file, target_language],
outputs=[output_audio, output_text],
fn=process_asr_example,
)
audio_source.change(
fn=update_audio_ui,
inputs=audio_source,
outputs=[
input_audio_mic,
input_audio_file,
],
queue=False,
api_name=False,
)
task_name.change(
fn=update_input_ui,
inputs=task_name,
outputs=[
audio_box,
input_text,
source_language,
target_language,
],
queue=False,
api_name=False,
).then(
fn=update_output_ui,
inputs=task_name,
outputs=[output_audio, output_text],
queue=False,
api_name=False,
).then(
fn=update_example_ui,
inputs=task_name,
outputs=[
s2st_example_row,
s2tt_example_row,
t2st_example_row,
t2tt_example_row,
asr_example_row,
],
queue=False,
api_name=False,
)
btn.click(
fn=api_predict,
inputs=[
task_name,
audio_source,
input_audio_mic,
input_audio_file,
input_text,
source_language,
target_language,
],
outputs=[output_audio, output_text],
api_name="run",
)
if __name__ == "__main__":
demo.queue().launch()