import os import sys import random import gradio as gr from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain_groq import ChatGroq from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough # Initialize the FAISS vector store vector_store = None # Sample PDF file sample_filenames = ["User Guide.pdf", "Installation.pdf", ] desc = """
RAG is an approach that combines retrieval-based and generative LLM models to improve the accuracy and relevance of generated text. It works by first retrieving relevant documents from an external knowledge source (like PDF files) and then using a LLM model to produce responses based on both the input query and the retrieved content. This method enhances factual correctness and allows the model to access up-to-date or domain-specific information without retraining.
Choose the PDF files and click Load and Index Documents button below to upload and index the files. It could take some time depends on the size of files. Once you see the message "PDF(s) indexed successfully!" in the below Indexing Status, go to the Chatbot tab to ask any relevant questios.
""" desc_sample = """Alternatively, click the button below to load a User Guide and an Installation for a smoke alarm device into the vector database. It could take a couple of minutes to process. Once you see the message "PDF(s) indexed successfully!" in the below Indexing Status, go to the Chatbot tab to ask any relevant questions about the device.
""" gui_css=""" .gradio-container { font-family: 'Inter', sans-serif; border-radius: 12px; overflow: hidden; } .panel { border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } .gr-button { border-radius: 8px; padding: 10px 20px; font-weight: bold; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); transition: all 0.2s ease-in-out; } .gr-button:hover { transform: translateY(-2px); box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15); } .gr-textbox textarea { border-radius: 8px; } .gr-slider { padding: 10px 0; } .gr-tabitem { padding: 20px; } """ sample_button = "Load and Index Sample PDF Files" examples_questions = [["How long is the lifespan of this smoke alarm?"], ["How often should I change the battery?"], ["Where should I install the smoke alarm in my home?"], ["How do I test if the smoke alarm is working?"], ["What should I do if the smoke alarm keeps beeping?"], ["Can this smoke alarm detect carbon monoxide too?"], ["How do I clean the smoke alarm properly?"], ["What type of battery does this smoke alarm use?"], ["How loud is the smoke alarm when it goes off?"], ["Can I install this smoke alarm on a wall instead of a ceiling?"], ] template = \ """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say you don't know because no relevant information in the provided documents, don't try to make up an answer. {context} Question: {question} Answer: """ # Function to handle PDF upload and indexing def load_pdf(files): global vector_store documents = [] # Load the PDFs for file in files: loader = PyPDFLoader(file.name) documents.extend(loader.load()) # print(f"{file} is processed!") # Split the documents into chunks text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) texts = text_splitter.split_documents(documents) # Embed the chunks # embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" embedding_model_name = "bert-base-uncased" # embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, encode_kwargs={"normalize_embeddings": True}) embeddings = HuggingFaceEmbeddings() # Store the embeddings in the vector store vector_store = FAISS.from_documents(texts, embeddings) return "PDF(s) indexed successfully!" def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) def generate_response(query, history, model, temperature, max_tokens, top_p, seed): if vector_store is None: return "Please upload and index a PDF at the Indexing tab.", "" if seed == 0: seed = random.randint(1, 100000) retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 16}) llm = ChatGroq(groq_api_key=os.environ.get("GROQ_API_KEY"), model=model) custom_rag_prompt = PromptTemplate.from_template(template) docs = retriever.invoke(query) relevant_info = format_docs(docs) rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | custom_rag_prompt | llm | StrOutputParser() ) response = rag_chain.invoke(query) return response, relevant_info template = """ You are a helpful AI assistant. Use the following context to answer the question. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} Question: {question} """ # --- Gradio Interface using gr.Blocks() --- with gr.Blocks(theme=gr.themes.Soft(), css=gui_css) as demo: with gr.Tab("Indexing"): with gr.Row(): gr.Markdown(desc) with gr.Row(): with gr.Column(): gr.Markdown(desc_pdf_upload) pdf_files = gr.File(label="Upload PDF Documents", file_types=[".pdf"], interactive=True, file_count="multiple") load_button = gr.Button("Load and Index Documents", variant="secondary") with gr.Column(): gr.Markdown(desc_sample) sample_files = gr.File( label="Sample PDF Files", file_count="multiple", file_types=[".pdf"], value=sample_filenames, visible=True, interactive=False ) sample_button = gr.Button(sample_button) with gr.Row(): index_output = gr.Textbox(label="Indexing Status") sample_button.click(load_pdf, inputs=sample_files, outputs=index_output) load_button.click(load_pdf, inputs=pdf_files, outputs=index_output) with gr.Tab("Chatbot"): with gr.Row(): with gr.Column(scale=2): # Chatbot component chatbot = gr.Chatbot( show_label=False, show_share_button=False, show_copy_button=True, layout="panel", height=500, # Set a fixed height for the chatbot avatar_images=( "https://placehold.co/60x60/FFD700/000000?text=U", # User avatar "https://placehold.co/60x60/6366F1/FFFFFF?text=AI" # Bot avatar ) ) # Message input textbox msg = gr.Textbox( label="Your Message", placeholder="Type your message here...", show_copy_button=True, container=False # Prevent it from being wrapped in a default container ) with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.ClearButton() # Will be configured below gr.Examples( examples=examples_questions, inputs=[msg], outputs=[msg], # Update the message input with the example label="Quick Examples", cache_examples=False, ) with gr.Column(scale=1): gr.Markdown("### LLM Settings") model_name = gr.Dropdown(label="Model Name", choices=[ "llama-3.3-70b-versatile", "llama-3.1-8b-instant", "llama3-70b-8192", "llama3-8b-8192", "whisper-large-v3", "whisper-large-v3-turbo", "meta-llama/Llama-Guard-4-12B", "gemma2-9b-it" ], value="llama-3.3-70b-versatile", interactive=True ) temperature_slider = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.01, label="Temperature", interactive=True) max_tokens_slider = gr.Slider(minimum=10, maximum=2000, value=500, step=10, label="Max Tokens", interactive=True) top_p_slider = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="Top P", interactive=True) seed_number = gr.Number(minimum=0, maximum=100000, value=0, step=1, label="Seed", precision=0, interactive=True) gr.Markdown("### Retrieved Information") # Textbox for relevant_info relevant_info_textbox = gr.Textbox( label="Retrieved Information", interactive=False, # Not editable by the user lines=20, show_copy_button=True, autoscroll=True, container=True # Ensure it has a container for styling ) # --- Event Handling --- # This function acts as a wrapper to process inputs and distribute outputs def process_chat_and_info(message, chat_history, model, temp, max_tok, top_p_val, seed_val): # Call your generate_response function which returns two values bot_message, retrieved_info = generate_response( message, chat_history, model, temp, max_tok, top_p_val, seed_val ) # Update the chat history for the chatbot component chat_history.append((message, bot_message)) # Return values in the order of the outputs list return chat_history, retrieved_info, "" # Clear the message input after sending # Bind the `process_chat_and_info` function to the submit event of the message textbox msg.submit( fn=process_chat_and_info, inputs=[msg, chatbot, model_name, temperature_slider, max_tokens_slider, top_p_slider, seed_number], outputs=[chatbot, relevant_info_textbox, msg], # Order matters here: chatbot, relevant_info, then msg queue=False # Set to True if you expect heavy load ) # Bind the `process_chat_and_info` function to the click event of the send button submit_btn.click( fn=process_chat_and_info, inputs=[msg, chatbot, model_name, temperature_slider, max_tokens_slider, top_p_slider, seed_number], outputs=[chatbot, relevant_info_textbox, msg], # Order matters here queue=False # Set to True if you expect heavy load ) # Configure the clear button to clear both the chatbot and the relevant_info_textbox clear_btn.add([msg, chatbot, relevant_info_textbox]) demo.launch(server_name="0.0.0.0", server_port=7860)