File size: 4,007 Bytes
3ec9224
5be8df6
22d9b85
80f5e00
22d9b85
 
 
 
 
af00b58
1ef8d7c
 
af00b58
 
 
 
 
 
22d9b85
af00b58
 
 
 
 
22d9b85
 
 
 
af00b58
22d9b85
 
a1f0b23
22d9b85
 
 
 
 
 
 
 
1ef8d7c
5be8df6
 
 
af00b58
1ef8d7c
af00b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4b9831
af00b58
 
 
 
 
 
 
 
 
 
 
 
 
d4b9831
af00b58
 
d4b9831
 
 
af00b58
 
 
 
 
 
72f6344
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader  # Corrected import
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma  # Corrected import
from langchain.chains import ConversationalRetrievalChain  # Note: Not from "langchain_community"
from langchain_community.embeddings import HuggingFaceEmbeddings  # Corrected import
from langchain_community.llms import HuggingFacePipeline, HuggingFaceHub  # Corrected import
from langchain.chains import ConversationChain  # Note: Not from "langchain_community"
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from transformers import AutoTokenizer
import transformers
import torch
import tqdm
import accelerate

# LLM model and parameters (adjusted for clarity)
chosen_llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
llm_temperature = 0.7
max_tokens = 1024
top_k = 3

# Chunk size and overlap (adjusted for clarity)
chunk_size = 600
chunk_overlap = 40

# Initialize vector database in background
accelerated(initialize_database)()  # Function definition moved here


def initialize_database():
    """
    This function initializes the vector database (assumed to be ChromaDB).
    Modify this function based on your specific database needs.
    """
    # Replace with your ChromaDB connection and schema creation logic
    # ...
    pass


def demo():
    with gr.Blocks(theme="base") as demo:
        qa_chain = gr.State()  # Store the initialized QA chain
        collection_name = gr.State()

        gr.Markdown(
            """
            <center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
            <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
            <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents. \
            When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
            <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
            """
        )

        with gr.Row():
            document = gr.Files(
                height=100,
                file_count="multiple",
                file_types=["pdf"],
                interactive=True,
                label="Upload your PDF documents (single or multiple)",
            )

        with gr.Row():
            chatbot = gr.Chatbot(height=300)

        with gr.Accordion("Advanced - Document references", open=False):
            with gr.Row():
                doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
                source1_page = gr.Number(label="Page", scale=1)
            with gr.Row():
                doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
                source2_page = gr.Number(label="Page", scale=1)
            with gr.Row():
                doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
                source3_page = gr.Number(label="Page", scale=1)

        with gr.Row():
            msg = gr.Textbox(placeholder="Type message", container=True)

        with gr.Row():
            submit_btn = gr.Button("Submit")
            clear_btn = gr.ClearButton([msg, chatbot])

        # Initialize default QA chain when documents are uploaded
        document.uploaded(initialize_LLM, inputs=[chosen_llm_model])

        # Chatbot events
        msg.submit(conversation, inputs=[qa_chain, msg, chatbot])
        submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot])
        clear_btn.click(lambda: [None, "", 0, "", 0, "", 0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])