File size: 3,670 Bytes
3ec9224
5be8df6
 
 
 
 
d4b9831
af00b58
 
 
1ef8d7c
 
af00b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ef8d7c
af00b58
 
5be8df6
 
 
af00b58
1ef8d7c
af00b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4b9831
af00b58
 
 
 
 
 
 
 
 
 
 
 
 
d4b9831
af00b58
 
d4b9831
 
 
af00b58
 
 
 
 
 
 
 
 
7da7eb1
af00b58
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import os
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline, HuggingFaceHub
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from transformers import AutoTokenizer
import transformers
import torch
import tqdm
import accelerate

# Default LLM model
chosen_llm_model = "mistralai/Mistral-7B-Instruct-v0.2"

# Default chunk size and overlap
chunk_size = 600
chunk_overlap = 40

# Default model configuration
llm_temperature = 0.7
max_tokens = 1024
top_k = 3

# Initialize vector database in background
accelerated(initialize_database)()  # Run in background with Accelerate

# Define functions (no changes needed here)
# ... (your existing functions here)

def demo():
    with gr.Blocks(theme="base") as demo:
        qa_chain = gr.State()  # Store the initialized QA chain
        collection_name = gr.State()

        gr.Markdown(
            """
            <center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
            <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
            <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents. \
            When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
            <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
            """
        )

        with gr.Row():
            document = gr.Files(
                height=100,
                file_count="multiple",
                file_types=["pdf"],
                interactive=True,
                label="Upload your PDF documents (single or multiple)",
            )

        with gr.Row():
            chatbot = gr.Chatbot(height=300)

        with gr.Accordion("Advanced - Document references", open=False):
            with gr.Row():
                doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
                source1_page = gr.Number(label="Page", scale=1)
            with gr.Row():
                doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
                source2_page = gr.Number(label="Page", scale=1)
            with gr.Row():
                doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
                source3_page = gr.Number(label="Page", scale=1)

        with gr.Row():
            msg = gr.Textbox(placeholder="Type message", container=True)

        with gr.Row():
            submit_btn = gr.Button("Submit")
            clear_btn = gr.ClearButton([msg, chatbot])

        # Initialize default QA chain when documents are uploaded
        document.uploaded(initialize_LLM, inputs=[chosen_llm_model])

        # Chatbot events
        msg.submit(conversation, inputs=[qa_chain, msg, chatbot])
        submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot])
        clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])

    demo.launch(debug=True)

if __name__ == "__main__":
    demo()