File size: 3,951 Bytes
b87f806
 
902fe41
 
 
 
 
b87f806
902fe41
b87f806
 
902fe41
 
 
64e4082
 
 
902fe41
64e4082
 
 
 
 
 
 
 
 
 
 
902fe41
64e4082
 
b87f806
 
 
 
 
 
 
 
 
64e4082
 
 
 
 
 
 
 
 
 
 
 
902fe41
 
 
b87f806
 
902fe41
b87f806
902fe41
b87f806
902fe41
b87f806
64e4082
902fe41
b87f806
64e4082
902fe41
b87f806
b6d84ef
 
 
64e4082
2206df7
 
 
902fe41
64e4082
 
 
9625bf1
 
 
 
 
b6d84ef
 
 
3e39bdd
 
 
 
b6d84ef
64e4082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b87f806
 
902fe41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from huggingface_hub import InferenceClient
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain

# Load the HuggingFace language model and embeddings
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# Initialize the embeddings model for document retrieval
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

# Initialize vector_store and retriever as None initially
vector_store = None
retriever = None

def update_documents(text_input):
    global vector_store, retriever
    # Split the input text into individual documents based on newlines or other delimiters
    documents = text_input.split("\n")
    
    # Update the FAISS vector store with new documents
    vector_store = FAISS.from_texts(documents, embeddings)
    
    # Set the retriever to use the new vector store
    retriever = vector_store.as_retriever()
    return f"{len(documents)} documents successfully added to the vector store."

# Set up ConversationalRetrievalChain
rag_chain = None

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    global rag_chain, retriever
    
    if retriever is None:
        return "Please upload or enter documents before asking a question."

    # Create the chain if it hasn't been initialized
    if rag_chain is None:
        rag_chain = ConversationalRetrievalChain.from_llm(
            HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta"),
            retriever=retriever
        )

    # Combine history with the user message
    conversation_history = [{"role": "system", "content": system_message}]
    
    for val in history:
        if val[0]:
            conversation_history.append({"role": "user", "content": val[0]})
        if val[1]:
            conversation_history.append({"role": "assistant", "content": val[1]})

    conversation_history.append({"role": "user", "content": message})

    # Retrieve documents and generate response
    response = rag_chain({"question": message, "chat_history": history})

    # Return the model's response
    return response['answer']

def upload_file(filepath):
    name = Path(filepath).name
    return [gr.UploadButton(visible=False), gr.DownloadButton(label=f"Download {name}", value=filepath, visible=True)]

def download_file():
    return [gr.UploadButton(visible=True), gr.DownloadButton(visible=False)]

# Gradio interface setup
demo = gr.Blocks()

with demo:
    with gr.Row():
        # Input box for user to add documents
        doc_input = gr.Textbox(
            lines=10, placeholder="Enter your documents here, one per line.", label="Input Documents"
        )
        # upload_button = gr.Button("Upload Documents")
        with gr.Row():
            u = gr.UploadButton("Upload a file", file_count="single")
            d = gr.DownloadButton("Download the file", visible=False)

    u.upload(upload_file, u, [u, d])
    d.click(download_file, None, [u, d])
        
    
    with gr.Row():
        # Chat interface for the RAG system
        chat = gr.ChatInterface(
            respond,
            additional_inputs=[
                gr.Textbox(value="You are a helpful assistant.", label="System message"),
                gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
                gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
                gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
            ],
        )

    # Bind button to update the document vector store
    upload_button.click(update_documents, inputs=[doc_input], outputs=gr.Textbox(label="Status"))

if __name__ == "__main__":
    demo.launch()