Shreyas094 commited on
Commit
ced5a78
·
verified ·
1 Parent(s): 303be9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -42
app.py CHANGED
@@ -97,48 +97,52 @@ class Agent1:
97
  return questions
98
 
99
  def update_context(self, query: str):
100
- tokens = nltk.pos_tag(word_tokenize(query))
101
- noun_phrases = []
102
- current_phrase = []
103
-
104
- for word, tag in tokens:
105
- if tag.startswith('NN') or tag.startswith('JJ'):
106
- current_phrase.append(word)
107
- else:
108
- if current_phrase:
109
- noun_phrases.append(' '.join(current_phrase))
110
- current_phrase = []
111
-
112
- if current_phrase:
113
- noun_phrases.append(' '.join(current_phrase))
114
-
115
- if noun_phrases:
116
- self.context['main_topic'] = noun_phrases[0]
117
- self.context['related_topics'] = noun_phrases[1:]
 
118
 
119
  def apply_context(self, query: str) -> str:
120
- words = word_tokenize(query.lower())
121
-
122
- if (len(words) <= 5 or
123
- any(word in self.pronouns for word in words) or
124
- (self.context.get('main_topic') and self.context['main_topic'].lower() not in query.lower())):
125
-
126
- new_query_parts = []
127
- main_topic_added = False
128
 
129
- for word in words:
130
- if word in self.pronouns and self.context.get('main_topic'):
131
- new_query_parts.append(self.context['main_topic'])
132
- main_topic_added = True
133
- else:
134
- new_query_parts.append(word)
 
 
 
135
 
136
- if not main_topic_added and self.context.get('main_topic'):
137
- new_query_parts.append(f"of {self.context['main_topic']}")
138
 
139
- query = ' '.join(new_query_parts)
 
140
 
141
- return query
142
 
143
  def process(self, user_input: str) -> tuple[List[str], Dict[str, List[Dict[str, str]]]]:
144
  self.update_context(user_input)
@@ -306,13 +310,15 @@ def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_
306
 
307
  return all_results
308
 
309
- def ask_question(question, temperature, top_p, repetition_penalty, web_search):
310
  if not question:
311
  return "Please enter a question."
312
 
 
 
 
313
  model = get_model(temperature, top_p, repetition_penalty)
314
  embed = get_embeddings()
315
- agent1 = Agent1() # Create Agent1 without passing a model
316
 
317
  if os.path.exists("faiss_database"):
318
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
@@ -322,8 +328,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
322
  max_attempts = 3
323
  context_reduction_factor = 0.7
324
 
 
 
 
325
  if web_search:
326
- queries, search_results = agent1.process(question)
327
  all_answers = []
328
 
329
  for query in queries:
@@ -395,7 +404,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
395
  return "No documents available. Please upload documents or enable web search to answer questions."
396
 
397
  retriever = database.as_retriever()
398
- relevant_docs = retriever.get_relevant_documents(question)
399
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
400
 
401
  if attempt > 0:
@@ -413,7 +422,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
413
  """
414
 
415
  prompt_val = ChatPromptTemplate.from_template(prompt_template)
416
- formatted_prompt = prompt_val.format(context=context_str, question=question)
417
 
418
  full_response = generate_chunked_response(model, formatted_prompt)
419
 
@@ -466,8 +475,10 @@ with gr.Blocks() as demo:
466
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
467
  web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
468
 
 
 
469
  def chat(question, history, temperature, top_p, repetition_penalty, web_search):
470
- answer = ask_question(question, temperature, top_p, repetition_penalty, web_search)
471
  history.append((question, answer))
472
  return "", history
473
 
 
97
  return questions
98
 
99
  def update_context(self, query: str):
100
+ tokens = nltk.pos_tag(word_tokenize(query))
101
+ noun_phrases = []
102
+ current_phrase = []
103
+
104
+ for word, tag in tokens:
105
+ if tag.startswith('NN') or tag.startswith('JJ'):
106
+ current_phrase.append(word)
107
+ else:
108
+ if current_phrase:
109
+ noun_phrases.append(' '.join(current_phrase))
110
+ current_phrase = []
111
+
112
+ if current_phrase:
113
+ noun_phrases.append(' '.join(current_phrase))
114
+
115
+ if noun_phrases:
116
+ self.context['main_topic'] = noun_phrases[0]
117
+ self.context['related_topics'] = noun_phrases[1:]
118
+ self.context['last_query'] = query
119
 
120
  def apply_context(self, query: str) -> str:
121
+ words = word_tokenize(query.lower())
122
+
123
+ if (len(words) <= 5 or
124
+ any(word in self.pronouns for word in words) or
125
+ (self.context.get('main_topic') and self.context['main_topic'].lower() not in query.lower())):
126
+
127
+ new_query_parts = []
128
+ main_topic_added = False
129
 
130
+ for word in words:
131
+ if word in self.pronouns and self.context.get('main_topic'):
132
+ new_query_parts.append(self.context['main_topic'])
133
+ main_topic_added = True
134
+ else:
135
+ new_query_parts.append(word)
136
+
137
+ if not main_topic_added and self.context.get('main_topic'):
138
+ new_query_parts.append(f"in the context of {self.context['main_topic']}")
139
 
140
+ query = ' '.join(new_query_parts)
 
141
 
142
+ if self.context.get('last_query'):
143
+ query = f"{self.context['last_query']} and now {query}"
144
 
145
+ return query
146
 
147
  def process(self, user_input: str) -> tuple[List[str], Dict[str, List[Dict[str, str]]]]:
148
  self.update_context(user_input)
 
310
 
311
  return all_results
312
 
313
+ def ask_question(question, temperature, top_p, repetition_penalty, web_search, agent1=None):
314
  if not question:
315
  return "Please enter a question."
316
 
317
+ if agent1 is None:
318
+ agent1 = Agent1()
319
+
320
  model = get_model(temperature, top_p, repetition_penalty)
321
  embed = get_embeddings()
 
322
 
323
  if os.path.exists("faiss_database"):
324
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
 
328
  max_attempts = 3
329
  context_reduction_factor = 0.7
330
 
331
+ agent1.update_context(question)
332
+ contextualized_question = agent1.apply_context(question)
333
+
334
  if web_search:
335
+ queries, search_results = agent1.process(contextualized_question)
336
  all_answers = []
337
 
338
  for query in queries:
 
404
  return "No documents available. Please upload documents or enable web search to answer questions."
405
 
406
  retriever = database.as_retriever()
407
+ relevant_docs = retriever.get_relevant_documents(contextualized_question)
408
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
409
 
410
  if attempt > 0:
 
422
  """
423
 
424
  prompt_val = ChatPromptTemplate.from_template(prompt_template)
425
+ formatted_prompt = prompt_val.format(context=context_str, question=contextualized_question)
426
 
427
  full_response = generate_chunked_response(model, formatted_prompt)
428
 
 
475
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
476
  web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
477
 
478
+ agent1 = Agent1()
479
+
480
  def chat(question, history, temperature, top_p, repetition_penalty, web_search):
481
+ answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, agent1)
482
  history.append((question, answer))
483
  return "", history
484