import gradio as gr from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.chains import ConversationalRetrievalChain from langchain_community.llms import HuggingFaceEndpoint from langchain.chains import ConversationChain from langchain.memory import ConversationBufferMemory import os api_token = os.getenv("HF_TOKEN") # List of LLMs list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # Load and split PDF documents def load_doc(list_file_path): loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) doc_splits = text_splitter.split_documents(pages) return doc_splits # Create vector database def create_db(splits): embeddings = HuggingFaceEmbeddings() vectordb = FAISS.from_documents(splits, embeddings) return vectordb # Initialize LLM chain def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct": llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token=api_token, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, ) else: llm = HuggingFaceEndpoint( huggingfacehub_api_token=api_token, repo_id=llm_model, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, ) memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False, ) return qa_chain # Function to handle chatbot responses def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, vector_db, llm_model, ): # Initialize LLM chain if not already initialized if not hasattr(respond, 'qa_chain'): respond.qa_chain = initialize_llmchain(llm_model, temperature, max_tokens, top_p, vector_db) # Format chat history formatted_chat_history = [] for user_message, bot_message in history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") formatted_chat_history.append(f"User: {message}") # Generate response using QA chain response = respond.qa_chain.invoke({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] return response_answer # CSS for styling the interface css = """ body { background-color: #06688E; /* Dark background */ color: white; /* Text color for better visibility */ } .gr-button { background-color: #42B3CE !important; /* White button color */ color: black !important; /* Black text for contrast */ border: none !important; padding: 8px 16px !important; border-radius: 5px !important; } .gr-button:hover { background-color: #e0e0e0 !important; /* Slightly lighter button on hover */ } .gr-slider-container { color: white !important; /* Slider labels in white */ } """ # Initialize database and LLM chain def initialize_database_and_llm(list_file_obj, llm_option, max_tokens, temperature, top_p): list_file_path = [x.name for x in list_file_obj if x is not None] doc_splits = load_doc(list_file_path) vector_db = create_db(doc_splits) llm_name = list_llm[llm_option] return vector_db, llm_name # Gradio interface demo = gr.ChatInterface( respond, additional_inputs=[ gr.Files(file_count="multiple", file_types=["pdf"], label="Upload PDF documents", visible=False), gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple, visible=False), gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max new tokens", visible=False), gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", visible=False), gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k", visible=False), ], css=css, title="RAG PDF Chatbot", description="Query your PDF documents using a Retrieval Augmented Generation (RAG) chatbot.", ) # Preprocessing events demo.preprocess( initialize_database_and_llm, inputs=["document", "llm_btn", "slider_maxtokens", "slider_temperature", "slider_topk"], outputs=["vector_db", "llm_model"], api_name="initialize", ) if __name__ == "__main__": demo.launch(share=True)