Spaces:
Running
Running
import gradio as gr | |
from transformers import MBartForConditionalGeneration, MBartTokenizer | |
import torch | |
# Load the model and tokenizer | |
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") | |
# Define the language codes supported by the model | |
language_codes = { | |
"Arabic": "ar_AR", | |
"Czech": "cs_CZ", | |
"German": "de_DE", | |
"English": "en_XX", | |
"Spanish": "es_XX", | |
"Estonian": "et_EE", | |
"Finnish": "fi_FI", | |
"French": "fr_XX", | |
"Gujarati": "gu_IN", | |
"Hindi": "hi_IN", | |
"Italian": "it_IT", | |
"Japanese": "ja_XX", | |
"Kazakh": "kk_KZ", | |
"Korean": "ko_KR", | |
"Lithuanian": "lt_LT", | |
"Latvian": "lv_LV", | |
"Burmese": "my_MM", | |
"Nepali": "ne_NP", | |
"Dutch": "nl_XX", | |
"Romanian": "ro_RO", | |
"Russian": "ru_RU", | |
"Sinhala": "si_LK", | |
"Turkish": "tr_TR", | |
"Vietnamese": "vi_VN", | |
"Chinese": "zh_CN", | |
} | |
def translate(text, src_lang, tgt_lang): | |
try: | |
if not text.strip(): | |
return "Please enter some text to translate." | |
if src_lang == tgt_lang: | |
return text | |
# Set the source language | |
tokenizer.src_lang = language_codes[src_lang] | |
# Tokenize the input text | |
encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
# Generate translation | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**encoded, | |
forced_bos_token_id=tokenizer.lang_code_to_id[language_codes[tgt_lang]], | |
max_length=512, | |
num_beams=5, | |
length_penalty=1.0 | |
) | |
# Decode the generated tokens | |
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
return translation | |
except Exception as e: | |
return f"Translation error: {str(e)}" | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=translate, | |
inputs=[ | |
gr.Textbox(label="Input Text", placeholder="Enter text to translate..."), | |
gr.Dropdown(choices=sorted(language_codes.keys()), label="Source Language", value="English"), | |
gr.Dropdown(choices=sorted(language_codes.keys()), label="Target Language", value="Spanish"), | |
], | |
outputs=gr.Textbox(label="Translated Text"), | |
title="Multilingual Translation with MBart", | |
description="Translate text between multiple languages using the MBart model.", | |
examples=[ | |
["Hello, how are you?", "English", "Spanish"], | |
["Bonjour, comment allez-vous?", "French", "English"], | |
] | |
) | |
demo.launch() |