vishwask commited on
Commit
8cea364
·
verified ·
1 Parent(s): 2f2c64e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -73
app.py CHANGED
@@ -21,22 +21,22 @@ import tqdm
21
  import accelerate
22
 
23
 
24
- #Set parameters
25
 
26
- llm_model = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
27
- list_file_obj = '/home/user/app/pdfs/'
28
- chunk_size = 1024
29
- chunk_overlap = 128
30
- temperature = 0.1
31
- max_tokens = 6000
32
- top_k = 3
 
33
 
34
-
35
- def load_doc(list_file_path):
36
  # Processing for one document only
37
  # loader = PyPDFLoader(file_path)
38
  # pages = loader.load()
39
- loaders = [PyPDFLoader(list_file_obj+x) for x in list_file_path]
40
  pages = []
41
  for loader in loaders:
42
  pages.extend(loader.load())
@@ -48,7 +48,6 @@ def load_doc(list_file_path):
48
  return doc_splits
49
 
50
 
51
-
52
  # Create vector database
53
  def create_db(splits, collection_name):
54
  embedding = HuggingFaceEmbeddings()
@@ -62,6 +61,7 @@ def create_db(splits, collection_name):
62
  )
63
  return vectordb
64
 
 
65
  # Load vector database
66
  def load_db():
67
  embedding = HuggingFaceEmbeddings()
@@ -70,20 +70,99 @@ def load_db():
70
  embedding_function=embedding)
71
  return vectordb
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Initialize database
75
- def initialize_database(list_file_obj):
76
  # Create list of documents (when valid)
77
- #list_file_path = [x.name for x in list_file_obj if x is not None]
78
- list_file_path = os.listdir(list_file_obj)
79
  # Create collection_name for vector database
 
80
  collection_name = Path(list_file_path[0]).stem
81
  # Fix potential issues from naming convention
82
  ## Remove space
83
  collection_name = collection_name.replace(" ","-")
84
  ## Limit lenght to 50 characters
85
  collection_name = collection_name[:50]
86
- print(collection_name)
87
  ## Enforce start and end as alphanumeric character
88
  if not collection_name[0].isalnum():
89
  collection_name[0] = 'A'
@@ -91,32 +170,24 @@ def initialize_database(list_file_obj):
91
  collection_name[-1] = 'Z'
92
  # print('list_file_path: ', list_file_path)
93
  print('Collection name: ', collection_name)
 
94
  # Load document and create splits
95
- doc_splits = load_doc(list_file_path)
96
  # Create or load vector database
 
97
  # global vector_db
98
  vector_db = create_db(doc_splits, collection_name)
99
- return vector_db, collection_name
 
100
 
101
 
102
- def initialize_llmchain(vector_db):
103
- # Initialize langchain LLM chain
104
- llm = HuggingFaceHub(repo_id = llm_model,model_kwargs={"temperature": temperature,
105
- "max_new_tokens": max_tokens,
106
- "top_k": top_k,
107
- "load_in_8bit": True})
108
- retriever=vector_db.as_retriever()
109
- memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
110
- qa_chain = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,chain_type="stuff",
111
- memory=memory,return_source_documents=True,verbose=False)
112
-
113
- return qa_chain
114
-
115
- def initialize_LLM(vector_db):
116
  # print("llm_option",llm_option)
117
- llm_name = llm_model
118
- qa_chain = initialize_llmchain(vector_db)
119
- return qa_chain
 
 
120
 
121
  def format_chat_history(message, chat_history):
122
  formatted_chat_history = []
@@ -124,6 +195,7 @@ def format_chat_history(message, chat_history):
124
  formatted_chat_history.append(f"User: {user_message}")
125
  formatted_chat_history.append(f"Assistant: {bot_message}")
126
  return formatted_chat_history
 
127
 
128
  def conversation(qa_chain, message, history):
129
  formatted_chat_history = format_chat_history(message, history)
@@ -148,57 +220,60 @@ def conversation(qa_chain, message, history):
148
  # Append user message and response to chat history
149
  new_history = history + [(message, response_answer)]
150
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
151
- return qa_chain, new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def demo():
154
- with gr.Blocks() as demo:
155
  vector_db = gr.State()
156
  qa_chain = gr.State()
157
  collection_name = gr.State()
158
 
 
 
 
 
159
  chatbot = gr.Chatbot(height=300)
160
- with gr.Accordion("References", open=True):
161
  with gr.Row():
162
- doc_source1 = gr.Textbox(label="Reference 1", lines=5, container=True, scale=20)
163
  source1_page = gr.Number(label="Page", scale=1)
164
  with gr.Row():
165
- doc_source2 = gr.Textbox(label="Reference 2", lines=5, container=True, scale=20)
166
  source2_page = gr.Number(label="Page", scale=1)
167
  with gr.Row():
168
- doc_source3 = gr.Textbox(label="Reference 3", lines=5, container=True, scale=20)
169
  source3_page = gr.Number(label="Page", scale=1)
170
- with gr.Row():
171
- msg = gr.Textbox(placeholder="Type message", container=True)
172
- with gr.Row():
173
- #db_btn = gr.Button('Initialize database')
174
- #qachain_btn = gr.Button('Start chatbot')
175
- submit_btn = gr.Button("Submit")
176
- clear_btn = gr.ClearButton([msg, chatbot])
177
-
178
- # document = list_file_obj
179
- vector_db, collection_name = initialize_database(list_file_obj)
180
- print(collection_name)
181
- # #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
182
- # db_btn.click(initialize_database, \
183
- # inputs=[document], \
184
- # outputs=[vector_db, collection_name])
185
-
186
- # qachain_btn.click(initialize_LLM, \
187
- # inputs=[vector_db], \
188
- # outputs=[qa_chain]).then(lambda:[0], \
189
- # inputs=None, \
190
- # outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
191
- # queue=False)
192
-
193
- #qachain = initialize_LLM(vector_db)
194
- llm = HuggingFaceHub(repo_id = llm_model,model_kwargs={"temperature": temperature,
195
- "max_new_tokens": max_tokens,
196
- "top_k": top_k,
197
- "load_in_8bit": True})
198
- retriever=vector_db.as_retriever()
199
- memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
200
- qa_chain = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,chain_type="stuff",
201
- memory=memory,return_source_documents=True,verbose=False)
202
  # Chatbot events
203
  msg.submit(conversation, \
204
  inputs=[qa_chain, msg, chatbot], \
@@ -214,5 +289,6 @@ def demo():
214
  queue=False)
215
  demo.queue().launch(debug=True)
216
 
 
217
  if __name__ == "__main__":
218
  demo()
 
21
  import accelerate
22
 
23
 
 
24
 
25
+ # default_persist_directory = './chroma_HF/'
26
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \
27
+ "google/gemma-7b-it","google/gemma-2b-it", \
28
+ "HuggingFaceH4/zephyr-7b-beta", "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \
29
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
30
+ "google/flan-t5-xxl"
31
+ ]
32
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
33
 
34
+ # Load PDF document and create doc splits
35
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
36
  # Processing for one document only
37
  # loader = PyPDFLoader(file_path)
38
  # pages = loader.load()
39
+ loaders = [PyPDFLoader(x) for x in list_file_path]
40
  pages = []
41
  for loader in loaders:
42
  pages.extend(loader.load())
 
48
  return doc_splits
49
 
50
 
 
51
  # Create vector database
52
  def create_db(splits, collection_name):
53
  embedding = HuggingFaceEmbeddings()
 
61
  )
62
  return vectordb
63
 
64
+
65
  # Load vector database
66
  def load_db():
67
  embedding = HuggingFaceEmbeddings()
 
70
  embedding_function=embedding)
71
  return vectordb
72
 
73
+
74
+ # Initialize langchain LLM chain
75
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
76
+ progress(0.1, desc="Initializing HF tokenizer...")
77
+ # HuggingFacePipeline uses local model
78
+ # Note: it will download model locally...
79
+ # tokenizer=AutoTokenizer.from_pretrained(llm_model)
80
+ # progress(0.5, desc="Initializing HF pipeline...")
81
+ # pipeline=transformers.pipeline(
82
+ # "text-generation",
83
+ # model=llm_model,
84
+ # tokenizer=tokenizer,
85
+ # torch_dtype=torch.bfloat16,
86
+ # trust_remote_code=True,
87
+ # device_map="auto",
88
+ # # max_length=1024,
89
+ # max_new_tokens=max_tokens,
90
+ # do_sample=True,
91
+ # top_k=top_k,
92
+ # num_return_sequences=1,
93
+ # eos_token_id=tokenizer.eos_token_id
94
+ # )
95
+ # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
96
 
97
+ # HuggingFaceHub uses HF inference endpoints
98
+ progress(0.5, desc="Initializing HF Hub...")
99
+ # Use of trust_remote_code as model_kwargs
100
+ # Warning: langchain issue
101
+ # URL: https://github.com/langchain-ai/langchain/issues/6080
102
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
103
+ llm = HuggingFaceHub(
104
+ repo_id=llm_model,
105
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
106
+ )
107
+ elif llm_model == "microsoft/phi-2":
108
+ raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
109
+ llm = HuggingFaceHub(
110
+ repo_id=llm_model,
111
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
112
+ )
113
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
114
+ llm = HuggingFaceHub(
115
+ repo_id=llm_model,
116
+ model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
117
+ )
118
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
119
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
120
+ llm = HuggingFaceHub(
121
+ repo_id=llm_model,
122
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
123
+ )
124
+ else:
125
+ llm = HuggingFaceHub(
126
+ repo_id=llm_model,
127
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
128
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
129
+ )
130
+
131
+ progress(0.75, desc="Defining buffer memory...")
132
+ memory = ConversationBufferMemory(
133
+ memory_key="chat_history",
134
+ output_key='answer',
135
+ return_messages=True
136
+ )
137
+ # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
138
+ retriever=vector_db.as_retriever()
139
+ progress(0.8, desc="Defining retrieval chain...")
140
+ qa_chain = ConversationalRetrievalChain.from_llm(
141
+ llm,
142
+ retriever=retriever,
143
+ chain_type="stuff",
144
+ memory=memory,
145
+ # combine_docs_chain_kwargs={"prompt": your_prompt})
146
+ return_source_documents=True,
147
+ #return_generated_question=False,
148
+ verbose=False,
149
+ )
150
+ progress(0.9, desc="Done!")
151
+ return qa_chain
152
+
153
+
154
  # Initialize database
155
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
156
  # Create list of documents (when valid)
157
+ list_file_path = [x.name for x in list_file_obj if x is not None]
 
158
  # Create collection_name for vector database
159
+ progress(0.1, desc="Creating collection name...")
160
  collection_name = Path(list_file_path[0]).stem
161
  # Fix potential issues from naming convention
162
  ## Remove space
163
  collection_name = collection_name.replace(" ","-")
164
  ## Limit lenght to 50 characters
165
  collection_name = collection_name[:50]
 
166
  ## Enforce start and end as alphanumeric character
167
  if not collection_name[0].isalnum():
168
  collection_name[0] = 'A'
 
170
  collection_name[-1] = 'Z'
171
  # print('list_file_path: ', list_file_path)
172
  print('Collection name: ', collection_name)
173
+ progress(0.25, desc="Loading document...")
174
  # Load document and create splits
175
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
176
  # Create or load vector database
177
+ progress(0.5, desc="Generating vector database...")
178
  # global vector_db
179
  vector_db = create_db(doc_splits, collection_name)
180
+ progress(0.9, desc="Done!")
181
+ return vector_db, collection_name, "Complete!"
182
 
183
 
184
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # print("llm_option",llm_option)
186
+ llm_name = list_llm[llm_option]
187
+ print("llm_name: ",llm_name)
188
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
189
+ return qa_chain, "Complete!"
190
+
191
 
192
  def format_chat_history(message, chat_history):
193
  formatted_chat_history = []
 
195
  formatted_chat_history.append(f"User: {user_message}")
196
  formatted_chat_history.append(f"Assistant: {bot_message}")
197
  return formatted_chat_history
198
+
199
 
200
  def conversation(qa_chain, message, history):
201
  formatted_chat_history = format_chat_history(message, history)
 
220
  # Append user message and response to chat history
221
  new_history = history + [(message, response_answer)]
222
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
223
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
224
+
225
+
226
+ def upload_file(file_obj):
227
+ list_file_path = []
228
+ for idx, file in enumerate(file_obj):
229
+ file_path = file_obj.name
230
+ list_file_path.append(file_path)
231
+ # print(file_path)
232
+ # initialize_database(file_path, progress)
233
+ return list_file_path
234
+
235
 
236
  def demo():
237
+ with gr.Blocks(theme="base") as demo:
238
  vector_db = gr.State()
239
  qa_chain = gr.State()
240
  collection_name = gr.State()
241
 
242
+
243
+ document = gr.Files(value = '/home/user/app/pdfs/Annual-Report-2022-2023-English_1.pdf',height=100,
244
+ file_count="multiple", file_types=["pdf"], interactive=True, visible=False,
245
+ label="Upload your PDF documents (single or multiple)")
246
  chatbot = gr.Chatbot(height=300)
247
+ with gr.Accordion("Advanced - Document references", open=False):
248
  with gr.Row():
249
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
250
  source1_page = gr.Number(label="Page", scale=1)
251
  with gr.Row():
252
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
253
  source2_page = gr.Number(label="Page", scale=1)
254
  with gr.Row():
255
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
256
  source3_page = gr.Number(label="Page", scale=1)
257
+ with gr.Row():
258
+ msg = gr.Textbox(placeholder="Type message", container=True)
259
+ with gr.Row():
260
+ db_btn = gr.Button("Generate vector database...")
261
+ qachain_btn = gr.Button("Initialize question-answering chain...")
262
+ submit_btn = gr.Button("Submit")
263
+ clear_btn = gr.ClearButton([msg, chatbot])
264
+
265
+ # Preprocessing events
266
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
267
+ db_btn.click(initialize_database, \
268
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
269
+ outputs=[vector_db, collection_name, db_progress])
270
+ qachain_btn.click(initialize_LLM, \
271
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
272
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
273
+ inputs=None, \
274
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
275
+ queue=False)
276
+
 
 
 
 
 
 
 
 
 
 
 
 
277
  # Chatbot events
278
  msg.submit(conversation, \
279
  inputs=[qa_chain, msg, chatbot], \
 
289
  queue=False)
290
  demo.queue().launch(debug=True)
291
 
292
+
293
  if __name__ == "__main__":
294
  demo()