Spaces:
Sleeping
Sleeping
Upload Mistral-Nemo-RAG.py
Browse files- Mistral-Nemo-RAG.py +252 -0
Mistral-Nemo-RAG.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
from langchain.vectorstores import FAISS
|
5 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
6 |
+
from langchain.memory import ConversationBufferMemory
|
7 |
+
from langchain.document_loaders import PyPDFLoader
|
8 |
+
import os
|
9 |
+
import tempfile
|
10 |
+
from deep_translator import GoogleTranslator
|
11 |
+
import asyncio
|
12 |
+
import uuid
|
13 |
+
import logging
|
14 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
15 |
+
|
16 |
+
# Set up logging
|
17 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
18 |
+
|
19 |
+
|
20 |
+
def initialize_session_state():
|
21 |
+
if 'generated' not in st.session_state:
|
22 |
+
st.session_state['generated'] = []
|
23 |
+
if 'past' not in st.session_state:
|
24 |
+
st.session_state['past'] = []
|
25 |
+
if 'memory' not in st.session_state:
|
26 |
+
st.session_state['memory'] = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
27 |
+
if 'vector_store' not in st.session_state:
|
28 |
+
st.session_state['vector_store'] = None
|
29 |
+
if 'embeddings' not in st.session_state:
|
30 |
+
st.session_state['embeddings'] = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
|
31 |
+
model_kwargs={'device': 'cuda'}) # Can use CPU if you want
|
32 |
+
if 'translation_states' not in st.session_state:
|
33 |
+
st.session_state['translation_states'] = {}
|
34 |
+
if 'message_ids' not in st.session_state:
|
35 |
+
st.session_state['message_ids'] = []
|
36 |
+
if 'is_loading' not in st.session_state:
|
37 |
+
st.session_state['is_loading'] = False
|
38 |
+
|
39 |
+
|
40 |
+
async def process_pdf(file):
|
41 |
+
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
42 |
+
temp_file.write(file.read())
|
43 |
+
temp_file_path = temp_file.name
|
44 |
+
|
45 |
+
loader = PyPDFLoader(temp_file_path)
|
46 |
+
text = await asyncio.to_thread(loader.load)
|
47 |
+
os.remove(temp_file_path)
|
48 |
+
|
49 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
50 |
+
text_chunks = await asyncio.to_thread(text_splitter.split_documents, text)
|
51 |
+
return text_chunks
|
52 |
+
|
53 |
+
|
54 |
+
async def extract_text_from_pdfs(uploaded_files):
|
55 |
+
tasks = [process_pdf(file) for file in uploaded_files]
|
56 |
+
results = await asyncio.gather(*tasks)
|
57 |
+
return [chunk for result in results for chunk in result]
|
58 |
+
|
59 |
+
|
60 |
+
@st.cache_data(show_spinner=False)
|
61 |
+
def translate_text(text, dest_language='ar'):
|
62 |
+
translator = GoogleTranslator(source='auto', target=dest_language)
|
63 |
+
translation = translator.translate(text)
|
64 |
+
return translation
|
65 |
+
|
66 |
+
|
67 |
+
def update_vector_store(new_text_chunks):
|
68 |
+
if st.session_state['vector_store']:
|
69 |
+
st.session_state['vector_store'].add_documents(new_text_chunks)
|
70 |
+
else:
|
71 |
+
st.session_state['vector_store'] = FAISS.from_documents(new_text_chunks,
|
72 |
+
embedding=st.session_state['embeddings'])
|
73 |
+
|
74 |
+
|
75 |
+
@st.cache_resource
|
76 |
+
def get_hf_client():
|
77 |
+
return InferenceClient(
|
78 |
+
"mistralai/Mistral-Nemo-Instruct-2407",
|
79 |
+
token="hf_********************************",
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
def retrieve_relevant_chunks(query, max_tokens=1000):
|
84 |
+
if st.session_state['vector_store']:
|
85 |
+
search_results = st.session_state['vector_store'].similarity_search_with_score(query, k=5)
|
86 |
+
relevant_chunks = []
|
87 |
+
total_tokens = 0
|
88 |
+
for doc, score in search_results:
|
89 |
+
chunk_tokens = len(doc.page_content.split())
|
90 |
+
if total_tokens + chunk_tokens > max_tokens:
|
91 |
+
break
|
92 |
+
relevant_chunks.append(doc.page_content)
|
93 |
+
total_tokens += chunk_tokens
|
94 |
+
return "\n".join(relevant_chunks) if relevant_chunks else None
|
95 |
+
return None
|
96 |
+
|
97 |
+
|
98 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
99 |
+
def generate_response(query, conversation_context, relevant_chunk=None):
|
100 |
+
client = get_hf_client()
|
101 |
+
if relevant_chunk:
|
102 |
+
full_query = f"Based on the following information:\n{relevant_chunk}\n\nAnswer the question: {query}"
|
103 |
+
else:
|
104 |
+
full_query = f"{conversation_context}\nUser: {query}"
|
105 |
+
|
106 |
+
response = ""
|
107 |
+
try:
|
108 |
+
for message in client.chat_completion(
|
109 |
+
messages=[{"role": "user", "content": full_query}],
|
110 |
+
max_tokens=800,
|
111 |
+
stream=True,
|
112 |
+
temperature=0.3
|
113 |
+
):
|
114 |
+
response += message.choices[0].delta.content
|
115 |
+
except Exception as e:
|
116 |
+
logging.error(f"Error generating response: {e}")
|
117 |
+
raise
|
118 |
+
|
119 |
+
return response.strip()
|
120 |
+
|
121 |
+
|
122 |
+
def display_chat_interface():
|
123 |
+
for i in range(len(st.session_state['generated'])):
|
124 |
+
with st.chat_message("user"):
|
125 |
+
st.text(st.session_state["past"][i])
|
126 |
+
|
127 |
+
with st.chat_message("assistant"):
|
128 |
+
st.markdown(st.session_state['generated'][i])
|
129 |
+
|
130 |
+
if i >= len(st.session_state['message_ids']):
|
131 |
+
message_id = str(uuid.uuid4())
|
132 |
+
st.session_state['message_ids'].append(message_id)
|
133 |
+
else:
|
134 |
+
message_id = st.session_state['message_ids'][i]
|
135 |
+
|
136 |
+
translate_key = f"translate_{message_id}"
|
137 |
+
|
138 |
+
if translate_key not in st.session_state['translation_states']:
|
139 |
+
st.session_state['translation_states'][translate_key] = False
|
140 |
+
|
141 |
+
if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation,
|
142 |
+
args=(translate_key,)):
|
143 |
+
pass
|
144 |
+
|
145 |
+
if st.session_state['translation_states'][translate_key]:
|
146 |
+
with st.spinner("Translating..."):
|
147 |
+
translated_text = translate_text(st.session_state['generated'][i])
|
148 |
+
st.markdown(f"**Translated:** \n\n {translated_text}")
|
149 |
+
|
150 |
+
|
151 |
+
def toggle_translation(translate_key):
|
152 |
+
st.session_state['translation_states'][translate_key] = not st.session_state['translation_states'][translate_key]
|
153 |
+
|
154 |
+
|
155 |
+
def get_conversation_context(max_tokens=2000):
|
156 |
+
context = []
|
157 |
+
total_tokens = 0
|
158 |
+
for past, generated in zip(reversed(st.session_state['past']), reversed(st.session_state['generated'])):
|
159 |
+
user_message = f"User: {past}\n"
|
160 |
+
assistant_message = f"Assistant: {generated}\n"
|
161 |
+
message_tokens = len(user_message.split()) + len(assistant_message.split())
|
162 |
+
|
163 |
+
if total_tokens + message_tokens > max_tokens:
|
164 |
+
break
|
165 |
+
|
166 |
+
context.insert(0, user_message)
|
167 |
+
context.insert(1, assistant_message)
|
168 |
+
total_tokens += message_tokens
|
169 |
+
|
170 |
+
return "".join(context)
|
171 |
+
|
172 |
+
|
173 |
+
def validate_input(user_input):
|
174 |
+
if not user_input or not user_input.strip():
|
175 |
+
return False, "Please enter a valid question or command."
|
176 |
+
if len(user_input) > 500:
|
177 |
+
return False, "Your input is too long. Please limit your question to 500 characters."
|
178 |
+
return True, ""
|
179 |
+
|
180 |
+
|
181 |
+
def process_user_input(user_input):
|
182 |
+
user_input = user_input.rstrip()
|
183 |
+
|
184 |
+
is_valid, error_message = validate_input(user_input)
|
185 |
+
if not is_valid:
|
186 |
+
st.error(error_message)
|
187 |
+
return
|
188 |
+
|
189 |
+
st.session_state['past'].append(user_input)
|
190 |
+
|
191 |
+
with st.chat_message("user"):
|
192 |
+
st.text(user_input)
|
193 |
+
|
194 |
+
with st.chat_message("assistant"):
|
195 |
+
message_placeholder = st.empty()
|
196 |
+
message_placeholder.markdown("⏳ Thinking...")
|
197 |
+
|
198 |
+
relevant_chunk = retrieve_relevant_chunks(user_input)
|
199 |
+
conversation_context = get_conversation_context()
|
200 |
+
|
201 |
+
try:
|
202 |
+
output = generate_response(user_input, conversation_context, relevant_chunk)
|
203 |
+
except Exception as e:
|
204 |
+
logging.error(f"Failed to generate response after retries: {e}")
|
205 |
+
output = "I apologize, but I'm having trouble processing your request at the moment. Please try again later."
|
206 |
+
|
207 |
+
message_placeholder.empty()
|
208 |
+
message_placeholder.markdown(output)
|
209 |
+
st.session_state['generated'].append(output)
|
210 |
+
st.session_state['memory'].save_context({"input": user_input}, {"output": output})
|
211 |
+
|
212 |
+
message_id = str(uuid.uuid4())
|
213 |
+
st.session_state['message_ids'].append(message_id)
|
214 |
+
|
215 |
+
translate_key = f"translate_{message_id}"
|
216 |
+
st.session_state['translation_states'][translate_key] = False
|
217 |
+
|
218 |
+
if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation,
|
219 |
+
args=(translate_key,)):
|
220 |
+
pass
|
221 |
+
|
222 |
+
if st.session_state['translation_states'][translate_key]:
|
223 |
+
with st.spinner("Translating..."):
|
224 |
+
translated_text = translate_text(output)
|
225 |
+
st.markdown(f"**Translated:** \n\n {translated_text}")
|
226 |
+
|
227 |
+
st.rerun()
|
228 |
+
|
229 |
+
|
230 |
+
def main():
|
231 |
+
initialize_session_state()
|
232 |
+
st.title("Chat with PDF Using Mistral AI")
|
233 |
+
|
234 |
+
uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type="pdf", accept_multiple_files=True)
|
235 |
+
|
236 |
+
if uploaded_files:
|
237 |
+
with st.spinner("Processing PDF files..."):
|
238 |
+
loop = asyncio.new_event_loop()
|
239 |
+
asyncio.set_event_loop(loop)
|
240 |
+
new_text_chunks = loop.run_until_complete(extract_text_from_pdfs(uploaded_files))
|
241 |
+
update_vector_store(new_text_chunks)
|
242 |
+
st.success("PDF files uploaded and processed successfully.")
|
243 |
+
|
244 |
+
display_chat_interface()
|
245 |
+
|
246 |
+
user_input = st.chat_input("Ask about your PDF(s)")
|
247 |
+
if user_input:
|
248 |
+
process_user_input(user_input)
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == "__main__":
|
252 |
+
main()
|