File size: 3,403 Bytes
b87f806
 
902fe41
 
 
 
 
27a6371
 
 
b87f806
902fe41
b87f806
902fe41
 
64e4082
 
902fe41
27a6371
 
 
 
 
 
 
 
 
 
b3ae10a
64e4082
b3ae10a
64e4082
 
 
902fe41
b3ae10a
 
27a6371
64e4082
 
 
 
 
 
 
 
 
 
 
902fe41
 
b87f806
 
902fe41
b87f806
902fe41
b87f806
902fe41
b87f806
902fe41
 
b87f806
b3ae10a
27a6371
 
 
 
 
 
 
2206df7
902fe41
64e4082
 
 
9625bf1
27a6371
 
 
 
b3ae10a
64e4082
 
 
 
 
 
 
 
 
 
 
b87f806
27a6371
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
from huggingface_hub import InferenceClient
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
from unstructured.documents import from_pdf
import camelot
from pathlib import Path

# Load the HuggingFace language model and embeddings
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

vector_store = None
retriever = None

def extract_text_from_pdf(filepath):
    # Use unstructured to read text from the PDF
    documents = from_pdf(filepath)
    return "\n".join([doc.text for doc in documents])

def extract_tables_from_pdf(filepath):
    # Use camelot to read tables from the PDF
    tables = camelot.read_pdf(filepath, pages='1-end')
    return [table.df.to_string(index=False) for table in tables]

def update_documents(text_input):
    global vector_store, retriever
    documents = text_input.split("\n")
    vector_store = FAISS.from_texts(documents, embeddings)
    retriever = vector_store.as_retriever()
    return f"{len(documents)} documents successfully added to the vector store."

rag_chain = None

def respond(message, history, system_message, max_tokens, temperature, top_p):
    global rag_chain, retriever
    
    if retriever is None:
        return "Please upload or enter documents before asking a question."

    if rag_chain is None:
        rag_chain = ConversationalRetrievalChain.from_llm(
            HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta"),
            retriever=retriever
        )

    conversation_history = [{"role": "system", "content": system_message}]
    
    for val in history:
        if val[0]:
            conversation_history.append({"role": "user", "content": val[0]})
        if val[1]:
            conversation_history.append({"role": "assistant", "content": val[1]})

    conversation_history.append({"role": "user", "content": message})

    response = rag_chain({"question": message, "chat_history": history})
    return response['answer']

def upload_file(filepath):
    text = extract_text_from_pdf(filepath)
    tables = extract_tables_from_pdf(filepath)
    
    # Update documents in the vector store
    update_documents(text)
    
    return [gr.UploadButton(visible=False), gr.DownloadButton(label=f"Download {Path(filepath).name}", value=filepath, visible=True), f"{len(tables)} tables extracted."]

# Gradio interface setup
demo = gr.Blocks()

with demo:
    with gr.Row():
        u = gr.UploadButton("Upload a file", file_count="single")
        d = gr.DownloadButton("Download the file", visible=False)

    u.upload(upload_file, u, [u, d, "status"])
    
    with gr.Row():
        chat = gr.ChatInterface(
            respond,
            additional_inputs=[
                gr.Textbox(value="You are a helpful assistant.", label="System message"),
                gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
                gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
                gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
            ],
        )

if __name__ == "__main__":
    demo.launch()