harupurito commited on
Commit
e1f1b19
·
verified ·
1 Parent(s): 26326c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from deep_translator import GoogleTranslator
3
+ from streamlit_mic_recorder import speech_to_text
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
+ from sentence_transformers import SentenceTransformer, util
6
+ import json
7
+ import time
8
+ st.set_page_config(layout="wide")
9
+ # Language dictionaries
10
+ language_dict = {
11
+ 'English': 'en', 'Hindi': 'hi', 'Bengali': 'bn', 'Gujarati': 'gu', 'Marathi': 'mr',
12
+ 'Telugu': 'te', 'Tamil': 'ta', 'Punjabi': 'pa', 'Odia': 'or', 'Nepali': 'ne', 'Malayalam': 'ml'
13
+ }
14
+
15
+ nllb_langs = {
16
+ 'English':'eng_Latn','Hindi':'hin_Deva','Punjabi':'pan_Guru','Odia':'ory_Orya',
17
+ 'Bengali':'ben_Beng','Telugu':'tel_Telu','Tamil':'tam_Taml','Nepali':'npi_Deva',
18
+ 'Marathi':'mar_Deva','Malayalam':'mal_Mlym','Gujarati':'guj_Gujr'
19
+ }
20
+
21
+ CHAT_FILE = "chat_data.json"
22
+
23
+ @st.cache_resource
24
+ def load_nllb_model():
25
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
26
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
27
+ translator = pipeline('translation', model=model, tokenizer=tokenizer)
28
+ return translator
29
+
30
+ @st.cache_resource
31
+ def load_sentence_model():
32
+ return SentenceTransformer("google/muril-base-cased")
33
+
34
+ translator_nllb = load_nllb_model()
35
+ sentence_model = load_sentence_model()
36
+
37
+ def load_messages():
38
+ try:
39
+ with open(CHAT_FILE, "r") as file:
40
+ return json.load(file)
41
+ except (FileNotFoundError, json.JSONDecodeError):
42
+ return []
43
+
44
+ def save_messages(messages):
45
+ with open(CHAT_FILE, "w") as file:
46
+ json.dump(messages, file)
47
+
48
+ def translate_text_multimodel(text, source_lang_name, target_lang_name):
49
+ source_nllb = nllb_langs[source_lang_name]
50
+ target_nllb = nllb_langs[target_lang_name]
51
+
52
+ # NLLB Translation
53
+ translation_nllb = translator_nllb(text, src_lang=source_nllb, tgt_lang=target_nllb)[0]['translation_text']
54
+ print(translation_nllb)
55
+ # Google Translation
56
+ translation_google = GoogleTranslator(source='auto', target=language_dict[target_lang_name]).translate(text)
57
+
58
+ # Cosine similarity comparison
59
+ embedding_original = sentence_model.encode(text, convert_to_tensor=True)
60
+ embedding_nllb = sentence_model.encode(translation_nllb, convert_to_tensor=True)
61
+ embedding_google = sentence_model.encode(translation_google, convert_to_tensor=True)
62
+
63
+ cosine_score_nllb = util.cos_sim(embedding_original, embedding_nllb).item()
64
+ cosine_score_google = util.cos_sim(embedding_original, embedding_google).item()
65
+
66
+ # Select more accurate translation
67
+ if cosine_score_nllb >= cosine_score_google:
68
+ print('nllb')
69
+ return translation_nllb
70
+ else:
71
+ print('gt')
72
+ return translation_google
73
+
74
+ def main():
75
+
76
+ st.title("Multilingual Chat Application with Speech Input")
77
+
78
+ # Sidebar for user setup
79
+ st.sidebar.header("User Setup")
80
+ username = st.sidebar.text_input("Enter your name:")
81
+ language = st.sidebar.selectbox("Choose your language:", list(language_dict.keys()))
82
+
83
+ if not username:
84
+ st.warning("Please enter your name to start chatting.")
85
+ return
86
+
87
+ user_lang_code = language_dict[language]
88
+
89
+ if "messages" not in st.session_state:
90
+ st.session_state["messages"] = load_messages()
91
+
92
+ # Display chat history
93
+ st.subheader("Chat Room")
94
+
95
+ # chat_container = st.container()
96
+
97
+ # with chat_container:
98
+ for msg in st.session_state["messages"]:
99
+ # translated_text = GoogleTranslator(source='auto', target=user_lang_code).translate(msg['text'])
100
+ #translated_text
101
+ with st.chat_message(msg['name']):
102
+ st.write(f"{msg['name']} ({msg['lang']}): {msg['translations'][language]}")
103
+
104
+ # Speech input integration
105
+ st.subheader("Speak your message")
106
+
107
+ spoken_text = speech_to_text(language=user_lang_code, use_container_width=True, just_once=True, key='speech_input')
108
+
109
+ if spoken_text:
110
+ input_text = spoken_text
111
+ translations = {}
112
+ st.write(f"You said: {spoken_text}")
113
+
114
+ if spoken_text:
115
+ for lang in nllb_langs:
116
+ translation = translate_text_multimodel(spoken_text, language, lang)
117
+ translations[lang] = translation
118
+ new_message = {"user": username, "name": username, "lang": language, "text": input_text, "translations": translations}
119
+ st.session_state["messages"].append(new_message)
120
+ save_messages(st.session_state["messages"])
121
+ st.rerun()
122
+ time.sleep(1)
123
+ st.rerun()
124
+
125
+
126
+
127
+
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()