Spaces:
Sleeping
Sleeping
commit
Browse files- translate.py +59 -0
translate.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mport streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
### Getting the Languages supported ####
|
| 7 |
+
LanguageCovered = "Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI"
|
| 8 |
+
LanguageCovered = LanguageCovered.split(",")
|
| 9 |
+
languages_list = [a.strip() for a in LanguageCovered]
|
| 10 |
+
languages_list = [a.split(" ") for a in languages_list]
|
| 11 |
+
languages = [a[0] for a in languages_list]
|
| 12 |
+
codes = [a[1] for a in languages_list]
|
| 13 |
+
codes = [a.replace('(', '') for a in codes]
|
| 14 |
+
codes = [a.replace(')', '') for a in codes]
|
| 15 |
+
lang_dict = dict(zip(languages, codes))
|
| 16 |
+
|
| 17 |
+
model_name = "facebook/mbart-large-50-many-to-many-mmt"
|
| 18 |
+
|
| 19 |
+
# tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
| 20 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
| 21 |
+
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
|
| 22 |
+
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
| 23 |
+
|
| 24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
| 25 |
+
model = model.to(device)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def translate_text(text, source_lang, target_lang):
|
| 29 |
+
tokenizer.src_lang = source_lang
|
| 30 |
+
encoded_text = tokenizer(text, return_tensors="pt").to(device)
|
| 31 |
+
|
| 32 |
+
generated_tokens = model.generate(**encoded_text, forced_bos_token_id=tokenizer.lang_code_to_id[target_lang])
|
| 33 |
+
|
| 34 |
+
#Decode the output
|
| 35 |
+
translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
| 36 |
+
return translated_text
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
st.markdown("### Language Translator")
|
| 40 |
+
source_language = ''
|
| 41 |
+
target_language = ''
|
| 42 |
+
source = st.sidebar.selectbox('Source Language', languages)
|
| 43 |
+
if source:
|
| 44 |
+
source_language = lang_dict.get(source)
|
| 45 |
+
st.write(source_language)
|
| 46 |
+
|
| 47 |
+
target = st.sidebar.selectbox('Target Language', languages)
|
| 48 |
+
if target:
|
| 49 |
+
target_language = lang_dict.get(target)
|
| 50 |
+
st.write(target_language)
|
| 51 |
+
|
| 52 |
+
with st.form(key="myForm"):
|
| 53 |
+
text = st.text_area("Enter your text")
|
| 54 |
+
submit = st.form_submit_button("Submit", type='primary')
|
| 55 |
+
|
| 56 |
+
if submit and text and source_language and target_language:
|
| 57 |
+
with st.spinner(f"{source} to {target} translating"):
|
| 58 |
+
translation = translate_text(text, source_language, target_language)
|
| 59 |
+
st.write(translation)
|