|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import gradio as gr |
|
import numpy as np |
|
|
|
|
|
|
|
from gradio_client import Client |
|
|
|
client = Client("https://facebook-seamless-m4t.hf.space/") |
|
|
|
DESCRIPTION = """ |
|
|
|
# SM4T |
|
|
|
Ứng dụng có thể chuyển đổi giọng nói hoặc chữ viết sang giọng nói hoặc chữ viết của một ngôn ngữ khác. |
|
\nHiện tại SM4T đã hỗ trợ 94 ngôn ngữ khác nhau. |
|
|
|
""" |
|
|
|
TASK_NAMES = [ |
|
"S2ST (Speech to Speech translation)", |
|
"S2TT (Speech to Text translation)", |
|
"T2ST (Text to Speech translation)", |
|
"T2TT (Text to Text translation)", |
|
"ASR (Automatic Speech Recognition)", |
|
] |
|
|
|
|
|
language_code_to_name = { |
|
"afr": "Afrikaans", |
|
"amh": "Amharic", |
|
"arb": "Modern Standard Arabic", |
|
"ary": "Moroccan Arabic", |
|
"arz": "Egyptian Arabic", |
|
"asm": "Assamese", |
|
"ast": "Asturian", |
|
"azj": "North Azerbaijani", |
|
"bel": "Belarusian", |
|
"ben": "Bengali", |
|
"bos": "Bosnian", |
|
"bul": "Bulgarian", |
|
"cat": "Catalan", |
|
"ceb": "Cebuano", |
|
"ces": "Czech", |
|
"ckb": "Central Kurdish", |
|
"cmn": "Mandarin Chinese", |
|
"cym": "Welsh", |
|
"dan": "Danish", |
|
"deu": "German", |
|
"ell": "Greek", |
|
"eng": "English", |
|
"est": "Estonian", |
|
"eus": "Basque", |
|
"fin": "Finnish", |
|
"fra": "French", |
|
"gaz": "West Central Oromo", |
|
"gle": "Irish", |
|
"glg": "Galician", |
|
"guj": "Gujarati", |
|
"heb": "Hebrew", |
|
"hin": "Hindi", |
|
"hrv": "Croatian", |
|
"hun": "Hungarian", |
|
"hye": "Armenian", |
|
"ibo": "Igbo", |
|
"ind": "Indonesian", |
|
"isl": "Icelandic", |
|
"ita": "Italian", |
|
"jav": "Javanese", |
|
"jpn": "Japanese", |
|
"kam": "Kamba", |
|
"kan": "Kannada", |
|
"kat": "Georgian", |
|
"kaz": "Kazakh", |
|
"kea": "Kabuverdianu", |
|
"khk": "Halh Mongolian", |
|
"khm": "Khmer", |
|
"kir": "Kyrgyz", |
|
"kor": "Korean", |
|
"lao": "Lao", |
|
"lit": "Lithuanian", |
|
"ltz": "Luxembourgish", |
|
"lug": "Ganda", |
|
"luo": "Luo", |
|
"lvs": "Standard Latvian", |
|
"mai": "Maithili", |
|
"mal": "Malayalam", |
|
"mar": "Marathi", |
|
"mkd": "Macedonian", |
|
"mlt": "Maltese", |
|
"mni": "Meitei", |
|
"mya": "Burmese", |
|
"nld": "Dutch", |
|
"nno": "Norwegian Nynorsk", |
|
"nob": "Norwegian Bokm\u00e5l", |
|
"npi": "Nepali", |
|
"nya": "Nyanja", |
|
"oci": "Occitan", |
|
"ory": "Odia", |
|
"pan": "Punjabi", |
|
"pbt": "Southern Pashto", |
|
"pes": "Western Persian", |
|
"pol": "Polish", |
|
"por": "Portuguese", |
|
"ron": "Romanian", |
|
"rus": "Russian", |
|
"slk": "Slovak", |
|
"slv": "Slovenian", |
|
"sna": "Shona", |
|
"snd": "Sindhi", |
|
"som": "Somali", |
|
"spa": "Spanish", |
|
"srp": "Serbian", |
|
"swe": "Swedish", |
|
"swh": "Swahili", |
|
"tam": "Tamil", |
|
"tel": "Telugu", |
|
"tgk": "Tajik", |
|
"tgl": "Tagalog", |
|
"tha": "Thai", |
|
"tur": "Turkish", |
|
"ukr": "Ukrainian", |
|
"urd": "Urdu", |
|
"uzn": "Northern Uzbek", |
|
"vie": "Vietnamese", |
|
"xho": "Xhosa", |
|
"yor": "Yoruba", |
|
"yue": "Cantonese", |
|
"zlm": "Colloquial Malay", |
|
"zsm": "Standard Malay", |
|
"zul": "Zulu", |
|
} |
|
LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()} |
|
|
|
|
|
|
|
text_source_language_codes = [ |
|
"afr", |
|
"amh", |
|
"arb", |
|
"ary", |
|
"arz", |
|
"asm", |
|
"azj", |
|
"bel", |
|
"ben", |
|
"bos", |
|
"bul", |
|
"cat", |
|
"ceb", |
|
"ces", |
|
"ckb", |
|
"cmn", |
|
"cym", |
|
"dan", |
|
"deu", |
|
"ell", |
|
"eng", |
|
"est", |
|
"eus", |
|
"fin", |
|
"fra", |
|
"gaz", |
|
"gle", |
|
"glg", |
|
"guj", |
|
"heb", |
|
"hin", |
|
"hrv", |
|
"hun", |
|
"hye", |
|
"ibo", |
|
"ind", |
|
"isl", |
|
"ita", |
|
"jav", |
|
"jpn", |
|
"kan", |
|
"kat", |
|
"kaz", |
|
"khk", |
|
"khm", |
|
"kir", |
|
"kor", |
|
"lao", |
|
"lit", |
|
"lug", |
|
"luo", |
|
"lvs", |
|
"mai", |
|
"mal", |
|
"mar", |
|
"mkd", |
|
"mlt", |
|
"mni", |
|
"mya", |
|
"nld", |
|
"nno", |
|
"nob", |
|
"npi", |
|
"nya", |
|
"ory", |
|
"pan", |
|
"pbt", |
|
"pes", |
|
"pol", |
|
"por", |
|
"ron", |
|
"rus", |
|
"slk", |
|
"slv", |
|
"sna", |
|
"snd", |
|
"som", |
|
"spa", |
|
"srp", |
|
"swe", |
|
"swh", |
|
"tam", |
|
"tel", |
|
"tgk", |
|
"tgl", |
|
"tha", |
|
"tur", |
|
"ukr", |
|
"urd", |
|
"uzn", |
|
"vie", |
|
"yor", |
|
"yue", |
|
"zsm", |
|
"zul", |
|
] |
|
TEXT_SOURCE_LANGUAGE_NAMES = sorted( |
|
[language_code_to_name[code] for code in text_source_language_codes] |
|
) |
|
|
|
|
|
|
|
s2st_target_language_codes = [ |
|
"eng", |
|
"arb", |
|
"ben", |
|
"cat", |
|
"ces", |
|
"cmn", |
|
"cym", |
|
"dan", |
|
"deu", |
|
"est", |
|
"fin", |
|
"fra", |
|
"hin", |
|
"ind", |
|
"ita", |
|
"jpn", |
|
"kor", |
|
"mlt", |
|
"nld", |
|
"pes", |
|
"pol", |
|
"por", |
|
"ron", |
|
"rus", |
|
"slk", |
|
"spa", |
|
"swe", |
|
"swh", |
|
"tel", |
|
"tgl", |
|
"tha", |
|
"tur", |
|
"ukr", |
|
"urd", |
|
"uzn", |
|
"vie", |
|
] |
|
S2ST_TARGET_LANGUAGE_NAMES = sorted( |
|
[language_code_to_name[code] for code in s2st_target_language_codes] |
|
) |
|
|
|
S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES |
|
|
|
T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES |
|
|
|
|
|
filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AUDIO_SAMPLE_RATE = 16000.0 |
|
MAX_INPUT_AUDIO_LENGTH = 60 |
|
DEFAULT_TARGET_LANGUAGE = "French" |
|
|
|
|
|
|
|
def api_predict( |
|
task_name: str, |
|
audio_source: str, |
|
input_audio_mic: str | None, |
|
input_audio_file: str | None, |
|
input_text: str | None, |
|
source_language: str | None, |
|
target_language: str,): |
|
|
|
audio_out, text_out = client.predict(task_name, |
|
audio_source, |
|
input_audio_mic, |
|
input_audio_file, |
|
input_text, |
|
source_language, |
|
target_language, |
|
api_name="/run") |
|
return audio_out, text_out |
|
|
|
|
|
|
|
|
|
|
|
def process_s2st_example( |
|
input_audio_file: str, target_language: str |
|
) -> tuple[tuple[int, np.ndarray] | None, str]: |
|
return api_predict( |
|
task_name="S2ST", |
|
audio_source="file", |
|
input_audio_mic=None, |
|
input_audio_file=input_audio_file, |
|
input_text=None, |
|
source_language=None, |
|
target_language=target_language, |
|
) |
|
|
|
|
|
def process_s2tt_example( |
|
input_audio_file: str, target_language: str |
|
) -> tuple[tuple[int, np.ndarray] | None, str]: |
|
return api_predict( |
|
task_name="S2TT", |
|
audio_source="file", |
|
input_audio_mic=None, |
|
input_audio_file=input_audio_file, |
|
input_text=None, |
|
source_language=None, |
|
target_language=target_language, |
|
) |
|
|
|
|
|
def process_t2st_example( |
|
input_text: str, source_language: str, target_language: str |
|
) -> tuple[tuple[int, np.ndarray] | None, str]: |
|
return api_predict( |
|
task_name="T2ST", |
|
audio_source="", |
|
input_audio_mic=None, |
|
input_audio_file=None, |
|
input_text=input_text, |
|
source_language=source_language, |
|
target_language=target_language, |
|
) |
|
|
|
|
|
def process_t2tt_example( |
|
input_text: str, source_language: str, target_language: str |
|
) -> tuple[tuple[int, np.ndarray] | None, str]: |
|
return api_predict( |
|
task_name="T2TT", |
|
audio_source="", |
|
input_audio_mic=None, |
|
input_audio_file=None, |
|
input_text=input_text, |
|
source_language=source_language, |
|
target_language=target_language, |
|
) |
|
|
|
|
|
def process_asr_example( |
|
input_audio_file: str, target_language: str |
|
) -> tuple[tuple[int, np.ndarray] | None, str]: |
|
return api_predict( |
|
task_name="ASR", |
|
audio_source="file", |
|
input_audio_mic=None, |
|
input_audio_file=input_audio_file, |
|
input_text=None, |
|
source_language=None, |
|
target_language=target_language, |
|
) |
|
|
|
|
|
def update_audio_ui(audio_source: str) -> tuple[dict, dict]: |
|
mic = audio_source == "microphone" |
|
return ( |
|
gr.update(visible=mic, value=None), |
|
gr.update(visible=not mic, value=None), |
|
) |
|
|
|
|
|
def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]: |
|
task_name = task_name.split()[0] |
|
if task_name == "S2ST": |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update( |
|
visible=True, |
|
choices=S2ST_TARGET_LANGUAGE_NAMES, |
|
value=DEFAULT_TARGET_LANGUAGE, |
|
), |
|
) |
|
elif task_name == "S2TT": |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update( |
|
visible=True, |
|
choices=S2TT_TARGET_LANGUAGE_NAMES, |
|
value=DEFAULT_TARGET_LANGUAGE, |
|
), |
|
) |
|
elif task_name == "T2ST": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update( |
|
visible=True, |
|
choices=S2ST_TARGET_LANGUAGE_NAMES, |
|
value=DEFAULT_TARGET_LANGUAGE, |
|
), |
|
) |
|
elif task_name == "T2TT": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update( |
|
visible=True, |
|
choices=T2TT_TARGET_LANGUAGE_NAMES, |
|
value=DEFAULT_TARGET_LANGUAGE, |
|
), |
|
) |
|
elif task_name == "ASR": |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update( |
|
visible=True, |
|
choices=S2TT_TARGET_LANGUAGE_NAMES, |
|
value=DEFAULT_TARGET_LANGUAGE, |
|
), |
|
) |
|
else: |
|
raise ValueError(f"Unknown task: {task_name}") |
|
|
|
|
|
def update_output_ui(task_name: str) -> tuple[dict, dict]: |
|
task_name = task_name.split()[0] |
|
if task_name in ["S2ST", "T2ST"]: |
|
return ( |
|
gr.update(visible=True, value=None), |
|
gr.update(value=None), |
|
) |
|
elif task_name in ["S2TT", "T2TT", "ASR"]: |
|
return ( |
|
gr.update(visible=False, value=None), |
|
gr.update(value=None), |
|
) |
|
else: |
|
raise ValueError(f"Unknown task: {task_name}") |
|
|
|
|
|
def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]: |
|
task_name = task_name.split()[0] |
|
return ( |
|
gr.update(visible=task_name == "S2ST"), |
|
gr.update(visible=task_name == "S2TT"), |
|
gr.update(visible=task_name == "T2ST"), |
|
gr.update(visible=task_name == "T2TT"), |
|
gr.update(visible=task_name == "ASR"), |
|
) |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
} |
|
|
|
#.contain { |
|
# max-width: 730px; |
|
# margin: auto; |
|
# padding-top: 1.5rem; |
|
#} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Group(): |
|
task_name = gr.Dropdown( |
|
label="Task", |
|
choices=TASK_NAMES, |
|
value=TASK_NAMES[0], |
|
) |
|
with gr.Row(): |
|
source_language = gr.Dropdown( |
|
label="Source language", |
|
choices=TEXT_SOURCE_LANGUAGE_NAMES, |
|
value="English", |
|
visible=False, |
|
) |
|
target_language = gr.Dropdown( |
|
label="Target language", |
|
choices=S2ST_TARGET_LANGUAGE_NAMES, |
|
value=DEFAULT_TARGET_LANGUAGE, |
|
) |
|
with gr.Row() as audio_box: |
|
audio_source = gr.Radio( |
|
label="Audio source", |
|
choices=["file", "microphone"], |
|
value="file", |
|
) |
|
input_audio_mic = gr.Audio( |
|
label="Input speech", |
|
type="filepath", |
|
source="microphone", |
|
visible=False, |
|
) |
|
input_audio_file = gr.Audio( |
|
label="Input speech", |
|
type="filepath", |
|
source="upload", |
|
visible=True, |
|
) |
|
input_text = gr.Textbox(label="Input text", visible=False) |
|
with gr.Row(): |
|
btn = gr.Button("Translate") |
|
btn_clean = gr.ClearButton([input_audio_mic, input_audio_file]) |
|
|
|
with gr.Column(): |
|
output_audio = gr.Audio( |
|
label="Translated speech", |
|
autoplay=False, |
|
streaming=False, |
|
type="numpy", |
|
) |
|
output_text = gr.Textbox(label="Translated text") |
|
|
|
with gr.Row(visible=True) as s2st_example_row: |
|
s2st_examples = gr.Examples( |
|
examples=[ |
|
["assets/sample_input.mp3", "French"], |
|
["assets/sample_input.mp3", "Mandarin Chinese"], |
|
["assets/sample_input_2.mp3", "Hindi"], |
|
["assets/sample_input_2.mp3", "Spanish"], |
|
], |
|
inputs=[input_audio_file, target_language], |
|
outputs=[output_audio, output_text], |
|
fn=process_s2st_example, |
|
) |
|
with gr.Row(visible=False) as s2tt_example_row: |
|
s2tt_examples = gr.Examples( |
|
examples=[ |
|
["assets/sample_input.mp3", "French"], |
|
["assets/sample_input.mp3", "Mandarin Chinese"], |
|
["assets/sample_input_2.mp3", "Hindi"], |
|
["assets/sample_input_2.mp3", "Spanish"], |
|
], |
|
inputs=[input_audio_file, target_language], |
|
outputs=[output_audio, output_text], |
|
fn=process_s2tt_example, |
|
) |
|
with gr.Row(visible=False) as t2st_example_row: |
|
t2st_examples = gr.Examples( |
|
examples=[ |
|
["My favorite animal is the elephant.", "English", "French"], |
|
["My favorite animal is the elephant.", "English", "Mandarin Chinese"], |
|
[ |
|
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers", |
|
"English", |
|
"Hindi", |
|
], |
|
[ |
|
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers", |
|
"English", |
|
"Spanish", |
|
], |
|
], |
|
inputs=[input_text, source_language, target_language], |
|
outputs=[output_audio, output_text], |
|
fn=process_t2st_example, |
|
) |
|
with gr.Row(visible=False) as t2tt_example_row: |
|
t2tt_examples = gr.Examples( |
|
examples=[ |
|
["My favorite animal is the elephant.", "English", "French"], |
|
["My favorite animal is the elephant.", "English", "Mandarin Chinese"], |
|
[ |
|
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers", |
|
"English", |
|
"Hindi", |
|
], |
|
[ |
|
"Meta AI's Seamless M4T model is democratising spoken communication across language barriers", |
|
"English", |
|
"Spanish", |
|
], |
|
], |
|
inputs=[input_text, source_language, target_language], |
|
outputs=[output_audio, output_text], |
|
fn=process_t2tt_example, |
|
) |
|
with gr.Row(visible=False) as asr_example_row: |
|
asr_examples = gr.Examples( |
|
examples=[ |
|
["assets/sample_input.mp3", "English"], |
|
["assets/sample_input_2.mp3", "English"], |
|
], |
|
inputs=[input_audio_file, target_language], |
|
outputs=[output_audio, output_text], |
|
fn=process_asr_example, |
|
) |
|
|
|
audio_source.change( |
|
fn=update_audio_ui, |
|
inputs=audio_source, |
|
outputs=[ |
|
input_audio_mic, |
|
input_audio_file, |
|
], |
|
queue=False, |
|
api_name=False, |
|
) |
|
task_name.change( |
|
fn=update_input_ui, |
|
inputs=task_name, |
|
outputs=[ |
|
audio_box, |
|
input_text, |
|
source_language, |
|
target_language, |
|
], |
|
queue=False, |
|
api_name=False, |
|
).then( |
|
fn=update_output_ui, |
|
inputs=task_name, |
|
outputs=[output_audio, output_text], |
|
queue=False, |
|
api_name=False, |
|
).then( |
|
fn=update_example_ui, |
|
inputs=task_name, |
|
outputs=[ |
|
s2st_example_row, |
|
s2tt_example_row, |
|
t2st_example_row, |
|
t2tt_example_row, |
|
asr_example_row, |
|
], |
|
queue=False, |
|
api_name=False, |
|
) |
|
|
|
btn.click( |
|
fn=api_predict, |
|
inputs=[ |
|
task_name, |
|
audio_source, |
|
input_audio_mic, |
|
input_audio_file, |
|
input_text, |
|
source_language, |
|
target_language, |
|
], |
|
outputs=[output_audio, output_text], |
|
api_name="run", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |
|
|