File size: 6,287 Bytes
3ec9224
5be8df6
3ec9224
5db4902
5be8df6
5db4902
5be8df6
ed62921
5db4902
5be8df6
 
ed62921
3ec9224
1ef8d7c
 
aa98840
1ef8d7c
5be8df6
 
 
 
 
9bf736d
5be8df6
ed62921
4df026c
b1ec9ac
5be8df6
4df026c
5be8df6
 
 
 
 
 
 
 
 
 
1ef8d7c
5be8df6
1ef8d7c
5be8df6
 
 
1ef8d7c
 
5be8df6
 
 
4df026c
 
 
 
 
 
 
 
5be8df6
eb94a8f
5be8df6
 
9733941
5be8df6
 
 
 
 
 
 
 
 
9733941
138ca2e
5be8df6
 
00bd139
5be8df6
9bf736d
 
32a58be
aa98840
9bf736d
989bff5
08108c1
 
fa7cc51
6e8daa8
fa7cc51
6e8daa8
9bf736d
 
4df026c
 
9bf736d
4df026c
5be8df6
4df026c
5be8df6
1ef8d7c
4df026c
 
 
 
 
 
5be8df6
 
 
 
 
 
 
 
00bd139
4df026c
 
5be8df6
 
9733941
04361a6
 
9733941
 
4df026c
 
5be8df6
 
 
00bd139
5be8df6
 
51d2a09
4df026c
51d2a09
4df026c
 
51d2a09
4df026c
5be8df6
4df026c
 
 
 
 
 
5be8df6
4df026c
 
 
 
5be8df6
549b3f4
5be8df6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import os

from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_huggingface import HuggingFaceEmbeddings , HuggingFaceEndpoint  
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory


from pathlib import Path
import chromadb
from unidecode import unidecode

from transformers import AutoTokenizer
import transformers
import torch
import tqdm 
import accelerate
import re

os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

def load_doc(list_file_path, chunk_size=600, chunk_overlap=40):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = chunk_size, 
        chunk_overlap = chunk_overlap)
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

def create_db(splits, collection_name):
    embedding = HuggingFaceEmbeddings()
    new_client = chromadb.EphemeralClient()
    vectordb = Chroma.from_documents(
        documents=splits,
        embedding=embedding,
        client=new_client,
        collection_name=collection_name,
    )
    return vectordb

def initialize_llmchain(llm_model, vector_db, progress=gr.Progress()):
    progress(0.1, desc="Initializing HF Hub...")
    llm = HuggingFaceEndpoint(
        repo_id=llm_model, 
        temperature=0.7,
        max_new_tokens=1024,
        top_k=3,
    )
    
    progress(0.75, desc="Defining buffer memory...")
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )
    retriever=vector_db.as_retriever()
    progress(0.8, desc="Defining retrieval chain...")
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff", 
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )
    progress(0.9, desc="Done!")
    return qa_chain

def create_collection_name(filepath):
    collection_name = Path(filepath).stem
    collection_name = collection_name.replace(" ","-") 
    collection_name = unidecode(collection_name)
    collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
    collection_name = collection_name[:50]
    if len(collection_name) < 3:
        collection_name = collection_name + 'xyz'
    if not collection_name[0].isalnum():
        collection_name = 'A' + collection_name[1:]
    if not collection_name[-1].isalnum():
        collection_name = collection_name[:-1] + 'Z'
    return collection_name

def initialize_all(file_obj, progress=gr.Progress()):
    file_path = [file_obj.name]
    progress(0.1, desc="Creating collection name...")
    collection_name = create_collection_name(file_path[0])
    progress(0.25, desc="Loading document...")
    doc_splits = load_doc(file_path)
    progress(0.5, desc="Generating vector database...")
    vector_db = create_db(doc_splits, collection_name)
    progress(0.75, desc="Initializing LLM...")
    qa_chain = initialize_llmchain(list_llm[0], vector_db, progress)
    if qa_chain is None:
        raise gr.Error("Failed to initialize QA chain. Please check the configuration.")
    progress(1.0, desc="Initialization complete!")
    return qa_chain, "Initialization complete!"

def format_chat_history(message, chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}")
        formatted_chat_history.append(f"Assistant: {bot_message}")
    return formatted_chat_history

def conversation(qa_chain, message, history):
    if qa_chain is None:
        return "QA chain is not initialized. Please upload the PDF and initialize again.", history, "", ""
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"]
    if response_answer.find("Helpful Answer:") != -1:
        response_answer = response_answer.split("Helpful Answer:")[-1]
    response_sources = response["source_documents"]
    response_source1 = response_sources[0].page_content.strip()
    response_source1_page = response_sources[0].metadata["page"] + 1 if "page" in response_sources[0].metadata else "N/A"
    return gr.update(value=""), [(message, response_answer)], response_source1, response_source1_page

def demo():
    with gr.Blocks(theme="base") as demo:
        qa_chain = gr.State()
        
        gr.Markdown(
        """<center><h2>PDF-based chatbot</center></h2>
        <h3>Ask any questions about your PDF document</h3>""")
        
        document = gr.File(height=100, file_types=["pdf"], label="Upload your PDF document")
        chatbot = gr.Chatbot(height=300)
        
        with gr.Accordion("Advanced - Document references", open=False):
            with gr.Row():
                doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
                source1_page = gr.Number(label="Page", scale=1)
                
        msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
        submit_btn = gr.Button("Submit message")
        clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
            
        document.upload(initialize_all, inputs=document, outputs=[qa_chain, gr.Textbox()])
        msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[msg, chatbot, doc_source1, source1_page])
        submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[msg, chatbot, doc_source1, source1_page])
        clear_btn.click(lambda:[None,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page])

    demo.queue().launch(debug=True)

if __name__ == "__main__":
    demo()