Shreyas094 commited on
Commit
9acbbf9
·
verified ·
1 Parent(s): 7038b6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -7
app.py CHANGED
@@ -45,6 +45,8 @@ llama_parser = LlamaParse(
45
  language="en",
46
  )
47
 
 
 
48
  def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[Document]:
49
  """Loads and splits the document into pages."""
50
  if parser == "pypdf":
@@ -66,6 +68,7 @@ def get_embeddings():
66
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
67
 
68
  def update_vectors(files, parser):
 
69
  if not files:
70
  return "Please upload at least one PDF file."
71
 
@@ -77,6 +80,7 @@ def update_vectors(files, parser):
77
  data = load_document(file, parser)
78
  all_data.extend(data)
79
  total_chunks += len(data)
 
80
 
81
  if os.path.exists("faiss_database"):
82
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
@@ -214,10 +218,11 @@ def retry_last_response(history, use_web_search, model, temperature, num_calls):
214
 
215
  return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
216
 
217
- def respond(message, history, model, temperature, num_calls, use_web_search):
218
  logging.info(f"User Query: {message}")
219
  logging.info(f"Model Used: {model}")
220
  logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
 
221
 
222
  try:
223
  if use_web_search:
@@ -231,10 +236,20 @@ def respond(message, history, model, temperature, num_calls, use_web_search):
231
  if os.path.exists("faiss_database"):
232
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
233
  retriever = database.as_retriever()
234
- relevant_docs = retriever.get_relevant_documents(message)
 
 
 
 
 
 
 
 
235
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
236
  else:
237
  context_str = "No documents available."
 
 
238
 
239
  if model == "@cf/meta/llama-3.1-8b-instruct":
240
  # Use Cloudflare API
@@ -244,7 +259,7 @@ def respond(message, history, model, temperature, num_calls, use_web_search):
244
  yield partial_response
245
  else:
246
  # Use Hugging Face API
247
- for partial_response in get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature):
248
  first_line = partial_response.split('\n')[0] if partial_response else ''
249
  logging.info(f"Generated Response (first line): {first_line}")
250
  yield partial_response
@@ -253,7 +268,7 @@ def respond(message, history, model, temperature, num_calls, use_web_search):
253
  if "microsoft/Phi-3-mini-4k-instruct" in model:
254
  logging.info("Falling back to Mistral model due to Phi-3 error")
255
  fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
256
- yield from respond(message, history, fallback_model, temperature, num_calls, use_web_search)
257
  else:
258
  yield f"An error occurred with the {model} model: {str(e)}. Please try again or select a different model."
259
 
@@ -344,7 +359,7 @@ After writing the document, please provide a list of sources used in your respon
344
  main_content += chunk
345
  yield main_content, "" # Yield partial main content without sources
346
 
347
- def get_response_from_pdf(query, model, num_calls=3, temperature=0.2):
348
  embed = get_embeddings()
349
  if os.path.exists("faiss_database"):
350
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
@@ -354,7 +369,11 @@ def get_response_from_pdf(query, model, num_calls=3, temperature=0.2):
354
 
355
  retriever = database.as_retriever()
356
  relevant_docs = retriever.get_relevant_documents(query)
357
- context_str = "\n".join([doc.page_content for doc in relevant_docs])
 
 
 
 
358
 
359
  if model == "@cf/meta/llama-3.1-8b-instruct":
360
  # Use Cloudflare API with the retrieved context
@@ -392,6 +411,15 @@ css = """
392
  """
393
 
394
  # Define the checkbox outside the demo block
 
 
 
 
 
 
 
 
 
395
  use_web_search = gr.Checkbox(label="Use Web Search", value=False)
396
 
397
  demo = gr.ChatInterface(
@@ -400,7 +428,8 @@ demo = gr.ChatInterface(
400
  gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]),
401
  gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
402
  gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
403
- use_web_search # Add this line to include the checkbox
 
404
  ],
405
  title="AI-powered Web Search and PDF Chat Assistant",
406
  description="Chat with your PDFs or use web search to answer questions.",
 
45
  language="en",
46
  )
47
 
48
+ uploaded_documents = []
49
+
50
  def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[Document]:
51
  """Loads and splits the document into pages."""
52
  if parser == "pypdf":
 
68
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
69
 
70
  def update_vectors(files, parser):
71
+ global uploaded_documents
72
  if not files:
73
  return "Please upload at least one PDF file."
74
 
 
80
  data = load_document(file, parser)
81
  all_data.extend(data)
82
  total_chunks += len(data)
83
+ uploaded_documents.append({"name": file.name, "selected": True})
84
 
85
  if os.path.exists("faiss_database"):
86
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
 
218
 
219
  return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
220
 
221
+ def respond(message, history, model, temperature, num_calls, use_web_search, selected_docs):
222
  logging.info(f"User Query: {message}")
223
  logging.info(f"Model Used: {model}")
224
  logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
225
+ logging.info(f"Selected Documents: {selected_docs}")
226
 
227
  try:
228
  if use_web_search:
 
236
  if os.path.exists("faiss_database"):
237
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
238
  retriever = database.as_retriever()
239
+
240
+ # Filter relevant documents based on user selection
241
+ all_relevant_docs = retriever.get_relevant_documents(message)
242
+ relevant_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
243
+
244
+ if not relevant_docs:
245
+ yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
246
+ return
247
+
248
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
249
  else:
250
  context_str = "No documents available."
251
+ yield "No documents available. Please upload PDF documents to answer questions."
252
+ return
253
 
254
  if model == "@cf/meta/llama-3.1-8b-instruct":
255
  # Use Cloudflare API
 
259
  yield partial_response
260
  else:
261
  # Use Hugging Face API
262
+ for partial_response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
263
  first_line = partial_response.split('\n')[0] if partial_response else ''
264
  logging.info(f"Generated Response (first line): {first_line}")
265
  yield partial_response
 
268
  if "microsoft/Phi-3-mini-4k-instruct" in model:
269
  logging.info("Falling back to Mistral model due to Phi-3 error")
270
  fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
271
+ yield from respond(message, history, fallback_model, temperature, num_calls, use_web_search, selected_docs)
272
  else:
273
  yield f"An error occurred with the {model} model: {str(e)}. Please try again or select a different model."
274
 
 
359
  main_content += chunk
360
  yield main_content, "" # Yield partial main content without sources
361
 
362
+ def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
363
  embed = get_embeddings()
364
  if os.path.exists("faiss_database"):
365
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
 
369
 
370
  retriever = database.as_retriever()
371
  relevant_docs = retriever.get_relevant_documents(query)
372
+
373
+ # Filter relevant_docs based on selected documents
374
+ filtered_docs = [doc for doc in relevant_docs if doc.metadata["source"] in selected_docs]
375
+
376
+ context_str = "\n".join([doc.page_content for doc in filtered_docs])
377
 
378
  if model == "@cf/meta/llama-3.1-8b-instruct":
379
  # Use Cloudflare API with the retrieved context
 
411
  """
412
 
413
  # Define the checkbox outside the demo block
414
+ def display_documents():
415
+ return gr.CheckboxGroup(
416
+ choices=[doc["name"] for doc in uploaded_documents],
417
+ value=[doc["name"] for doc in uploaded_documents if doc["selected"]],
418
+ label="Select documents to query"
419
+ )
420
+
421
+ document_selector = gr.CheckboxGroup(label="Select documents to query")
422
+
423
  use_web_search = gr.Checkbox(label="Use Web Search", value=False)
424
 
425
  demo = gr.ChatInterface(
 
428
  gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]),
429
  gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
430
  gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
431
+ use_web_search,
432
+ document_selector # Add this line to include the document selector
433
  ],
434
  title="AI-powered Web Search and PDF Chat Assistant",
435
  description="Chat with your PDFs or use web search to answer questions.",