File size: 2,742 Bytes
c239aa3
b72a9d6
a655994
 
c239aa3
 
b72a9d6
c239aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a655994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c239aa3
 
a655994
c239aa3
 
a655994
 
 
c239aa3
 
 
 
a655994
 
 
 
c239aa3
 
a655994
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from transformers import MBartForConditionalGeneration, MBartTokenizer
import torch

# Load the model and tokenizer
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# Define the language codes supported by the model
language_codes = {
    "Arabic": "ar_AR",
    "Czech": "cs_CZ",
    "German": "de_DE",
    "English": "en_XX",
    "Spanish": "es_XX",
    "Estonian": "et_EE",
    "Finnish": "fi_FI",
    "French": "fr_XX",
    "Gujarati": "gu_IN",
    "Hindi": "hi_IN",
    "Italian": "it_IT",
    "Japanese": "ja_XX",
    "Kazakh": "kk_KZ",
    "Korean": "ko_KR",
    "Lithuanian": "lt_LT",
    "Latvian": "lv_LV",
    "Burmese": "my_MM",
    "Nepali": "ne_NP",
    "Dutch": "nl_XX",
    "Romanian": "ro_RO",
    "Russian": "ru_RU",
    "Sinhala": "si_LK",
    "Turkish": "tr_TR",
    "Vietnamese": "vi_VN",
    "Chinese": "zh_CN",
}

def translate(text, src_lang, tgt_lang):
    try:
        if not text.strip():
            return "Please enter some text to translate."
        
        if src_lang == tgt_lang:
            return text
        
        # Set the source language
        tokenizer.src_lang = language_codes[src_lang]
        
        # Tokenize the input text
        encoded = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        
        # Generate translation
        with torch.no_grad():
            generated_tokens = model.generate(
                **encoded,
                forced_bos_token_id=tokenizer.lang_code_to_id[language_codes[tgt_lang]],
                max_length=512,
                num_beams=5,
                length_penalty=1.0
            )
        
        # Decode the generated tokens
        translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
        
        return translation
    except Exception as e:
        return f"Translation error: {str(e)}"

# Create the Gradio interface
demo = gr.Interface(
    fn=translate,
    inputs=[
        gr.Textbox(label="Input Text", placeholder="Enter text to translate..."),
        gr.Dropdown(choices=sorted(language_codes.keys()), label="Source Language", value="English"),
        gr.Dropdown(choices=sorted(language_codes.keys()), label="Target Language", value="Spanish"),
    ],
    outputs=gr.Textbox(label="Translated Text"),
    title="Multilingual Translation with MBart",
    description="Translate text between multiple languages using the MBart model.",
    examples=[
        ["Hello, how are you?", "English", "Spanish"],
        ["Bonjour, comment allez-vous?", "French", "English"],
    ]
)

demo.launch()