Shreyas094 commited on
Commit
63b644a
·
verified ·
1 Parent(s): 6fac185

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -20
app.py CHANGED
@@ -29,33 +29,41 @@ from langchain_core.documents import Document
29
 
30
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
31
 
32
- # Download necessary NLTK data
33
- nltk.download('punkt')
34
- nltk.download('averaged_perceptron_tagger')
35
- class ContextDrivenChatbot:
36
- def __init__(self, history_size=5):
 
 
 
37
  self.history = []
38
  self.history_size = history_size
39
- self.vectorizer = TfidfVectorizer()
40
- nltk.download('punkt', quiet=True)
41
- nltk.download('averaged_perceptron_tagger', quiet=True)
42
 
43
  def add_to_history(self, text):
44
  self.history.append(text)
45
  if len(self.history) > self.history_size:
46
  self.history.pop(0)
 
 
 
 
 
 
 
47
 
48
  def get_context(self):
49
  return " ".join(self.history)
50
 
51
  def is_follow_up_question(self, question):
52
- tokens = word_tokenize(question.lower())
53
  follow_up_indicators = set(['it', 'this', 'that', 'these', 'those', 'he', 'she', 'they', 'them'])
54
- return any(token in follow_up_indicators for token in tokens)
55
 
56
  def extract_topics(self, text):
57
- tokens = nltk.pos_tag(word_tokenize(text))
58
- return [word for word, pos in tokens if pos.startswith('NN')]
59
 
60
  def get_most_relevant_context(self, question):
61
  if not self.history:
@@ -64,11 +72,12 @@ class ContextDrivenChatbot:
64
  # Create a combined context from history
65
  combined_context = self.get_context()
66
 
67
- # Vectorize the context and the question
68
- vectors = self.vectorizer.fit_transform([combined_context, question])
 
69
 
70
  # Calculate similarity
71
- similarity = cosine_similarity(vectors[0], vectors[1])[0][0]
72
 
73
  # If similarity is low, it might be a new topic
74
  if similarity < 0.3: # This threshold can be adjusted
@@ -91,7 +100,7 @@ class ContextDrivenChatbot:
91
  # Add the new question to history
92
  self.add_to_history(question)
93
 
94
- return contextualized_question, topics
95
 
96
  def load_document(file: NamedTemporaryFile) -> List[Document]:
97
  """Loads and splits the document into pages."""
@@ -262,7 +271,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
262
  max_attempts = 3
263
  context_reduction_factor = 0.7
264
 
265
- contextualized_question, topics = chatbot.process_question(question)
266
 
267
  if web_search:
268
  search_results = google_search(contextualized_question)
@@ -282,12 +291,13 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
282
  context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
283
 
284
  prompt_template = """
285
- Answer the question based on the following web search results and conversation context:
286
  Web Search Results:
287
  {context}
288
  Conversation Context: {conv_context}
289
  Current Question: {question}
290
  Topics: {topics}
 
291
  If the web search results don't contain relevant information, state that the information is not available in the search results.
292
  Provide a summarized and direct answer to the question without mentioning the web search or these instructions.
293
  Do not include any source information in your answer.
@@ -298,7 +308,8 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
298
  context=context_str,
299
  conv_context=chatbot.get_context(),
300
  question=question,
301
- topics=", ".join(topics)
 
302
  )
303
 
304
  full_response = generate_chunked_response(model, formatted_prompt)
@@ -415,7 +426,7 @@ with gr.Blocks() as demo:
415
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
416
  web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
417
 
418
- context_driven_chatbot = ContextDrivenChatbot()
419
 
420
  def chat(question, history, temperature, top_p, repetition_penalty, web_search):
421
  answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, context_driven_chatbot)
 
29
 
30
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
31
 
32
+ # Load spaCy model
33
+ nlp = spacy.load("en_core_web_sm")
34
+
35
+ # Load SentenceTransformer model
36
+ sentence_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
37
+
38
+ class EnhancedContextDrivenChatbot:
39
+ def __init__(self, history_size=10):
40
  self.history = []
41
  self.history_size = history_size
42
+ self.entity_tracker = {}
 
 
43
 
44
  def add_to_history(self, text):
45
  self.history.append(text)
46
  if len(self.history) > self.history_size:
47
  self.history.pop(0)
48
+
49
+ # Update entity tracker
50
+ doc = nlp(text)
51
+ for ent in doc.ents:
52
+ if ent.label_ not in self.entity_tracker:
53
+ self.entity_tracker[ent.label_] = set()
54
+ self.entity_tracker[ent.label_].add(ent.text)
55
 
56
  def get_context(self):
57
  return " ".join(self.history)
58
 
59
  def is_follow_up_question(self, question):
60
+ doc = nlp(question.lower())
61
  follow_up_indicators = set(['it', 'this', 'that', 'these', 'those', 'he', 'she', 'they', 'them'])
62
+ return any(token.text in follow_up_indicators for token in doc)
63
 
64
  def extract_topics(self, text):
65
+ doc = nlp(text)
66
+ return [chunk.text for chunk in doc.noun_chunks]
67
 
68
  def get_most_relevant_context(self, question):
69
  if not self.history:
 
72
  # Create a combined context from history
73
  combined_context = self.get_context()
74
 
75
+ # Get embeddings
76
+ context_embedding = sentence_model.encode([combined_context])[0]
77
+ question_embedding = sentence_model.encode([question])[0]
78
 
79
  # Calculate similarity
80
+ similarity = cosine_similarity([context_embedding], [question_embedding])[0][0]
81
 
82
  # If similarity is low, it might be a new topic
83
  if similarity < 0.3: # This threshold can be adjusted
 
100
  # Add the new question to history
101
  self.add_to_history(question)
102
 
103
+ return contextualized_question, topics, self.entity_tracker
104
 
105
  def load_document(file: NamedTemporaryFile) -> List[Document]:
106
  """Loads and splits the document into pages."""
 
271
  max_attempts = 3
272
  context_reduction_factor = 0.7
273
 
274
+ contextualized_question, topics, entity_tracker = chatbot.process_question(question)
275
 
276
  if web_search:
277
  search_results = google_search(contextualized_question)
 
291
  context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
292
 
293
  prompt_template = """
294
+ Answer the question based on the following web search results, conversation context, and entity information:
295
  Web Search Results:
296
  {context}
297
  Conversation Context: {conv_context}
298
  Current Question: {question}
299
  Topics: {topics}
300
+ Entity Information: {entities}
301
  If the web search results don't contain relevant information, state that the information is not available in the search results.
302
  Provide a summarized and direct answer to the question without mentioning the web search or these instructions.
303
  Do not include any source information in your answer.
 
308
  context=context_str,
309
  conv_context=chatbot.get_context(),
310
  question=question,
311
+ topics=", ".join(topics),
312
+ entities=json.dumps(entity_tracker)
313
  )
314
 
315
  full_response = generate_chunked_response(model, formatted_prompt)
 
426
  repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
427
  web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
428
 
429
+ context_driven_chatbot = EnhancedContextDrivenChatbot()
430
 
431
  def chat(question, history, temperature, top_p, repetition_penalty, web_search):
432
  answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, context_driven_chatbot)