vishwask commited on
Commit
036be0e
·
verified ·
1 Parent(s): bdc42f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -3
app.py CHANGED
@@ -126,7 +126,9 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
126
  progress(0.9, desc="Done!")
127
  return qa_chain
128
 
129
- def start(llm_model, temperature, max_tokens, top_k, vector_db, list_file_obj, chunk_size, chunk_overlap):
 
 
130
  # HuggingFaceHub uses HF inference endpoints
131
  # Use of trust_remote_code as model_kwargs
132
  # Warning: langchain issue
@@ -174,7 +176,71 @@ def start(llm_model, temperature, max_tokens, top_k, vector_db, list_file_obj, c
174
  # Create or load vector database
175
  vector_db = create_db(doc_splits, collection_name)
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- return qa_chain, vector_db, collection_name
179
-
 
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  progress(0.9, desc="Done!")
127
  return qa_chain
128
 
129
+ def start(llm_model, temperature, max_tokens, top_k,
130
+ vector_db, list_file_obj, chunk_size, chunk_overlap,
131
+ qa_chain, message, history):
132
  # HuggingFaceHub uses HF inference endpoints
133
  # Use of trust_remote_code as model_kwargs
134
  # Warning: langchain issue
 
176
  # Create or load vector database
177
  vector_db = create_db(doc_splits, collection_name)
178
 
179
+ formatted_chat_history = format_chat_history(message, history)
180
+ #print("formatted_chat_history",formatted_chat_history)
181
+
182
+ # Generate response using QA chain
183
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
184
+ response_answer = response["answer"]
185
+ if response_answer.find("Helpful Answer:") != -1:
186
+ response_answer = response_answer.split("Helpful Answer:")[-1]
187
+ response_sources = response["source_documents"]
188
+ response_source1 = response_sources[0].page_content.strip()
189
+ response_source2 = response_sources[1].page_content.strip()
190
+ response_source3 = response_sources[2].page_content.strip()
191
+ # Langchain sources are zero-based
192
+ response_source1_page = response_sources[0].metadata["page"] + 1
193
+ response_source2_page = response_sources[1].metadata["page"] + 1
194
+ response_source3_page = response_sources[2].metadata["page"] + 1
195
+ # print ('chat response: ', response_answer)
196
+ # print('DB source', response_sources)
197
 
198
+ # Append user message and response to chat history
199
+ new_history = history + [(message, response_answer)]
200
+
201
+ return qa_chain, vector_db, collection_name, new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
202
 
203
+ def demo():
204
+ with gr.Blocks(theme="base") as demo:
205
+ vector_db = gr.State()
206
+ qa_chain = gr.State()
207
+ collection_name = gr.State()
208
+
209
+ chatbot = gr.Chatbot(height=300)
210
+ with gr.Accordion("Advanced - Document references", open=False):
211
+ with gr.Row():
212
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
213
+ source1_page = gr.Number(label="Page", scale=1)
214
+ with gr.Row():
215
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
216
+ source2_page = gr.Number(label="Page", scale=1)
217
+ with gr.Row():
218
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
219
+ source3_page = gr.Number(label="Page", scale=1)
220
+ with gr.Row():
221
+ msg = gr.Textbox(placeholder="Type message", container=True)
222
+ with gr.Row():
223
+ submit_btn = gr.Button("Submit")
224
+ clear_btn = gr.ClearButton([msg, chatbot])
225
+
226
+ msg.submit(start,
227
+ inputs=[llm_model, temperature, max_tokens, top_k,
228
+ vector_db, list_file_obj, chunk_size, chunk_overlap,
229
+ qa_chain, message, history],
230
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page,
231
+ doc_source2, source2_page,
232
+ doc_source3, source3_page],
233
+ queue=False)
234
+ submit_btn.click(conversation, \
235
+ inputs=[qa_chain, msg, chatbot], \
236
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
237
+ queue=False)
238
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
239
+ inputs=None, \
240
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
241
+ queue=False)
242
+
243
+ demo.queue().launch(debug=True)
244
+
245
+ if __name__ == "__main__":
246
+ demo()