File size: 5,991 Bytes
3ec9224
5be8df6
3ec9224
5be8df6
 
 
 
d4b9831
5be8df6
3ec9224
1ef8d7c
 
 
d4b9831
 
 
 
 
 
 
 
 
0abb90d
5be8df6
 
 
 
 
d4b9831
5be8df6
 
 
0abb90d
1ef8d7c
5be8df6
1ef8d7c
5be8df6
 
 
1ef8d7c
d4b9831
5be8df6
 
 
0abb90d
5be8df6
d4b9831
 
 
 
 
 
 
 
 
fc1e558
 
d4b9831
fc1e558
d4b9831
5be8df6
 
9733941
5be8df6
 
d4b9831
fc1e558
d4b9831
5be8df6
 
 
0abb90d
5be8df6
9733941
d4b9831
5be8df6
d4b9831
5be8df6
00bd139
5be8df6
d4b9831
 
fc1e558
5be8df6
1ef8d7c
fc1e558
d4b9831
 
 
 
0abb90d
d4b9831
0abb90d
fc1e558
5be8df6
d4b9831
 
 
 
 
 
 
5be8df6
 
 
3ca2785
1ef8d7c
d4b9831
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import os

from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceHub

from pathlib import Path
import chromadb

# List of available LLM models
list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
            "google/gemma-7b-it", "google/gemma-2b-it",
            "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
            "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
            "google/flan-t5-xxl"
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# Load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

# Create vector database
def create_db(splits, collection_name):
    embedding = HuggingFaceEmbeddings()
    new_client = chromadb.EphemeralClient()
    vectordb = Chroma.from_documents(
        documents=splits,
        embedding=embedding,
        client=new_client,
        collection_name=collection_name
    )
    return vectordb

# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
        model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
    elif llm_model == "microsoft/phi-2":
        raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
    elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
        model_kwargs = {"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
    else:
        model_kwargs = {"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
    
    llm = HuggingFaceHub(
        repo_id=llm_model,
        model_kwargs=model_kwargs
    )

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )

    retriever = vector_db.as_retriever()
    
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff",
        memory=memory,
        return_source_documents=True,
        verbose=False
    )

    progress(0.9, desc="Done!")
    return qa_chain

def initialize_demo(list_file_obj, chunk_size, chunk_overlap, db_progress):
    list_file_path = [file.name for file in list_file_obj if file is not None]
    collection_name = Path(list_file_path[0]).stem.replace(" ", "-")[:50]
    doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
    vector_db = create_db(doc_splits, collection_name)
    qa_chain = initialize_llmchain(
        list_llm[0],  # Using Mistral-7B-Instruct-v0.2 as the LLM model
        0.7,  # Temperature
        1024,  # Max Tokens
        3,  # Top K
        vector_db,
        db_progress
    )
    return vector_db, collection_name, qa_chain, "Complete!"

def upload_file(file_obj):
    list_file_path = []
    for file in file_obj:
        if file is not None:
            file_path = file.name
            list_file_path.append(file_path)
    return list_file_path

def demo():
    with gr.Blocks(theme="base") as demo:
        vector_db = gr.State()
        collection_name = gr.State()
        qa_chain = gr.State()
        
        with gr.Tab("Step 1 - Document pre-processing"):
            document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
            slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
            slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
            db_progress = gr.Textbox(label="Vector database initialization", value="None")
            db_btn = gr.Button("Generate vector database...")

        with gr.Tab("Step 2 - QA chain initialization"):
            llm_progress = gr.Textbox(value="None", label="QA chain initialization")
            qachain_btn = gr.Button("Initialize question-answering chain...")

        with gr.Tab("Step 3 - Conversation with chatbot"):
            chatbot = gr.Chatbot(height=300)
            doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
            source1_page = gr.Number(label="Page", scale=1)
            doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
            source2_page = gr.Number(label="Page", scale=1)
            doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
            source3_page = gr.Number(label="Page", scale=1)
            msg = gr.Textbox(placeholder="Type message", container=True)
            submit_btn = gr.Button("Submit")
            clear_btn = gr.ClearButton([msg, chatbot])

        document.upload(initialize_demo, inputs=[document, slider_chunk_size, slider_chunk_overlap, db_progress], outputs=[vector_db, collection_name, qa_chain, db_progress])
        qachain_btn.click(initialize_llmchain, inputs=[qa_chain, llm_progress], outputs=[qa_chain, llm_progress])
        submit_btn.click(lambda: None, inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2