import gradio as gr
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline, HuggingFaceHub
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from transformers import AutoTokenizer
import transformers
import torch
import tqdm
import accelerate
# Default LLM model
chosen_llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
# Default chunk size and overlap
chunk_size = 600
chunk_overlap = 40
# Default model configuration
llm_temperature = 0.7
max_tokens = 1024
top_k = 3
# Initialize vector database in background
accelerated(initialize_database)() # Run in background with Accelerate
# Define functions (no changes needed here)
# ... (your existing functions here)
def demo():
with gr.Blocks(theme="base") as demo:
qa_chain = gr.State() # Store the initialized QA chain
collection_name = gr.State()
gr.Markdown(
"""
PDF-based chatbot (powered by LangChain and open-source LLMs)
Ask any questions about your PDF documents, along with follow-ups
Note: This AI assistant performs retrieval-augmented generation from your PDF documents. \
When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.
Warning: This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.
"""
)
with gr.Row():
document = gr.Files(
height=100,
file_count="multiple",
file_types=["pdf"],
interactive=True,
label="Upload your PDF documents (single or multiple)",
)
with gr.Row():
chatbot = gr.Chatbot(height=300)
with gr.Accordion("Advanced - Document references", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
source2_page = gr.Number(label="Page", scale=1)
with gr.Row():
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
source3_page = gr.Number(label="Page", scale=1)
with gr.Row():
msg = gr.Textbox(placeholder="Type message", container=True)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.ClearButton([msg, chatbot])
# Initialize default QA chain when documents are uploaded
document.uploaded(initialize_LLM, inputs=[chosen_llm_model])
# Chatbot events
msg.submit(conversation, inputs=[qa_chain, msg, chatbot])
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot])
clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
demo.launch(debug=True)
if __name__ == "__main__":
demo()