Spaces:
Sleeping
Sleeping
File size: 6,287 Bytes
3ec9224 5be8df6 3ec9224 5db4902 5be8df6 5db4902 5be8df6 ed62921 5db4902 5be8df6 ed62921 3ec9224 1ef8d7c aa98840 1ef8d7c 5be8df6 9bf736d 5be8df6 ed62921 4df026c b1ec9ac 5be8df6 4df026c 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 4df026c 5be8df6 eb94a8f 5be8df6 9733941 5be8df6 9733941 138ca2e 5be8df6 00bd139 5be8df6 9bf736d 32a58be aa98840 9bf736d 989bff5 08108c1 fa7cc51 6e8daa8 fa7cc51 6e8daa8 9bf736d 4df026c 9bf736d 4df026c 5be8df6 4df026c 5be8df6 1ef8d7c 4df026c 5be8df6 00bd139 4df026c 5be8df6 9733941 04361a6 9733941 4df026c 5be8df6 00bd139 5be8df6 51d2a09 4df026c 51d2a09 4df026c 51d2a09 4df026c 5be8df6 4df026c 5be8df6 4df026c 5be8df6 549b3f4 5be8df6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_huggingface import HuggingFaceEmbeddings , HuggingFaceEndpoint
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
from transformers import AutoTokenizer
import transformers
import torch
import tqdm
import accelerate
import re
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(list_file_path, chunk_size=600, chunk_overlap=40):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = chunk_size,
chunk_overlap = chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
def initialize_llmchain(llm_model, vector_db, progress=gr.Progress()):
progress(0.1, desc="Initializing HF Hub...")
llm = HuggingFaceEndpoint(
repo_id=llm_model,
temperature=0.7,
max_new_tokens=1024,
top_k=3,
)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever=vector_db.as_retriever()
progress(0.8, desc="Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
progress(0.9, desc="Done!")
return qa_chain
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = collection_name.replace(" ","-")
collection_name = unidecode(collection_name)
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
collection_name = collection_name[:50]
if len(collection_name) < 3:
collection_name = collection_name + 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
return collection_name
def initialize_all(file_obj, progress=gr.Progress()):
file_path = [file_obj.name]
progress(0.1, desc="Creating collection name...")
collection_name = create_collection_name(file_path[0])
progress(0.25, desc="Loading document...")
doc_splits = load_doc(file_path)
progress(0.5, desc="Generating vector database...")
vector_db = create_db(doc_splits, collection_name)
progress(0.75, desc="Initializing LLM...")
qa_chain = initialize_llmchain(list_llm[0], vector_db, progress)
if qa_chain is None:
raise gr.Error("Failed to initialize QA chain. Please check the configuration.")
progress(1.0, desc="Initialization complete!")
return qa_chain, "Initialization complete!"
def format_chat_history(message, chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
def conversation(qa_chain, message, history):
if qa_chain is None:
return "QA chain is not initialized. Please upload the PDF and initialize again.", history, "", ""
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if response_answer.find("Helpful Answer:") != -1:
response_answer = response_answer.split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
response_source1 = response_sources[0].page_content.strip()
response_source1_page = response_sources[0].metadata["page"] + 1 if "page" in response_sources[0].metadata else "N/A"
return gr.update(value=""), [(message, response_answer)], response_source1, response_source1_page
def demo():
with gr.Blocks(theme="base") as demo:
qa_chain = gr.State()
gr.Markdown(
"""<center><h2>PDF-based chatbot</center></h2>
<h3>Ask any questions about your PDF document</h3>""")
document = gr.File(height=100, file_types=["pdf"], label="Upload your PDF document")
chatbot = gr.Chatbot(height=300)
with gr.Accordion("Advanced - Document references", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
source1_page = gr.Number(label="Page", scale=1)
msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
submit_btn = gr.Button("Submit message")
clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
document.upload(initialize_all, inputs=document, outputs=[qa_chain, gr.Textbox()])
msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[msg, chatbot, doc_source1, source1_page])
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[msg, chatbot, doc_source1, source1_page])
clear_btn.click(lambda:[None,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page])
demo.queue().launch(debug=True)
if __name__ == "__main__":
demo()
|