EduTechTeam commited on
Commit
afbb015
·
verified ·
1 Parent(s): 4c51435

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import shutil
3
+ import os
4
+
5
+
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.document_loaders import PyPDFLoader
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain.chains import ConversationalRetrievalChain
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.llms import HuggingFacePipeline
13
+ from langchain.chains import ConversationChain
14
+ from langchain.memory import ConversationBufferMemory
15
+ from langchain_community.llms import HuggingFaceEndpoint
16
+ from langchain_openai import ChatOpenAI
17
+ import torch
18
+ import fitz
19
+ from google.colab import userdata
20
+ from dotenv import load_dotenv, set_key
21
+
22
+ load_dotenv(dotenv_path=".env")
23
+
24
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2","gpt-4o-mini"]
25
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
26
+
27
+ # Load and split PDF document
28
+ def load_doc():
29
+ # Processing for one document only
30
+ # loader = PyPDFLoader(file_path)
31
+ # pages = loader.load()
32
+ path="pdfs"
33
+ loaders = []
34
+ for file in os.listdir(path):
35
+ print(file)
36
+ print(type(file))
37
+ loader = PyPDFLoader(f"/content/pdfs/{file}")
38
+ loaders.append(loader)
39
+
40
+ pages = []
41
+ for loader in loaders:
42
+ pages.extend(loader.load())
43
+
44
+ text_splitter = RecursiveCharacterTextSplitter(
45
+ chunk_size = 200,
46
+ chunk_overlap = 64
47
+ )
48
+
49
+ doc_splits = text_splitter.split_documents(pages)
50
+ return doc_splits
51
+
52
+ # Create vector database
53
+ def create_db(splits):
54
+ embeddings = HuggingFaceEmbeddings()
55
+ vectordb = FAISS.from_documents(splits, embeddings)
56
+ return vectordb
57
+
58
+
59
+ # Initialize langchain LLM chain
60
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
+ if llm_model == "meta-llama/Meta-Llama-3-8B-Instruct":
62
+ llm = HuggingFaceEndpoint(
63
+ repo_id=llm_model,
64
+ huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN"),
65
+ temperature = temperature,
66
+ max_new_tokens = max_tokens,
67
+ top_k = top_k,
68
+ )
69
+ elif llm_model== "gpt-4o-mini":
70
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
71
+ llm = ChatOpenAI(
72
+ model_name="gpt-4o-mini",
73
+ temperature = temperature,
74
+ max_tokens = max_tokens,
75
+ )
76
+ else:
77
+ llm = HuggingFaceEndpoint(
78
+ huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN"),
79
+ repo_id=llm_model,
80
+ temperature = temperature,
81
+ max_new_tokens = max_tokens,
82
+ top_k = top_k,
83
+ )
84
+
85
+ memory = ConversationBufferMemory(
86
+ memory_key="chat_history",
87
+ output_key='answer',
88
+ return_messages=True
89
+ )
90
+
91
+ retriever=vector_db.as_retriever()
92
+ qa_chain = ConversationalRetrievalChain.from_llm(
93
+ llm,
94
+ retriever=retriever,
95
+ chain_type="stuff",
96
+ memory=memory,
97
+ return_source_documents=True,
98
+ verbose=False,
99
+ )
100
+ return qa_chain
101
+
102
+ # Initialize database
103
+ def initialize_database(list_file_obj, progress=gr.Progress()):
104
+ if not os.path.exists("pdfs"):
105
+ os.mkdir("pdfs")
106
+ for file_obj in list_file_obj:
107
+ shutil.copy(file_obj.name,"pdfs")
108
+ # Load document and create splits
109
+ doc_splits = load_doc()
110
+ # Create or load vector database
111
+ vector_db = create_db(doc_splits)
112
+ return vector_db, "Database created!"
113
+
114
+ # Initialize LLM
115
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
116
+ # print("llm_option",llm_option)
117
+ llm_name = list_llm[llm_option]
118
+ print("llm_name: ",llm_name)
119
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
120
+ return qa_chain, "QA chain initialized. Chatbot is ready!"
121
+
122
+
123
+ def format_chat_history(message, chat_history):
124
+ formatted_chat_history = []
125
+ for user_message, bot_message in chat_history:
126
+ formatted_chat_history.append(f"User: {user_message}")
127
+ formatted_chat_history.append(f"Assistant: {bot_message}")
128
+ return formatted_chat_history
129
+
130
+
131
+ def conversation(qa_chain, message, history):
132
+ formatted_chat_history = format_chat_history(message, history)
133
+ # Generate response using QA chain
134
+ response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
135
+ response_answer = response["answer"]
136
+ if response_answer.find("Helpful Answer:") != -1:
137
+ response_answer = response_answer.split("Helpful Answer:")[-1]
138
+ response_sources = response["source_documents"]
139
+ response_source1 = response_sources[0].page_content.strip()
140
+ response_source2 = response_sources[1].page_content.strip()
141
+ response_source3 = response_sources[2].page_content.strip()
142
+ # Langchain sources are zero-based
143
+ response_source1_page = response_sources[0].metadata["page"] + 1
144
+ response_source2_page = response_sources[1].metadata["page"] + 1
145
+ response_source3_page = response_sources[2].metadata["page"] + 1
146
+ # Append user message and response to chat history
147
+ new_history = history + [(message, response_answer)]
148
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
149
+
150
+ def setup_gradio_interface():
151
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue = "sky")) as demo:
152
+ vector_db = gr.State()
153
+ qa_chain = gr.State()
154
+ gr.HTML("<center><h1>RAG PDF Chatbot</h1><center>")
155
+ gr.Markdown("""<b>Query your PDF documents!</b> This AI agent is designed to perform retrieval augmented generation (RAG) on PDF documents.\
156
+ <b>Please do not upload confidential documents.</b>
157
+ """)
158
+
159
+ def set_env_vars(openai_key, huggingface_token):
160
+ """將 API 金鑰設為環境變數並儲存至 .env"""
161
+ if openai_key:
162
+ os.environ["OPENAI_API_KEY"] = openai_key
163
+ set_key(".env", "OPENAI_API_KEY", openai_key)
164
+ if huggingface_token:
165
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = huggingface_token
166
+ set_key(".env", "HUGGINGFACEHUB_API_TOKEN", huggingface_token)
167
+ return "Environment variables set successfully!"
168
+
169
+ with gr.Tab("帳號輸入"):
170
+ with gr.Row():
171
+ with gr.Column():
172
+ gr.Markdown("<b>Step 1 - Input OpenAI API Key</b>")
173
+ with gr.Row():
174
+ openai_key_input = gr.Textbox(
175
+ label="OpenAI API Key",
176
+ placeholder="Enter your OpenAI API Key",
177
+ value=os.getenv("OPENAI_API_KEY", ""),
178
+ type="password",
179
+ )
180
+ with gr.Column():
181
+ gr.Markdown("<b>Step 2 - Input HuggingFaceHub API Token</b>")
182
+ with gr.Row():
183
+ huggingface_token_input = gr.Textbox(
184
+ label="HuggingFaceHub API Token",
185
+ placeholder="Enter your HuggingFaceHub API Key",
186
+ value=os.getenv("HUGGINGFACEHUB_API_TOKEN", ""),
187
+ type="password",
188
+ )
189
+ submit_button = gr.Button("Submit")
190
+ status_output = gr.Label()
191
+
192
+ with gr.Tab("對話機器人"):
193
+ with gr.Row():
194
+ with gr.Column():
195
+ gr.Markdown("<b>Step 1 - Upload PDF documents and Initialize RAG pipeline</b>")
196
+ with gr.Row():
197
+ document = gr.Files(height=300, file_count="multiple", label="Upload PDF documents")
198
+ with gr.Row():
199
+ db_btn = gr.Button("Create vector database")
200
+ with gr.Row():
201
+ db_progress = gr.Textbox(value="Not initialized", show_label=False) # label="Vector database status",
202
+ gr.Markdown("<style>body { font-size: 16px; }</style><b>Select Large Language Model (LLM) and input parameters</b>")
203
+ with gr.Row():
204
+ llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value = list_llm_simple[0], type="index") # info="Select LLM", show_label=False
205
+ with gr.Row():
206
+ with gr.Accordion("LLM input parameters", open=False):
207
+ with gr.Row():
208
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.5, step=0.1, label="Temperature", info="Controls randomness in token generation", interactive=True)
209
+ with gr.Row():
210
+ slider_maxtokens = gr.Slider(minimum = 128, maximum = 9192, value=4096, step=128, label="Max New Tokens", info="Maximum number of tokens to be generated",interactive=True)
211
+ with gr.Row():
212
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k", info="Number of tokens to select the next token from", interactive=True)
213
+ with gr.Row():
214
+ qachain_btn = gr.Button("Initialize Question Answering Chatbot")
215
+ with gr.Row():
216
+ llm_progress = gr.Textbox(value="Not initialized", show_label=False) # label="Chatbot status",
217
+
218
+ with gr.Column():
219
+ gr.Markdown("<b>Step 2 - Chat with your Document</b>")
220
+ chatbot = gr.Chatbot(height=505)
221
+ with gr.Accordion("Relevent context from the source document", open=False):
222
+ with gr.Row():
223
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
224
+ source1_page = gr.Number(label="Page", scale=1)
225
+ with gr.Row():
226
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
227
+ source2_page = gr.Number(label="Page", scale=1)
228
+ with gr.Row():
229
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
230
+ source3_page = gr.Number(label="Page", scale=1)
231
+ with gr.Row():
232
+ msg = gr.Textbox(placeholder="Ask a question", container=True)
233
+ with gr.Row():
234
+ submit_btn = gr.Button("Submit")
235
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
236
+
237
+ # Preprocessing events
238
+ db_btn.click(initialize_database, \
239
+ inputs=[document], \
240
+ outputs=[vector_db, db_progress])
241
+ qachain_btn.click(initialize_LLM, \
242
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
243
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
244
+ inputs=None, \
245
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
246
+ queue=False)
247
+
248
+ # Chatbot events
249
+ msg.submit(conversation, \
250
+ inputs=[qa_chain, msg, chatbot], \
251
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
252
+ queue=False)
253
+ submit_btn.click(conversation, \
254
+ inputs=[qa_chain, msg, chatbot], \
255
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
256
+ queue=False)
257
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
258
+ inputs=None, \
259
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
260
+ queue=False)
261
+
262
+ def set_env_vars(openai_key, huggingface_token):
263
+ """將 API 金鑰設為環境變數並儲存至 .env"""
264
+ if openai_key:
265
+ os.environ["OPENAI_API_KEY"] = openai_key
266
+ set_key(".env", "OPENAI_API_KEY", openai_key)
267
+ if huggingface_token:
268
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = huggingface_token
269
+ set_key(".env", "HUGGINGFACEHUB_API_TOKEN", huggingface_token)
270
+ return "Environment variables set successfully!"
271
+
272
+ # 綁定按鈕與設置環境變數的函數
273
+ submit_button.click(
274
+ set_env_vars,
275
+ inputs=[openai_key_input, huggingface_token_input],
276
+ outputs=[status_output]
277
+ )
278
+
279
+ return demo
280
+
281
+ demo = setup_gradio_interface()
282
+ demo.launch(debug=True)