Shreyas094 commited on
Commit
3450cd7
·
verified ·
1 Parent(s): a6a5ca5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -104
app.py CHANGED
@@ -44,17 +44,18 @@ def load_spacy_model():
44
  nlp = load_spacy_model()
45
 
46
  class EnhancedContextDrivenChatbot:
47
- def __init__(self, history_size=10, model=None):
48
  self.history = []
49
  self.history_size = history_size
 
50
  self.entity_tracker = {}
51
  self.conversation_context = ""
52
- self.model = model
53
  self.last_instructions = None
54
 
55
- def add_to_history(self, text):
56
  self.history.append(text)
57
- if len(self.history) > self.history_size:
58
  self.history.pop(0)
59
 
60
  # Update entity tracker
@@ -221,6 +222,28 @@ def get_model(temperature, top_p, repetition_penalty):
221
  huggingfacehub_api_token=huggingface_token
222
  )
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def generate_chunked_response(model, prompt, max_tokens=1000, max_chunks=5):
225
  full_response = ""
226
  for i in range(max_chunks):
@@ -329,115 +352,51 @@ def estimate_tokens(text):
329
  # Rough estimate: 1 token ~= 4 characters
330
  return len(text) // 4
331
 
332
- def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot):
333
- if not question:
334
- return "Please enter a question."
335
-
336
  model = get_model(temperature, top_p, repetition_penalty)
337
-
338
- # Update the chatbot's model
339
  chatbot.model = model
340
 
341
- embed = get_embeddings()
342
-
343
- if os.path.exists("faiss_database"):
344
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
345
- else:
346
- database = None
347
-
348
- max_attempts = 5
349
- context_reduction_factor = 0.7
350
- max_tokens = 32000 # Maximum tokens allowed by the model
351
-
352
  if web_search:
353
  contextualized_question, topics, entity_tracker, instructions = chatbot.process_question(question)
354
- serializable_entity_tracker = {k: list(v) for k, v in entity_tracker.items()}
355
-
356
  search_results = google_search(contextualized_question, num_results=3)
357
- all_answers = []
358
-
359
- for attempt in range(max_attempts):
360
- try:
361
- web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
362
-
363
- if database is None:
364
- database = FAISS.from_documents(web_docs, embed)
365
- else:
366
- database.add_documents(web_docs)
367
-
368
- database.save_local("faiss_database")
369
-
370
- context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
371
-
372
- instruction_prompt = f"User Instructions: {instructions}\n" if instructions else ""
373
-
374
- prompt_template = f"""
375
- Answer the question based on the following web search results, conversation context, entity information, and user instructions:
376
- Web Search Results:
377
- {{context}}
378
- Conversation Context: {{conv_context}}
379
- Current Question: {{question}}
380
- Topics: {{topics}}
381
- Entity Information: {{entities}}
382
- {instruction_prompt}
383
- Provide a concise and relevant answer to the question.
384
- """
385
-
386
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
387
-
388
- # Start with full context and progressively reduce if necessary
389
- current_context = context_str
390
- current_conv_context = chatbot.get_context()
391
- current_topics = topics
392
- current_entities = serializable_entity_tracker
393
-
394
- while True:
395
- formatted_prompt = prompt_val.format(
396
- context=current_context,
397
- conv_context=current_conv_context,
398
- question=question,
399
- topics=", ".join(current_topics),
400
- entities=json.dumps(current_entities)
401
- )
402
-
403
- # Estimate token count
404
- estimated_tokens = estimate_tokens(formatted_prompt)
405
-
406
- if estimated_tokens <= max_tokens - 1000: # Leave 1000 tokens for the model's response
407
- break
408
-
409
- # Reduce context if estimated token count is too high
410
- current_context = current_context[:int(len(current_context) * context_reduction_factor)]
411
- current_conv_context = current_conv_context[:int(len(current_conv_context) * context_reduction_factor)]
412
- current_topics = current_topics[:max(1, int(len(current_topics) * context_reduction_factor))]
413
- current_entities = {k: v[:max(1, int(len(v) * context_reduction_factor))] for k, v in current_entities.items()}
414
-
415
- if len(current_context) + len(current_conv_context) + len(str(current_topics)) + len(str(current_entities)) < 100:
416
- raise ValueError("Context reduced too much. Unable to process the query.")
417
-
418
- full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
419
- answer = extract_answer(full_response, instructions)
420
- all_answers.append(answer)
421
  break
422
-
423
- except ValueError as ve:
424
- print(f"Error in ask_question (attempt {attempt + 1}): {ve}")
425
- if attempt == max_attempts - 1:
426
- all_answers.append(f"I apologize, but I'm having trouble processing the query due to its length or complexity. Could you please try asking a more specific or shorter question?")
427
-
428
- except Exception as e:
429
- print(f"Error in ask_question (attempt {attempt + 1}): {e}")
430
- if attempt == max_attempts - 1:
431
- all_answers.append(f"I apologize, but an unexpected error occurred. Please try again with a different question or check your internet connection.")
432
-
433
- answer = "\n\n".join(all_answers)
434
- sources = set(doc.metadata['source'] for doc in web_docs)
435
- sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
436
- answer += sources_section
437
-
438
  # Update chatbot context with the answer
439
  chatbot.add_to_history(answer)
440
-
441
  return answer
442
 
443
  else: # PDF document chat
 
44
  nlp = load_spacy_model()
45
 
46
  class EnhancedContextDrivenChatbot:
47
+ def __init__(self, history_size: int = 10, max_history_chars: int = 5000):
48
  self.history = []
49
  self.history_size = history_size
50
+ self.max_history_chars = max_history_chars
51
  self.entity_tracker = {}
52
  self.conversation_context = ""
53
+ self.model = None
54
  self.last_instructions = None
55
 
56
+ def add_to_history(self, text: str):
57
  self.history.append(text)
58
+ while len(' '.join(self.history)) > self.max_history_chars or len(self.history) > self.history_size:
59
  self.history.pop(0)
60
 
61
  # Update entity tracker
 
222
  huggingfacehub_api_token=huggingface_token
223
  )
224
 
225
+ MAX_PROMPT_CHARS = 24000 # Adjust based on your model's limitations
226
+
227
+ def chunk_text(text: str, max_chunk_size: int = 1000) -> List[str]:
228
+ chunks = []
229
+ current_chunk = ""
230
+ for sentence in re.split(r'(?<=[.!?])\s+', text):
231
+ if len(current_chunk) + len(sentence) > max_chunk_size:
232
+ chunks.append(current_chunk.strip())
233
+ current_chunk = sentence
234
+ else:
235
+ current_chunk += " " + sentence
236
+ if current_chunk:
237
+ chunks.append(current_chunk.strip())
238
+ return chunks
239
+
240
+ def get_most_relevant_chunks(question: str, chunks: List[str], top_k: int = 3) -> List[str]:
241
+ question_embedding = sentence_model.encode([question])[0]
242
+ chunk_embeddings = sentence_model.encode(chunks)
243
+ similarities = cosine_similarity([question_embedding], chunk_embeddings)[0]
244
+ top_indices = np.argsort(similarities)[-top_k:]
245
+ return [chunks[i] for i in top_indices]
246
+
247
  def generate_chunked_response(model, prompt, max_tokens=1000, max_chunks=5):
248
  full_response = ""
249
  for i in range(max_chunks):
 
352
  # Rough estimate: 1 token ~= 4 characters
353
  return len(text) // 4
354
 
355
+ def ask_question(question: str, temperature: float, top_p: float, repetition_penalty: float, web_search: bool, chatbot: EnhancedContextDrivenChatbot) -> str:
 
 
 
356
  model = get_model(temperature, top_p, repetition_penalty)
 
 
357
  chatbot.model = model
358
 
 
 
 
 
 
 
 
 
 
 
 
359
  if web_search:
360
  contextualized_question, topics, entity_tracker, instructions = chatbot.process_question(question)
 
 
361
  search_results = google_search(contextualized_question, num_results=3)
362
+
363
+ context_chunks = []
364
+ for result in search_results:
365
+ if result["text"]:
366
+ context_chunks.extend(chunk_text(result["text"]))
367
+
368
+ relevant_chunks = get_most_relevant_chunks(question, context_chunks)
369
+
370
+ prompt_parts = [
371
+ f"Question: {question}",
372
+ f"Conversation Context: {chatbot.get_context()[-1000:]}", # Last 1000 characters
373
+ "Relevant Web Search Results:"
374
+ ]
375
+
376
+ for chunk in relevant_chunks:
377
+ if len(' '.join(prompt_parts)) + len(chunk) < MAX_PROMPT_CHARS:
378
+ prompt_parts.append(chunk)
379
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  break
381
+
382
+ if instructions:
383
+ prompt_parts.append(f"User Instructions: {instructions}")
384
+
385
+ prompt_template = """
386
+ Answer the question based on the following information:
387
+ {context}
388
+ Provide a concise and relevant answer to the question.
389
+ """
390
+
391
+ formatted_prompt = prompt_template.format(context='\n'.join(prompt_parts))
392
+
393
+ # Generate response using the model
394
+ full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
395
+ answer = extract_answer(full_response, instructions)
396
+
397
  # Update chatbot context with the answer
398
  chatbot.add_to_history(answer)
399
+
400
  return answer
401
 
402
  else: # PDF document chat