Shreyas094 commited on
Commit
8b05473
·
verified ·
1 Parent(s): d613eb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -21
app.py CHANGED
@@ -7,7 +7,7 @@ import requests
7
  import random
8
  import urllib.parse
9
  from tempfile import NamedTemporaryFile
10
- from typing import List
11
  from bs4 import BeautifulSoup
12
  from langchain.prompts import PromptTemplate
13
  from langchain.chains import LLMChain
@@ -17,10 +17,72 @@ from langchain_community.document_loaders import PyPDFLoader
17
  from langchain_core.output_parsers import StrOutputParser
18
  from langchain_community.embeddings import HuggingFaceEmbeddings
19
  from langchain_community.llms import HuggingFaceHub
20
- from langchain_core.documents import Document # Add this line
21
 
22
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def load_document(file: NamedTemporaryFile) -> List[Document]:
25
  """Loads and splits the document into pages."""
26
  loader = PyPDFLoader(file.name)
@@ -207,6 +269,8 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
207
 
208
  model = get_model(temperature, top_p, repetition_penalty)
209
  embed = get_embeddings()
 
 
210
 
211
  if os.path.exists("faiss_database"):
212
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
@@ -219,16 +283,10 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
219
  for attempt in range(max_attempts):
220
  try:
221
  if web_search:
222
- original_query = question
223
- rephrased_query = rephrase_for_search(original_query, model)
224
- print(f"Original query: {original_query}")
225
- print(f"Rephrased query: {rephrased_query}")
226
-
227
- if rephrased_query == original_query:
228
- print("Warning: Query was not rephrased. Using original query for search.")
229
-
230
- search_results = google_search(rephrased_query)
231
- web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
232
 
233
  if database is None:
234
  database = FAISS.from_documents(web_docs, embed)
@@ -237,20 +295,17 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
237
 
238
  database.save_local("faiss_database")
239
 
240
- context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
241
 
242
  prompt_template = """
243
  Answer the question based on the following web search results:
244
  Web Search Results:
245
  {context}
246
- Original Question: {original_question}
247
- Rephrased Search Query: {rephrased_query}
248
  If the web search results don't contain relevant information, state that the information is not available in the search results.
249
  Provide a concise and direct answer to the original question without mentioning the web search or these instructions.
250
  Do not include any source information in your answer.
251
  """
252
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
253
- formatted_prompt = prompt_val.format(context=context_str, original_question=question, rephrased_query=rephrased_query)
254
  else:
255
  if database is None:
256
  return "No documents available. Please upload documents or enable web search to answer questions."
@@ -259,7 +314,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
259
  relevant_docs = retriever.get_relevant_documents(question)
260
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
261
 
262
- # Reduce context if we're not on the first attempt
263
  if attempt > 0:
264
  words = context_str.split()
265
  context_str = " ".join(words[:int(len(words) * context_reduction_factor)])
@@ -273,8 +327,9 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
273
  Provide a concise and direct answer to the question.
274
  Do not include any source information in your answer.
275
  """
276
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
277
- formatted_prompt = prompt_val.format(context=context_str, question=question)
 
278
 
279
  full_response = generate_chunked_response(model, formatted_prompt)
280
 
@@ -294,7 +349,16 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
294
  else:
295
  answer = full_response.strip()
296
 
297
- # Add sources section
 
 
 
 
 
 
 
 
 
298
  if web_search:
299
  sources = set(doc.metadata['source'] for doc in web_docs)
300
  sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
 
7
  import random
8
  import urllib.parse
9
  from tempfile import NamedTemporaryFile
10
+ from typing import List, Dict
11
  from bs4 import BeautifulSoup
12
  from langchain.prompts import PromptTemplate
13
  from langchain.chains import LLMChain
 
17
  from langchain_core.output_parsers import StrOutputParser
18
  from langchain_community.embeddings import HuggingFaceEmbeddings
19
  from langchain_community.llms import HuggingFaceHub
20
+ from langchain_core.documents import Document
21
 
22
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
23
 
24
+ class Agent1:
25
+ def __init__(self, model):
26
+ self.model = model
27
+
28
+ def rephrase_and_split(self, user_input: str) -> List[str]:
29
+ rephrase_prompt = PromptTemplate(
30
+ input_variables=["query"],
31
+ template="""
32
+ Your task is to rephrase the given query into one or more concise, search-engine-friendly formats.
33
+ If the query contains multiple distinct questions, split them.
34
+ Provide ONLY the rephrased queries without any additional text or explanations, one per line.
35
+
36
+ Query: {query}
37
+
38
+ Rephrased queries:"""
39
+ )
40
+
41
+ chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
42
+ response = chain.run(query=user_input).strip()
43
+
44
+ return [q.strip() for q in response.split('\n') if q.strip()]
45
+
46
+ def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
47
+ queries = self.rephrase_and_split(user_input)
48
+ results = {}
49
+ for query in queries:
50
+ results[query] = google_search(query)
51
+ return results
52
+
53
+ class Agent2:
54
+ def __init__(self, model):
55
+ self.model = model
56
+
57
+ def validate_response(self, user_query: str, response: str) -> bool:
58
+ validation_prompt = PromptTemplate(
59
+ input_variables=["query", "response"],
60
+ template="""
61
+ Evaluate if the following response fully answers the user's query.
62
+ User query: {query}
63
+ Response: {response}
64
+
65
+ Does the response fully answer the query? Answer with Yes or No:"""
66
+ )
67
+
68
+ chain = LLMChain(llm=self.model, prompt=validation_prompt)
69
+ result = chain.run(query=user_query, response=response).strip().lower()
70
+ return result == 'yes'
71
+
72
+ def generate_follow_up_query(self, user_query: str, response: str) -> str:
73
+ follow_up_prompt = PromptTemplate(
74
+ input_variables=["query", "response"],
75
+ template="""
76
+ The following response did not fully answer the user's query.
77
+ User query: {query}
78
+ Response: {response}
79
+
80
+ Generate a follow-up query to get more relevant information:"""
81
+ )
82
+
83
+ chain = LLMChain(llm=self.model, prompt=follow_up_prompt)
84
+ return chain.run(query=user_query, response=response).strip()
85
+
86
  def load_document(file: NamedTemporaryFile) -> List[Document]:
87
  """Loads and splits the document into pages."""
88
  loader = PyPDFLoader(file.name)
 
269
 
270
  model = get_model(temperature, top_p, repetition_penalty)
271
  embed = get_embeddings()
272
+ agent1 = Agent1(model)
273
+ agent2 = Agent2(model)
274
 
275
  if os.path.exists("faiss_database"):
276
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
 
283
  for attempt in range(max_attempts):
284
  try:
285
  if web_search:
286
+ search_results = agent1.process(question)
287
+ web_docs = []
288
+ for query, results in search_results.items():
289
+ web_docs.extend([Document(page_content=result["text"], metadata={"source": result["link"], "query": query}) for result in results if result["text"]])
 
 
 
 
 
 
290
 
291
  if database is None:
292
  database = FAISS.from_documents(web_docs, embed)
 
295
 
296
  database.save_local("faiss_database")
297
 
298
+ context_str = "\n".join([f"Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
299
 
300
  prompt_template = """
301
  Answer the question based on the following web search results:
302
  Web Search Results:
303
  {context}
304
+ Original Question: {question}
 
305
  If the web search results don't contain relevant information, state that the information is not available in the search results.
306
  Provide a concise and direct answer to the original question without mentioning the web search or these instructions.
307
  Do not include any source information in your answer.
308
  """
 
 
309
  else:
310
  if database is None:
311
  return "No documents available. Please upload documents or enable web search to answer questions."
 
314
  relevant_docs = retriever.get_relevant_documents(question)
315
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
316
 
 
317
  if attempt > 0:
318
  words = context_str.split()
319
  context_str = " ".join(words[:int(len(words) * context_reduction_factor)])
 
327
  Provide a concise and direct answer to the question.
328
  Do not include any source information in your answer.
329
  """
330
+
331
+ prompt_val = ChatPromptTemplate.from_template(prompt_template)
332
+ formatted_prompt = prompt_val.format(context=context_str, question=question)
333
 
334
  full_response = generate_chunked_response(model, formatted_prompt)
335
 
 
349
  else:
350
  answer = full_response.strip()
351
 
352
+ if not agent2.validate_response(question, answer):
353
+ follow_up_query = agent2.generate_follow_up_query(question, answer)
354
+ follow_up_results = agent1.process(follow_up_query)
355
+ follow_up_docs = [Document(page_content=result["text"], metadata={"source": result["link"], "query": follow_up_query}) for results in follow_up_results.values() for result in results if result["text"]]
356
+ database.add_documents(follow_up_docs)
357
+ context_str += "\n" + "\n".join([f"Follow-up Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in follow_up_docs])
358
+ formatted_prompt = prompt_val.format(context=context_str, question=question)
359
+ full_response = generate_chunked_response(model, formatted_prompt)
360
+ answer = full_response.strip()
361
+
362
  if web_search:
363
  sources = set(doc.metadata['source'] for doc in web_docs)
364
  sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)