Temuzin64 commited on
Commit
d59893d
·
verified ·
1 Parent(s): c42f54c
Files changed (1) hide show
  1. translate.py +59 -0
translate.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mport streamlit as st
2
+ import torch
3
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
+
5
+
6
+ ### Getting the Languages supported ####
7
+ LanguageCovered = "Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI"
8
+ LanguageCovered = LanguageCovered.split(",")
9
+ languages_list = [a.strip() for a in LanguageCovered]
10
+ languages_list = [a.split(" ") for a in languages_list]
11
+ languages = [a[0] for a in languages_list]
12
+ codes = [a[1] for a in languages_list]
13
+ codes = [a.replace('(', '') for a in codes]
14
+ codes = [a.replace(')', '') for a in codes]
15
+ lang_dict = dict(zip(languages, codes))
16
+
17
+ model_name = "facebook/mbart-large-50-many-to-many-mmt"
18
+
19
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
20
+ # model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
21
+ tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
22
+ model = MBartForConditionalGeneration.from_pretrained(model_name)
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
25
+ model = model.to(device)
26
+
27
+
28
+ def translate_text(text, source_lang, target_lang):
29
+ tokenizer.src_lang = source_lang
30
+ encoded_text = tokenizer(text, return_tensors="pt").to(device)
31
+
32
+ generated_tokens = model.generate(**encoded_text, forced_bos_token_id=tokenizer.lang_code_to_id[target_lang])
33
+
34
+ #Decode the output
35
+ translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
36
+ return translated_text
37
+
38
+
39
+ st.markdown("### Language Translator")
40
+ source_language = ''
41
+ target_language = ''
42
+ source = st.sidebar.selectbox('Source Language', languages)
43
+ if source:
44
+ source_language = lang_dict.get(source)
45
+ st.write(source_language)
46
+
47
+ target = st.sidebar.selectbox('Target Language', languages)
48
+ if target:
49
+ target_language = lang_dict.get(target)
50
+ st.write(target_language)
51
+
52
+ with st.form(key="myForm"):
53
+ text = st.text_area("Enter your text")
54
+ submit = st.form_submit_button("Submit", type='primary')
55
+
56
+ if submit and text and source_language and target_language:
57
+ with st.spinner(f"{source} to {target} translating"):
58
+ translation = translate_text(text, source_language, target_language)
59
+ st.write(translation)