Spaces:
Sleeping
Sleeping
commit
Browse files- translate.py +59 -0
translate.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mport streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
4 |
+
|
5 |
+
|
6 |
+
### Getting the Languages supported ####
|
7 |
+
LanguageCovered = "Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI"
|
8 |
+
LanguageCovered = LanguageCovered.split(",")
|
9 |
+
languages_list = [a.strip() for a in LanguageCovered]
|
10 |
+
languages_list = [a.split(" ") for a in languages_list]
|
11 |
+
languages = [a[0] for a in languages_list]
|
12 |
+
codes = [a[1] for a in languages_list]
|
13 |
+
codes = [a.replace('(', '') for a in codes]
|
14 |
+
codes = [a.replace(')', '') for a in codes]
|
15 |
+
lang_dict = dict(zip(languages, codes))
|
16 |
+
|
17 |
+
model_name = "facebook/mbart-large-50-many-to-many-mmt"
|
18 |
+
|
19 |
+
# tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
20 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
|
21 |
+
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
|
22 |
+
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
23 |
+
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
25 |
+
model = model.to(device)
|
26 |
+
|
27 |
+
|
28 |
+
def translate_text(text, source_lang, target_lang):
|
29 |
+
tokenizer.src_lang = source_lang
|
30 |
+
encoded_text = tokenizer(text, return_tensors="pt").to(device)
|
31 |
+
|
32 |
+
generated_tokens = model.generate(**encoded_text, forced_bos_token_id=tokenizer.lang_code_to_id[target_lang])
|
33 |
+
|
34 |
+
#Decode the output
|
35 |
+
translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
36 |
+
return translated_text
|
37 |
+
|
38 |
+
|
39 |
+
st.markdown("### Language Translator")
|
40 |
+
source_language = ''
|
41 |
+
target_language = ''
|
42 |
+
source = st.sidebar.selectbox('Source Language', languages)
|
43 |
+
if source:
|
44 |
+
source_language = lang_dict.get(source)
|
45 |
+
st.write(source_language)
|
46 |
+
|
47 |
+
target = st.sidebar.selectbox('Target Language', languages)
|
48 |
+
if target:
|
49 |
+
target_language = lang_dict.get(target)
|
50 |
+
st.write(target_language)
|
51 |
+
|
52 |
+
with st.form(key="myForm"):
|
53 |
+
text = st.text_area("Enter your text")
|
54 |
+
submit = st.form_submit_button("Submit", type='primary')
|
55 |
+
|
56 |
+
if submit and text and source_language and target_language:
|
57 |
+
with st.spinner(f"{source} to {target} translating"):
|
58 |
+
translation = translate_text(text, source_language, target_language)
|
59 |
+
st.write(translation)
|