translator / app.py
breadlicker45's picture
Update app.py
b72a9d6 verified
import gradio as gr
from transformers import MBartForConditionalGeneration, MBartTokenizer
import torch
# Load the model and tokenizer
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
# Define the language codes supported by the model
language_codes = {
"Arabic": "ar_AR",
"Czech": "cs_CZ",
"German": "de_DE",
"English": "en_XX",
"Spanish": "es_XX",
"Estonian": "et_EE",
"Finnish": "fi_FI",
"French": "fr_XX",
"Gujarati": "gu_IN",
"Hindi": "hi_IN",
"Italian": "it_IT",
"Japanese": "ja_XX",
"Kazakh": "kk_KZ",
"Korean": "ko_KR",
"Lithuanian": "lt_LT",
"Latvian": "lv_LV",
"Burmese": "my_MM",
"Nepali": "ne_NP",
"Dutch": "nl_XX",
"Romanian": "ro_RO",
"Russian": "ru_RU",
"Sinhala": "si_LK",
"Turkish": "tr_TR",
"Vietnamese": "vi_VN",
"Chinese": "zh_CN",
}
def translate(text, src_lang, tgt_lang):
try:
if not text.strip():
return "Please enter some text to translate."
if src_lang == tgt_lang:
return text
# Set the source language
tokenizer.src_lang = language_codes[src_lang]
# Tokenize the input text
encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
# Generate translation
with torch.no_grad():
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.lang_code_to_id[language_codes[tgt_lang]],
max_length=512,
num_beams=5,
length_penalty=1.0
)
# Decode the generated tokens
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return translation
except Exception as e:
return f"Translation error: {str(e)}"
# Create the Gradio interface
demo = gr.Interface(
fn=translate,
inputs=[
gr.Textbox(label="Input Text", placeholder="Enter text to translate..."),
gr.Dropdown(choices=sorted(language_codes.keys()), label="Source Language", value="English"),
gr.Dropdown(choices=sorted(language_codes.keys()), label="Target Language", value="Spanish"),
],
outputs=gr.Textbox(label="Translated Text"),
title="Multilingual Translation with MBart",
description="Translate text between multiple languages using the MBart model.",
examples=[
["Hello, how are you?", "English", "Spanish"],
["Bonjour, comment allez-vous?", "French", "English"],
]
)
demo.launch()