Shreyas094 commited on
Commit
5c1c2c2
·
verified ·
1 Parent(s): fcfb7cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -50
app.py CHANGED
@@ -41,7 +41,11 @@ class Agent1:
41
  chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
42
  response = chain.run(query=user_input).strip()
43
 
44
- return [q.strip() for q in response.split('\n') if q.strip()]
 
 
 
 
45
 
46
  def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
47
  queries = self.rephrase_and_split(user_input)
@@ -50,39 +54,6 @@ class Agent1:
50
  results[query] = google_search(query)
51
  return results
52
 
53
- class Agent2:
54
- def __init__(self, model):
55
- self.model = model
56
-
57
- def validate_response(self, user_query: str, response: str) -> bool:
58
- validation_prompt = PromptTemplate(
59
- input_variables=["query", "response"],
60
- template="""
61
- Evaluate if the following response fully answers the user's query.
62
- User query: {query}
63
- Response: {response}
64
-
65
- Does the response fully answer the query? Answer with Yes or No:"""
66
- )
67
-
68
- chain = LLMChain(llm=self.model, prompt=validation_prompt)
69
- result = chain.run(query=user_query, response=response).strip().lower()
70
- return result == 'yes'
71
-
72
- def generate_follow_up_query(self, user_query: str, response: str) -> str:
73
- follow_up_prompt = PromptTemplate(
74
- input_variables=["query", "response"],
75
- template="""
76
- The following response did not fully answer the user's query.
77
- User query: {query}
78
- Response: {response}
79
-
80
- Generate a follow-up query to get more relevant information:"""
81
- )
82
-
83
- chain = LLMChain(llm=self.model, prompt=follow_up_prompt)
84
- return chain.run(query=user_query, response=response).strip()
85
-
86
  def load_document(file: NamedTemporaryFile) -> List[Document]:
87
  """Loads and splits the document into pages."""
88
  loader = PyPDFLoader(file.name)
@@ -270,7 +241,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
270
  model = get_model(temperature, top_p, repetition_penalty)
271
  embed = get_embeddings()
272
  agent1 = Agent1(model)
273
- agent2 = Agent2(model)
274
 
275
  if os.path.exists("faiss_database"):
276
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
@@ -279,7 +249,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
279
 
280
  max_attempts = 3
281
  context_reduction_factor = 0.7
282
- agent2_max_attempts = 2
283
 
284
  for attempt in range(max_attempts):
285
  try:
@@ -350,20 +319,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
350
  else:
351
  answer = full_response.strip()
352
 
353
- for agent2_attempt in range(agent2_max_attempts):
354
- if agent2.validate_response(question, answer):
355
- break
356
-
357
- if agent2_attempt < agent2_max_attempts - 1:
358
- follow_up_query = agent2.generate_follow_up_query(question, answer)
359
- follow_up_results = agent1.process(follow_up_query)
360
- follow_up_docs = [Document(page_content=result["text"], metadata={"source": result["link"], "query": follow_up_query}) for results in follow_up_results.values() for result in results if result["text"]]
361
- database.add_documents(follow_up_docs)
362
- context_str += "\n" + "\n".join([f"Follow-up Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in follow_up_docs])
363
- formatted_prompt = prompt_val.format(context=context_str, question=question)
364
- full_response = generate_chunked_response(model, formatted_prompt)
365
- answer = full_response.strip()
366
-
367
  if web_search:
368
  sources = set(doc.metadata['source'] for doc in web_docs)
369
  sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
 
41
  chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
42
  response = chain.run(query=user_input).strip()
43
 
44
+ # Remove any lines that contain instructions or explanations
45
+ rephrased_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.startswith("Rephrase") and "query" not in q.lower()]
46
+
47
+ # If no valid rephrased queries, return the original input
48
+ return rephrased_queries if rephrased_queries else [user_input]
49
 
50
  def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
51
  queries = self.rephrase_and_split(user_input)
 
54
  results[query] = google_search(query)
55
  return results
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def load_document(file: NamedTemporaryFile) -> List[Document]:
58
  """Loads and splits the document into pages."""
59
  loader = PyPDFLoader(file.name)
 
241
  model = get_model(temperature, top_p, repetition_penalty)
242
  embed = get_embeddings()
243
  agent1 = Agent1(model)
 
244
 
245
  if os.path.exists("faiss_database"):
246
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
 
249
 
250
  max_attempts = 3
251
  context_reduction_factor = 0.7
 
252
 
253
  for attempt in range(max_attempts):
254
  try:
 
319
  else:
320
  answer = full_response.strip()
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  if web_search:
323
  sources = set(doc.metadata['source'] for doc in web_docs)
324
  sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)