import streamlit as st from huggingface_hub import InferenceClient from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain.memory import ConversationBufferMemory from langchain.document_loaders import PyPDFLoader import os import tempfile from deep_translator import GoogleTranslator import asyncio import uuid import logging from tenacity import retry, stop_after_attempt, wait_exponential # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def initialize_session_state(): if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] if 'memory' not in st.session_state: st.session_state['memory'] = ConversationBufferMemory(memory_key="chat_history", return_messages=True) if 'vector_store' not in st.session_state: st.session_state['vector_store'] = None if 'embeddings' not in st.session_state: st.session_state['embeddings'] = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={'device': 'cpu'}) # Can use CUDA if you want on your device if 'translation_states' not in st.session_state: st.session_state['translation_states'] = {} if 'message_ids' not in st.session_state: st.session_state['message_ids'] = [] if 'is_loading' not in st.session_state: st.session_state['is_loading'] = False async def process_pdf(file): with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(file.read()) temp_file_path = temp_file.name loader = PyPDFLoader(temp_file_path) text = await asyncio.to_thread(loader.load) os.remove(temp_file_path) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) text_chunks = await asyncio.to_thread(text_splitter.split_documents, text) return text_chunks async def extract_text_from_pdfs(uploaded_files): tasks = [process_pdf(file) for file in uploaded_files] results = await asyncio.gather(*tasks) return [chunk for result in results for chunk in result] @st.cache_data(show_spinner=False) def translate_text(text, dest_language='ar'): translator = GoogleTranslator(source='auto', target=dest_language) translation = translator.translate(text) return translation def update_vector_store(new_text_chunks): if st.session_state['vector_store']: st.session_state['vector_store'].add_documents(new_text_chunks) else: st.session_state['vector_store'] = FAISS.from_documents(new_text_chunks, embedding=st.session_state['embeddings']) @st.cache_resource def get_hf_client(): return InferenceClient( "mistralai/Mistral-Nemo-Instruct-2407", token="hf_********************************" ) def retrieve_relevant_chunks(query, max_tokens=1000): if st.session_state['vector_store']: search_results = st.session_state['vector_store'].similarity_search_with_score(query, k=5) relevant_chunks = [] total_tokens = 0 for doc, score in search_results: chunk_tokens = len(doc.page_content.split()) if total_tokens + chunk_tokens > max_tokens: break relevant_chunks.append(doc.page_content) total_tokens += chunk_tokens return "\n".join(relevant_chunks) if relevant_chunks else None return None @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) def generate_response(query, conversation_context, relevant_chunk=None): client = get_hf_client() if relevant_chunk: full_query = f"Based on the following information:\n{relevant_chunk}\n\nAnswer the question: {query}" else: full_query = f"{conversation_context}\nUser: {query}" response = "" try: for message in client.chat_completion( messages=[{"role": "user", "content": full_query}], max_tokens=800, stream=True, temperature=0.3 ): response += message.choices[0].delta.content except Exception as e: logging.error(f"Error generating response: {e}") raise return response.strip() def display_chat_interface(): for i in range(len(st.session_state['generated'])): with st.chat_message("user"): st.text(st.session_state["past"][i]) with st.chat_message("assistant"): st.markdown(st.session_state['generated'][i]) if i >= len(st.session_state['message_ids']): message_id = str(uuid.uuid4()) st.session_state['message_ids'].append(message_id) else: message_id = st.session_state['message_ids'][i] translate_key = f"translate_{message_id}" if translate_key not in st.session_state['translation_states']: st.session_state['translation_states'][translate_key] = False if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation, args=(translate_key,)): pass if st.session_state['translation_states'][translate_key]: with st.spinner("Translating..."): translated_text = translate_text(st.session_state['generated'][i]) st.markdown(f"**Translated:** \n\n {translated_text}") def toggle_translation(translate_key): st.session_state['translation_states'][translate_key] = not st.session_state['translation_states'][translate_key] def get_conversation_context(max_tokens=2000): context = [] total_tokens = 0 for past, generated in zip(reversed(st.session_state['past']), reversed(st.session_state['generated'])): user_message = f"User: {past}\n" assistant_message = f"Assistant: {generated}\n" message_tokens = len(user_message.split()) + len(assistant_message.split()) if total_tokens + message_tokens > max_tokens: break context.insert(0, user_message) context.insert(1, assistant_message) total_tokens += message_tokens return "".join(context) def validate_input(user_input): if not user_input or not user_input.strip(): return False, "Please enter a valid question or command." if len(user_input) > 500: return False, "Your input is too long. Please limit your question to 500 characters." return True, "" def process_user_input(user_input): user_input = user_input.rstrip() is_valid, error_message = validate_input(user_input) if not is_valid: st.error(error_message) return st.session_state['past'].append(user_input) with st.chat_message("user"): st.text(user_input) with st.chat_message("assistant"): message_placeholder = st.empty() message_placeholder.markdown("⏳ Thinking...") relevant_chunk = retrieve_relevant_chunks(user_input) conversation_context = get_conversation_context() try: output = generate_response(user_input, conversation_context, relevant_chunk) except Exception as e: logging.error(f"Failed to generate response after retries: {e}") output = "I apologize, but I'm having trouble processing your request at the moment. Please try again later." message_placeholder.empty() message_placeholder.markdown(output) st.session_state['generated'].append(output) st.session_state['memory'].save_context({"input": user_input}, {"output": output}) message_id = str(uuid.uuid4()) st.session_state['message_ids'].append(message_id) translate_key = f"translate_{message_id}" st.session_state['translation_states'][translate_key] = False if st.button(f"Translate to Arabic", key=f"btn_{translate_key}", on_click=toggle_translation, args=(translate_key,)): pass if st.session_state['translation_states'][translate_key]: with st.spinner("Translating..."): translated_text = translate_text(output) st.markdown(f"**Translated:** \n\n {translated_text}") st.rerun() def main(): initialize_session_state() st.title("Chat with PDF Using Mistral AI") uploaded_files = st.sidebar.file_uploader("Upload your PDF files", type="pdf", accept_multiple_files=True) if uploaded_files: with st.spinner("Processing PDF files..."): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) new_text_chunks = loop.run_until_complete(extract_text_from_pdfs(uploaded_files)) update_vector_store(new_text_chunks) st.success("PDF files uploaded and processed successfully.") display_chat_interface() user_input = st.chat_input("Ask about your PDF(s)") if user_input: process_user_input(user_input) if __name__ == "__main__": main()