Shreyas094 commited on
Commit
8b01918
1 Parent(s): 0f075d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -80
app.py CHANGED
@@ -210,69 +210,103 @@ def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_
210
  print(f"Result {i}:")
211
  print(f" Link: {result['link']}")
212
  if result['text']:
213
- print(f" Text: {result['text'][:100]}...") # Display the first 100 characters of the text for brevity
214
  else:
215
- print(" No text extracted")
 
 
 
 
 
 
216
  return all_results
217
 
218
- def process_question(question, documents, history, temperature, top_p, repetition_penalty, enable_web_search):
219
  global conversation_history
220
 
221
- embeddings = get_embeddings()
 
222
 
223
- # Check the memory database for similar questions
224
- for prev_question, prev_answer in memory_database.items():
225
- similarity = get_similarity(question, prev_question)
226
- if similarity > 0.8:
227
- return prev_answer
228
-
229
- # Retrieve relevant documents from the vector store
230
- if os.path.exists("faiss_database"):
231
- db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
232
- relevant_docs = db.similarity_search(question, k=3)
233
  else:
234
- relevant_docs = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
- # Perform web search if enabled and no relevant documents found
237
- if enable_web_search and len(relevant_docs) == 0:
238
- web_search_results = google_search(question, num_results=5)
239
- web_docs = [Document(page_content=res["text"] or "", metadata={"source": res["link"]}) for res in web_search_results if res["text"]]
240
 
241
- if web_docs:
242
- # Update the FAISS vector store with new documents
243
- create_or_update_database(web_docs, embeddings)
244
- db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
245
- relevant_docs = db.similarity_search(question, k=3)
 
246
 
247
- context = "\n\n".join([doc.page_content for doc in relevant_docs])
 
248
 
249
- if is_related_to_history(question, history):
250
- context = "None"
251
- else:
252
- history_text = "\n".join([f"Q: {h['question']}\nA: {h['answer']}" for h in history]) if history else "None"
253
- context = context if context else "None"
254
 
255
- prompt_text = ChatPromptTemplate(
256
- input_variables=["history", "context", "question"],
257
- template=prompt
258
- ).format(history=history_text, context=context, question=question)
259
 
260
- model = get_model(temperature, top_p, repetition_penalty)
261
- answer = generate_chunked_response(model, prompt_text)
262
-
263
- conversation_history = manage_conversation_history(question, answer, history)
264
- memory_database[question] = answer
265
 
266
  return answer
267
 
268
- def process_uploaded_file(file, is_recursive):
269
- if is_recursive:
270
- data = load_and_split_document_recursive(file)
271
- else:
272
- data = load_and_split_document_basic(file)
273
- embeddings = get_embeddings()
274
- create_or_update_database(data, embeddings)
275
- return "File processed and data added to the vector database."
 
 
 
 
 
 
 
 
276
 
277
  def extract_db_to_excel():
278
  embed = get_embeddings()
@@ -303,43 +337,47 @@ def export_memory_db_to_excel():
303
 
304
  return excel_path
305
 
 
306
  with gr.Blocks() as demo:
307
- with gr.Row():
308
- pdf_file = gr.File(label="Upload PDF")
309
- with gr.Row():
310
- recursive_check = gr.Checkbox(label="Use Recursive Text Splitter")
311
- upload_button = gr.Button("Upload and Process")
312
- with gr.Row():
313
- upload_output = gr.Textbox(label="Upload Output")
314
 
315
  with gr.Row():
316
- question = gr.Textbox(label="Your Question")
317
- with gr.Row():
318
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
319
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
320
- repetition_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, label="Repetition Penalty")
321
- web_search_check = gr.Checkbox(label="Enable Web Search")
322
- with gr.Row():
323
- ask_button = gr.Button("Ask")
324
- with gr.Row():
325
- answer = gr.Textbox(label="Answer")
326
 
327
  with gr.Row():
328
- clear_button = gr.Button("Clear Cache")
329
- with gr.Row():
330
- clear_output = gr.Textbox(label="Clear Output")
331
-
332
- with gr.Row():
333
- export_db_button = gr.Button("Export Database to Excel")
334
- export_db_output = gr.Textbox(label="Export Output")
335
- with gr.Row():
336
- export_memory_button = gr.Button("Export Memory DB to Excel")
337
- export_memory_output = gr.Textbox(label="Export Output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
- upload_button.click(process_uploaded_file, [pdf_file, recursive_check], upload_output)
340
- ask_button.click(process_question, [question, pdf_file, conversation_history, temperature, top_p, repetition_penalty, web_search_check], answer)
341
- clear_button.click(clear_cache, [], clear_output)
342
- export_db_button.click(extract_db_to_excel, [], export_db_output)
343
- export_memory_button.click(export_memory_db_to_excel, [], export_memory_output)
344
 
345
- demo.launch()
 
 
210
  print(f"Result {i}:")
211
  print(f" Link: {result['link']}")
212
  if result['text']:
213
+ print(f" Text: {result['text'][:100]}...") # Print first 100 characters
214
  else:
215
+ print(" Text: None")
216
+ print("End of search results")
217
+
218
+ if not all_results:
219
+ print("No search results found. Returning a default message.")
220
+ return [{"link": None, "text": "No information found in the web search results."}]
221
+
222
  return all_results
223
 
224
+ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
225
  global conversation_history
226
 
227
+ if not question:
228
+ return "Please enter a question."
229
 
230
+ if question in memory_database and not web_search:
231
+ answer = memory_database[question]
 
 
 
 
 
 
 
 
232
  else:
233
+ model = get_model(temperature, top_p, repetition_penalty)
234
+ embed = get_embeddings()
235
+
236
+ if web_search:
237
+ search_results = google_search(question)
238
+ context_str = "\n".join([result["text"] for result in search_results if result["text"]])
239
+
240
+ # Convert web search results to Document format
241
+ web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
242
+
243
+ # Create a temporary FAISS database for web search results
244
+ temp_database = FAISS.from_documents(web_docs, embed)
245
+
246
+ retriever = temp_database.as_retriever()
247
+ relevant_docs = retriever.get_relevant_documents(question)
248
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
249
+
250
+ prompt_template = """
251
+ Answer the question based on the following web search results:
252
+ Web Search Results:
253
+ {context}
254
+ Current Question: {question}
255
+ If the web search results don't contain relevant information, state that the information is not available in the search results.
256
+ Provide a concise and direct answer to the question without mentioning the web search or these instructions:
257
+ """
258
+ prompt_val = ChatPromptTemplate.from_template(prompt_template)
259
+ formatted_prompt = prompt_val.format(context=context_str, question=question)
260
+ else:
261
+ # Check if the FAISS database exists
262
+ if os.path.exists("faiss_database"):
263
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
264
+ else:
265
+ return "No FAISS database found. Please upload documents to create the vector store."
266
 
267
+ history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
 
 
 
268
 
269
+ if is_related_to_history(question, conversation_history):
270
+ context_str = "No additional context needed. Please refer to the conversation history."
271
+ else:
272
+ retriever = database.as_retriever()
273
+ relevant_docs = retriever.get_relevant_documents(question)
274
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
275
 
276
+ prompt_val = ChatPromptTemplate.from_template(prompt)
277
+ formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
278
 
279
+ answer = generate_chunked_response(model, formatted_prompt)
280
+ answer = re.split(r'Question:|Current Question:', answer)[-1].strip()
 
 
 
281
 
282
+ # Remove any remaining prompt instructions from the answer
283
+ answer_lines = answer.split('\n')
284
+ answer = '\n'.join(line for line in answer_lines if not line.startswith('If') and not line.startswith('Provide'))
 
285
 
286
+ if not web_search:
287
+ memory_database[question] = answer
288
+
289
+ if not web_search:
290
+ conversation_history = manage_conversation_history(question, answer, conversation_history)
291
 
292
  return answer
293
 
294
+ def update_vectors(files, use_recursive_splitter):
295
+ if not files:
296
+ return "Please upload at least one PDF file."
297
+
298
+ embed = get_embeddings()
299
+ total_chunks = 0
300
+
301
+ for file in files:
302
+ if use_recursive_splitter:
303
+ data = load_and_split_document_recursive(file)
304
+ else:
305
+ data = load_and_split_document_basic(file)
306
+ create_or_update_database(data, embed)
307
+ total_chunks += len(data)
308
+
309
+ return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
310
 
311
  def extract_db_to_excel():
312
  embed = get_embeddings()
 
337
 
338
  return excel_path
339
 
340
+ # Gradio interface
341
  with gr.Blocks() as demo:
342
+ gr.Markdown("# Chat with your PDF documents")
 
 
 
 
 
 
343
 
344
  with gr.Row():
345
+ file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
346
+ update_button = gr.Button("Update Vector Store")
347
+ use_recursive_splitter = gr.Checkbox(label="Use Recursive Text Splitter", value=False)
348
+
349
+ update_output = gr.Textbox(label="Update Status")
350
+ update_button.click(update_vectors, inputs=[file_input, use_recursive_splitter], outputs=update_output)
 
 
 
 
351
 
352
  with gr.Row():
353
+ with gr.Column(scale=2):
354
+ chatbot = gr.Chatbot(label="Conversation")
355
+ question_input = gr.Textbox(label="Ask a question about your documents")
356
+ submit_button = gr.Button("Submit")
357
+ with gr.Column(scale=1):
358
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
359
+ top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
360
+ repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
361
+ web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
362
+
363
+ def chat(question, history):
364
+ answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value, web_search_checkbox.value)
365
+ history.append((question, answer))
366
+ return "", history
367
+
368
+ submit_button.click(chat, inputs=[question_input, chatbot], outputs=[question_input, chatbot])
369
+
370
+ extract_button = gr.Button("Extract Database to Excel")
371
+ excel_output = gr.File(label="Download Excel File")
372
+ extract_button.click(extract_db_to_excel, inputs=[], outputs=excel_output)
373
+
374
+ export_memory_button = gr.Button("Export Memory Database to Excel")
375
+ memory_excel_output = gr.File(label="Download Memory Excel File")
376
+ export_memory_button.click(export_memory_db_to_excel, inputs=[], outputs=memory_excel_output)
377
 
378
+ clear_button = gr.Button("Clear Cache")
379
+ clear_output = gr.Textbox(label="Cache Status")
380
+ clear_button.click(clear_cache, inputs=[], outputs=clear_output)
 
 
381
 
382
+ if __name__ == "__main__":
383
+ demo.launch()