File size: 4,854 Bytes
b87f806
 
902fe41
bc3c4e5
 
902fe41
 
af50095
27a6371
 
b87f806
902fe41
b87f806
902fe41
 
64e4082
 
902fe41
af514b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0e9961
 
 
af514b7
 
 
 
af50095
85ac5f6
af514b7
 
 
27a6371
af514b7
 
 
 
 
27a6371
 
b3ae10a
64e4082
b3ae10a
64e4082
 
 
902fe41
b3ae10a
 
27a6371
64e4082
 
 
 
 
 
 
 
 
 
 
902fe41
 
b87f806
 
902fe41
b87f806
902fe41
b87f806
902fe41
b87f806
902fe41
 
b87f806
af514b7
 
 
27a6371
 
1000a42
27a6371
cdad274
 
 
2206df7
902fe41
64e4082
 
 
9625bf1
27a6371
 
af514b7
27a6371
cdad274
 
 
 
af514b7
b3ae10a
64e4082
 
 
 
 
 
 
 
 
 
 
b87f806
27a6371
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
from huggingface_hub import InferenceClient
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
from langchain_unstructured import UnstructuredLoader
import camelot
from pathlib import Path

# Load the HuggingFace language model and embeddings
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

vector_store = None
retriever = None

def parse_page_input(page_input):
    pages = set()
    for part in page_input.split(","):
        part = part.strip()
        if '-' in part:  # Handle ranges
            start, end = part.split('-')
            try:
                pages.update(range(int(start), int(end) + 1))
            except ValueError:
                continue  # Skip invalid ranges
        else:  # Handle individual pages
            try:
                pages.add(int(part))
            except ValueError:
                continue  # Skip invalid page numbers
    return sorted(pages)  # Return a sorted list of pages

def extract_text_from_pdf(filepath, pages):
    chunk_size = 1000  # Example chunk size
    overlap = 100      # Example overlap
    loader = UnstructuredLoader([filepath], chunk_size=chunk_size, overlap=overlap)
    pages_to_load = parse_page_input(pages)  # Parse the input for page numbers
    
    # Filter pages according to user input
    pages_data = []
    for doc in loader.lazy_load():
        if doc.metadata.page_number in pages_to_load:  # Assuming doc.page_number exists
            pages_data.append(doc.page_content)
    
    return "\n".join(pages_data)

def extract_tables_from_pdf(filepath, pages):
    if pages:
        tables = camelot.read_pdf(filepath, pages=pages)
    else:
        tables = camelot.read_pdf(filepath, pages='1-end')
    return [table.df.to_string(index=False) for table in tables]

def update_documents(text_input):
    global vector_store, retriever
    documents = text_input.split("\n")
    vector_store = FAISS.from_texts(documents, embeddings)
    retriever = vector_store.as_retriever()
    return f"{len(documents)} documents successfully added to the vector store."

rag_chain = None

def respond(message, history, system_message, max_tokens, temperature, top_p):
    global rag_chain, retriever
    
    if retriever is None:
        return "Please upload or enter documents before asking a question."

    if rag_chain is None:
        rag_chain = ConversationalRetrievalChain.from_llm(
            HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta"),
            retriever=retriever
        )

    conversation_history = [{"role": "system", "content": system_message}]
    
    for val in history:
        if val[0]:
            conversation_history.append({"role": "user", "content": val[0]})
        if val[1]:
            conversation_history.append({"role": "assistant", "content": val[1]})

    conversation_history.append({"role": "user", "content": message})

    response = rag_chain({"question": message, "chat_history": history})
    return response['answer']

def upload_file(filepath, pages):
    text = extract_text_from_pdf(filepath, pages)
    tables = extract_tables_from_pdf(filepath, pages)
    
    # Update documents in the vector store
    update_documents(text)
    
    return [gr.UploadButton(visible=False), 
            gr.DownloadButton(label=f"Download {Path(filepath).name}", value=filepath, visible=True), 
            f"{len(tables)} tables extracted."]  # Change to a Textbox below

# Gradio interface setup
demo = gr.Blocks()

with demo:
    with gr.Row():
        u = gr.UploadButton("Upload a file", file_count="single")
        d = gr.DownloadButton("Download the file", visible=False)
        page_input = gr.Textbox(label="Pages to Parse (e.g., 1, 2, 5-7)", placeholder="Enter page numbers or ranges")

    # Create a Textbox for the status message
    status_output = gr.Textbox(label="Status", visible=True)

    # Use the proper output components in the upload method
    u.upload(upload_file, [u, page_input], [u, d, status_output])
    
    with gr.Row():
        chat = gr.ChatInterface(
            respond,
            additional_inputs=[
                gr.Textbox(value="You are a helpful assistant.", label="System message"),
                gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
                gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
                gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
            ],
        )

if __name__ == "__main__":
    demo.launch()