AbdalrhmanRi commited on
Commit
b420626
·
verified ·
1 Parent(s): 72ee2ff

Upload Mistral-Nemo-RAG.py

Browse files
Files changed (1) hide show
  1. Mistral-Nemo-RAG.py +252 -0
Mistral-Nemo-RAG.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import InferenceClient
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain.vectorstores import FAISS
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.memory import ConversationBufferMemory
7
+ from langchain.document_loaders import PyPDFLoader
8
+ import os
9
+ import tempfile
10
+ from deep_translator import GoogleTranslator
11
+ import asyncio
12
+ import uuid
13
+ import logging
14
+ from tenacity import retry, stop_after_attempt, wait_exponential
15
+
16
+ # Set up logging
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+
19
+
20
+ def initialize_session_state():
21
+ if 'generated' not in st.session_state:
22
+ st.session_state['generated'] = []
23
+ if 'past' not in st.session_state:
24
+ st.session_state['past'] = []
25
+ if 'memory' not in st.session_state:
26
+ st.session_state['memory'] = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
27
+ if 'vector_store' not in st.session_state:
28
+ st.session_state['vector_store'] = None
29
+ if 'embeddings' not in st.session_state:
30
+ st.session_state['embeddings'] = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
31
+ model_kwargs={'device': 'cuda'}) # Can use CPU if you want
32
+ if 'translation_states' not in st.session_state:
33
+ st.session_state['translation_states'] = {}
34
+ if 'message_ids' not in st.session_state:
35
+ st.session_state['message_ids'] = []
36
+ if 'is_loading' not in st.session_state:
37
+ st.session_state['is_loading'] = False
38
+
39
+
40
+ async def process_pdf(file):
41
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
42
+ temp_file.write(file.read())
43
+ temp_file_path = temp_file.name
44
+
45
+ loader = PyPDFLoader(temp_file_path)
46
+ text = await asyncio.to_thread(loader.load)
47
+ os.remove(temp_file_path)
48
+
49
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
50
+ text_chunks = await asyncio.to_thread(text_splitter.split_documents, text)
51
+ return text_chunks
52
+
53
+
54
+ async def extract_text_from_pdfs(uploaded_files):
55
+ tasks = [process_pdf(file) for file in uploaded_files]
56
+ results = await asyncio.gather(*tasks)
57
+ return [chunk for result in results for chunk in result]
58
+
59
+
60
+ @st.cache_data(show_spinner=False)
61
+ def translate_text(text, dest_language='ar'):
62
+ translator = GoogleTranslator(source='auto', target=dest_language)
63
+ translation = translator.translate(text)
64
+ return translation
65
+
66
+
67
+ def update_vector_store(new_text_chunks):
68
+ if st.session_state['vector_store']:
69
+ st.session_state['vector_store'].add_documents(new_text_chunks)
70
+ else:
71
+ st.session_state['vector_store'] = FAISS.from_documents(new_text_chunks,
72
+ embedding=st.session_state['embeddings'])
73
+
74
+
75
+ @st.cache_resource
76
+ def get_hf_client():
77
+ return InferenceClient(
78
+ "mistralai/Mistral-Nemo-Instruct-2407",
79
+ token="hf_********************************",
80
+ )
81
+
82
+
83
+ def retrieve_relevant_chunks(query, max_tokens=1000):
84
+ if st.session_state['vector_store']:
85
+ search_results = st.session_state['vector_store'].similarity_search_with_score(query, k=5)
86
+ relevant_chunks = []
87
+ total_tokens = 0
88
+ for doc, score in search_results:
89
+ chunk_tokens = len(doc.page_content.split())
90
+ if total_tokens + chunk_tokens > max_tokens:
91
+ break
92
+ relevant_chunks.append(doc.page_content)
93
+ total_tokens += chunk_tokens
94
+ return "\n".join(relevant_chunks) if relevant_chunks else None
95
+ return None
96
+
97
+
98
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
99
+ def generate_response(query, conversation_context, relevant_chunk=None):
100
+ client = get_hf_client()
101
+ if relevant_chunk:
102
+ full_query = f"Based on the following information:\n{relevant_chunk}\n\nAnswer the question: {query}"
103
+ else:
104
+ full_query = f"{conversation_context}\nUser: {query}"
105
+
106
+ response = ""
107
+ try:
108
+ for message in client.chat_completion(
109
+ messages=[{"role": "user", "content": full_query}],
110
+ max_tokens=800,
111
+ stream=True,
112
+ temperature=0.3
113
+ ):
114
+ response += message.choices[0].delta.content
115
+ except Exception as e:
116
+ logging.error(f"Error generating response: {e}")
117
+ raise
118
+
119
+ return response.strip()
120
+
121
+
122
+ def display_chat_interface():
123
+ for i in range(len(st.session_state['generated'])):
124
+ with st.chat_message("user"):
125
+ st.text(st.session_state["past"][i])
126
+
127
+ with st.chat_message("assistant"):
128
+ st.markdown(st.session_state['generated'][i])
129
+
130
+ if i >= len(st.session_state['message_ids']):
131
+ message_id = str(uuid.uuid4())
132
+ st.session_state['message_ids'].append(message_id)
133
+ else:
134
+ message_id = st.session_state['message_ids'][i]
135
+
136
+ translate_key = f"translate_{message_id}"
137
+
138
+ if translate_key not in st.session_state['translation_states']:
139
+ st.session_state['translation_states'][translate_key] = False
140
+
141
+ if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation,
142
+ args=(translate_key,)):
143
+ pass
144
+
145
+ if st.session_state['translation_states'][translate_key]:
146
+ with st.spinner("Translating..."):
147
+ translated_text = translate_text(st.session_state['generated'][i])
148
+ st.markdown(f"**Translated:** \n\n {translated_text}")
149
+
150
+
151
+ def toggle_translation(translate_key):
152
+ st.session_state['translation_states'][translate_key] = not st.session_state['translation_states'][translate_key]
153
+
154
+
155
+ def get_conversation_context(max_tokens=2000):
156
+ context = []
157
+ total_tokens = 0
158
+ for past, generated in zip(reversed(st.session_state['past']), reversed(st.session_state['generated'])):
159
+ user_message = f"User: {past}\n"
160
+ assistant_message = f"Assistant: {generated}\n"
161
+ message_tokens = len(user_message.split()) + len(assistant_message.split())
162
+
163
+ if total_tokens + message_tokens > max_tokens:
164
+ break
165
+
166
+ context.insert(0, user_message)
167
+ context.insert(1, assistant_message)
168
+ total_tokens += message_tokens
169
+
170
+ return "".join(context)
171
+
172
+
173
+ def validate_input(user_input):
174
+ if not user_input or not user_input.strip():
175
+ return False, "Please enter a valid question or command."
176
+ if len(user_input) > 500:
177
+ return False, "Your input is too long. Please limit your question to 500 characters."
178
+ return True, ""
179
+
180
+
181
+ def process_user_input(user_input):
182
+ user_input = user_input.rstrip()
183
+
184
+ is_valid, error_message = validate_input(user_input)
185
+ if not is_valid:
186
+ st.error(error_message)
187
+ return
188
+
189
+ st.session_state['past'].append(user_input)
190
+
191
+ with st.chat_message("user"):
192
+ st.text(user_input)
193
+
194
+ with st.chat_message("assistant"):
195
+ message_placeholder = st.empty()
196
+ message_placeholder.markdown("⏳ Thinking...")
197
+
198
+ relevant_chunk = retrieve_relevant_chunks(user_input)
199
+ conversation_context = get_conversation_context()
200
+
201
+ try:
202
+ output = generate_response(user_input, conversation_context, relevant_chunk)
203
+ except Exception as e:
204
+ logging.error(f"Failed to generate response after retries: {e}")
205
+ output = "I apologize, but I'm having trouble processing your request at the moment. Please try again later."
206
+
207
+ message_placeholder.empty()
208
+ message_placeholder.markdown(output)
209
+ st.session_state['generated'].append(output)
210
+ st.session_state['memory'].save_context({"input": user_input}, {"output": output})
211
+
212
+ message_id = str(uuid.uuid4())
213
+ st.session_state['message_ids'].append(message_id)
214
+
215
+ translate_key = f"translate_{message_id}"
216
+ st.session_state['translation_states'][translate_key] = False
217
+
218
+ if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation,
219
+ args=(translate_key,)):
220
+ pass
221
+
222
+ if st.session_state['translation_states'][translate_key]:
223
+ with st.spinner("Translating..."):
224
+ translated_text = translate_text(output)
225
+ st.markdown(f"**Translated:** \n\n {translated_text}")
226
+
227
+ st.rerun()
228
+
229
+
230
+ def main():
231
+ initialize_session_state()
232
+ st.title("Chat with PDF Using Mistral AI")
233
+
234
+ uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type="pdf", accept_multiple_files=True)
235
+
236
+ if uploaded_files:
237
+ with st.spinner("Processing PDF files..."):
238
+ loop = asyncio.new_event_loop()
239
+ asyncio.set_event_loop(loop)
240
+ new_text_chunks = loop.run_until_complete(extract_text_from_pdfs(uploaded_files))
241
+ update_vector_store(new_text_chunks)
242
+ st.success("PDF files uploaded and processed successfully.")
243
+
244
+ display_chat_interface()
245
+
246
+ user_input = st.chat_input("Ask about your PDF(s)")
247
+ if user_input:
248
+ process_user_input(user_input)
249
+
250
+
251
+ if __name__ == "__main__":
252
+ main()