basildarwazeh commited on
Commit
2ebbe14
·
verified ·
1 Parent(s): b4e3d76

Upload 4 files

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -35
  2. README.md +12 -13
  3. ap1.py +174 -0
  4. requirements.txt +9 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,12 @@
1
- ---
2
- title: Test
3
- emoji: 🏢
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.31.5
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: RAG PDF Chatbot
3
+ emoji: 📚
4
+ colorFrom: gray
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.31.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
ap1.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ api_token = os.getenv("HF_TOKEN")
4
+
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_community.llms import HuggingFacePipeline
12
+ from langchain.chains import ConversationChain
13
+ from langchain.memory import ConversationBufferMemory
14
+ from langchain_community.llms import HuggingFaceEndpoint
15
+ import torch
16
+
17
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
18
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
19
+
20
+ # Load and split PDF document
21
+ def load_doc(list_file_path):
22
+ loaders = [PyPDFLoader(x) for x in list_file_path]
23
+ pages = []
24
+ for loader in loaders:
25
+ pages.extend(loader.load())
26
+ text_splitter = RecursiveCharacterTextSplitter(
27
+ chunk_size=1024,
28
+ chunk_overlap=64
29
+ )
30
+ doc_splits = text_splitter.split_documents(pages)
31
+ return doc_splits
32
+
33
+ # Create vector database
34
+ def create_db(splits):
35
+ embeddings = HuggingFaceEmbeddings()
36
+ vectordb = FAISS.from_documents(splits, embeddings)
37
+ return vectordb
38
+
39
+ # Initialize langchain LLM chain
40
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
41
+ llm = HuggingFaceEndpoint(
42
+ repo_id=llm_model,
43
+ huggingfacehub_api_token=api_token,
44
+ temperature=temperature,
45
+ max_new_tokens=max_tokens,
46
+ top_k=top_k,
47
+ )
48
+
49
+ memory = ConversationBufferMemory(
50
+ memory_key="chat_history",
51
+ output_key='answer',
52
+ return_messages=True
53
+ )
54
+
55
+ retriever = vector_db.as_retriever()
56
+ qa_chain = ConversationalRetrievalChain.from_llm(
57
+ llm,
58
+ retriever=retriever,
59
+ chain_type="stuff",
60
+ memory=memory,
61
+ return_source_documents=True,
62
+ verbose=False,
63
+ )
64
+ return qa_chain
65
+
66
+ # Initialize database
67
+ def initialize_database(list_file_obj, progress=gr.Progress()):
68
+ list_file_path = [x.name for x in list_file_obj if x is not None]
69
+ doc_splits = load_doc(list_file_path)
70
+ vector_db = create_db(doc_splits)
71
+ return vector_db, "Database created!"
72
+
73
+ # Initialize LLM
74
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
75
+ llm_name = list_llm[llm_option]
76
+ print("llm_name: ", llm_name)
77
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
78
+ return qa_chain, "QA chain initialized. Chatbot is ready!"
79
+
80
+ def format_chat_history(message, chat_history):
81
+ formatted_chat_history = []
82
+ for user_message, bot_message in chat_history:
83
+ formatted_chat_history.append(f"User: {user_message}")
84
+ formatted_chat_history.append(f"Assistant: {bot_message}")
85
+ return formatted_chat_history
86
+
87
+ def conversation(qa_chain, message, history):
88
+ formatted_chat_history = format_chat_history(message, history)
89
+ response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
90
+ response_answer = response["answer"]
91
+ if "Helpful Answer:" in response_answer:
92
+ response_answer = response_answer.split("Helpful Answer:")[-1]
93
+ response_sources = response["source_documents"]
94
+ response_source1 = response_sources[0].page_content.strip()
95
+ response_source2 = response_sources[1].page_content.strip()
96
+ response_source3 = response_sources[2].page_content.strip()
97
+ response_source1_page = response_sources[0].metadata["page"] + 1
98
+ response_source2_page = response_sources[1].metadata["page"] + 1
99
+ response_source3_page = response_sources[2].metadata["page"] + 1
100
+ new_history = history + [(message, response_answer)]
101
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
102
+
103
+ def upload_file(file_obj):
104
+ list_file_path = []
105
+ for idx, file in enumerate(file_obj):
106
+ file_path = file.name
107
+ list_file_path.append(file_path)
108
+ return list_file_path
109
+
110
+ def demo():
111
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="sky")) as demo:
112
+ vector_db = gr.State()
113
+ qa_chain = gr.State()
114
+ gr.HTML("<center><h1>RAG PDF Chatbot</h1><center>")
115
+ gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents. The app is hosted on Hugging Face Hub for the sole purpose of demonstration. <b>Please do not upload confidential documents.</b>""")
116
+
117
+ with gr.Row():
118
+ with gr.Column(scale=86):
119
+ gr.Markdown("<b>Step 1 - Upload PDF documents and Initialize RAG pipeline</b>")
120
+ with gr.Row():
121
+ document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
122
+ with gr.Row():
123
+ db_btn = gr.Button("Create vector database")
124
+ with gr.Row():
125
+ db_progress = gr.Textbox(value="Not initialized", show_label=False)
126
+
127
+ gr.Markdown("<style>body { font-size: 16px; }</style><b>Select Large Language Model (LLM) and input parameters</b>")
128
+ with gr.Row():
129
+ llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
130
+ with gr.Row():
131
+ with gr.Accordion("LLM input parameters", open=False):
132
+ with gr.Row():
133
+ slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
134
+ with gr.Row():
135
+ slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated", interactive=True)
136
+ with gr.Row():
137
+ slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
138
+ with gr.Row():
139
+ qachain_btn = gr.Button("Initialize Question Answering Chatbot")
140
+ with gr.Row():
141
+ llm_progress = gr.Textbox(value="Not initialized", show_label=False)
142
+
143
+ with gr.Column(scale=200):
144
+ gr.Markdown("<b>Step 2 - Chat with your Document</b>")
145
+ chatbot = gr.Chatbot(height=505)
146
+ with gr.Accordion("Relevant context from the source document", open=False):
147
+ with gr.Row():
148
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
149
+ source1_page = gr.Number(label="Page", scale=1)
150
+ with gr.Row():
151
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
152
+ source2_page = gr.Number(label="Page", scale=1)
153
+ with gr.Row():
154
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
155
+ source3_page = gr.Number(label="Page", scale=1)
156
+ with gr.Row():
157
+ msg = gr.Textbox(placeholder="Ask a question", container=True)
158
+ with gr.Row():
159
+ submit_btn = gr.Button("Submit")
160
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
161
+
162
+ # Preprocessing events
163
+ db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_progress])
164
+ qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
165
+
166
+ # Chatbot events
167
+ msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
168
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
169
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], inputs=None, outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], queue=False)
170
+
171
+ demo.queue().launch(debug=True)
172
+
173
+ if __name__ == "__main__":
174
+ demo()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence-transformers
4
+ langchain
5
+ langchain-community
6
+ tqdm
7
+ accelerate
8
+ pypdf
9
+ faiss-gpu