Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -44,17 +44,18 @@ def load_spacy_model():
|
|
44 |
nlp = load_spacy_model()
|
45 |
|
46 |
class EnhancedContextDrivenChatbot:
|
47 |
-
def __init__(self, history_size=10,
|
48 |
self.history = []
|
49 |
self.history_size = history_size
|
|
|
50 |
self.entity_tracker = {}
|
51 |
self.conversation_context = ""
|
52 |
-
self.model =
|
53 |
self.last_instructions = None
|
54 |
|
55 |
-
def add_to_history(self, text):
|
56 |
self.history.append(text)
|
57 |
-
|
58 |
self.history.pop(0)
|
59 |
|
60 |
# Update entity tracker
|
@@ -221,6 +222,28 @@ def get_model(temperature, top_p, repetition_penalty):
|
|
221 |
huggingfacehub_api_token=huggingface_token
|
222 |
)
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
def generate_chunked_response(model, prompt, max_tokens=1000, max_chunks=5):
|
225 |
full_response = ""
|
226 |
for i in range(max_chunks):
|
@@ -329,115 +352,51 @@ def estimate_tokens(text):
|
|
329 |
# Rough estimate: 1 token ~= 4 characters
|
330 |
return len(text) // 4
|
331 |
|
332 |
-
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot):
|
333 |
-
if not question:
|
334 |
-
return "Please enter a question."
|
335 |
-
|
336 |
model = get_model(temperature, top_p, repetition_penalty)
|
337 |
-
|
338 |
-
# Update the chatbot's model
|
339 |
chatbot.model = model
|
340 |
|
341 |
-
embed = get_embeddings()
|
342 |
-
|
343 |
-
if os.path.exists("faiss_database"):
|
344 |
-
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
345 |
-
else:
|
346 |
-
database = None
|
347 |
-
|
348 |
-
max_attempts = 5
|
349 |
-
context_reduction_factor = 0.7
|
350 |
-
max_tokens = 32000 # Maximum tokens allowed by the model
|
351 |
-
|
352 |
if web_search:
|
353 |
contextualized_question, topics, entity_tracker, instructions = chatbot.process_question(question)
|
354 |
-
serializable_entity_tracker = {k: list(v) for k, v in entity_tracker.items()}
|
355 |
-
|
356 |
search_results = google_search(contextualized_question, num_results=3)
|
357 |
-
|
358 |
-
|
359 |
-
for
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
Answer the question based on the following web search results, conversation context, entity information, and user instructions:
|
376 |
-
Web Search Results:
|
377 |
-
{{context}}
|
378 |
-
Conversation Context: {{conv_context}}
|
379 |
-
Current Question: {{question}}
|
380 |
-
Topics: {{topics}}
|
381 |
-
Entity Information: {{entities}}
|
382 |
-
{instruction_prompt}
|
383 |
-
Provide a concise and relevant answer to the question.
|
384 |
-
"""
|
385 |
-
|
386 |
-
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
387 |
-
|
388 |
-
# Start with full context and progressively reduce if necessary
|
389 |
-
current_context = context_str
|
390 |
-
current_conv_context = chatbot.get_context()
|
391 |
-
current_topics = topics
|
392 |
-
current_entities = serializable_entity_tracker
|
393 |
-
|
394 |
-
while True:
|
395 |
-
formatted_prompt = prompt_val.format(
|
396 |
-
context=current_context,
|
397 |
-
conv_context=current_conv_context,
|
398 |
-
question=question,
|
399 |
-
topics=", ".join(current_topics),
|
400 |
-
entities=json.dumps(current_entities)
|
401 |
-
)
|
402 |
-
|
403 |
-
# Estimate token count
|
404 |
-
estimated_tokens = estimate_tokens(formatted_prompt)
|
405 |
-
|
406 |
-
if estimated_tokens <= max_tokens - 1000: # Leave 1000 tokens for the model's response
|
407 |
-
break
|
408 |
-
|
409 |
-
# Reduce context if estimated token count is too high
|
410 |
-
current_context = current_context[:int(len(current_context) * context_reduction_factor)]
|
411 |
-
current_conv_context = current_conv_context[:int(len(current_conv_context) * context_reduction_factor)]
|
412 |
-
current_topics = current_topics[:max(1, int(len(current_topics) * context_reduction_factor))]
|
413 |
-
current_entities = {k: v[:max(1, int(len(v) * context_reduction_factor))] for k, v in current_entities.items()}
|
414 |
-
|
415 |
-
if len(current_context) + len(current_conv_context) + len(str(current_topics)) + len(str(current_entities)) < 100:
|
416 |
-
raise ValueError("Context reduced too much. Unable to process the query.")
|
417 |
-
|
418 |
-
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
419 |
-
answer = extract_answer(full_response, instructions)
|
420 |
-
all_answers.append(answer)
|
421 |
break
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
answer
|
437 |
-
|
438 |
# Update chatbot context with the answer
|
439 |
chatbot.add_to_history(answer)
|
440 |
-
|
441 |
return answer
|
442 |
|
443 |
else: # PDF document chat
|
|
|
44 |
nlp = load_spacy_model()
|
45 |
|
46 |
class EnhancedContextDrivenChatbot:
|
47 |
+
def __init__(self, history_size: int = 10, max_history_chars: int = 5000):
|
48 |
self.history = []
|
49 |
self.history_size = history_size
|
50 |
+
self.max_history_chars = max_history_chars
|
51 |
self.entity_tracker = {}
|
52 |
self.conversation_context = ""
|
53 |
+
self.model = None
|
54 |
self.last_instructions = None
|
55 |
|
56 |
+
def add_to_history(self, text: str):
|
57 |
self.history.append(text)
|
58 |
+
while len(' '.join(self.history)) > self.max_history_chars or len(self.history) > self.history_size:
|
59 |
self.history.pop(0)
|
60 |
|
61 |
# Update entity tracker
|
|
|
222 |
huggingfacehub_api_token=huggingface_token
|
223 |
)
|
224 |
|
225 |
+
MAX_PROMPT_CHARS = 24000 # Adjust based on your model's limitations
|
226 |
+
|
227 |
+
def chunk_text(text: str, max_chunk_size: int = 1000) -> List[str]:
|
228 |
+
chunks = []
|
229 |
+
current_chunk = ""
|
230 |
+
for sentence in re.split(r'(?<=[.!?])\s+', text):
|
231 |
+
if len(current_chunk) + len(sentence) > max_chunk_size:
|
232 |
+
chunks.append(current_chunk.strip())
|
233 |
+
current_chunk = sentence
|
234 |
+
else:
|
235 |
+
current_chunk += " " + sentence
|
236 |
+
if current_chunk:
|
237 |
+
chunks.append(current_chunk.strip())
|
238 |
+
return chunks
|
239 |
+
|
240 |
+
def get_most_relevant_chunks(question: str, chunks: List[str], top_k: int = 3) -> List[str]:
|
241 |
+
question_embedding = sentence_model.encode([question])[0]
|
242 |
+
chunk_embeddings = sentence_model.encode(chunks)
|
243 |
+
similarities = cosine_similarity([question_embedding], chunk_embeddings)[0]
|
244 |
+
top_indices = np.argsort(similarities)[-top_k:]
|
245 |
+
return [chunks[i] for i in top_indices]
|
246 |
+
|
247 |
def generate_chunked_response(model, prompt, max_tokens=1000, max_chunks=5):
|
248 |
full_response = ""
|
249 |
for i in range(max_chunks):
|
|
|
352 |
# Rough estimate: 1 token ~= 4 characters
|
353 |
return len(text) // 4
|
354 |
|
355 |
+
def ask_question(question: str, temperature: float, top_p: float, repetition_penalty: float, web_search: bool, chatbot: EnhancedContextDrivenChatbot) -> str:
|
|
|
|
|
|
|
356 |
model = get_model(temperature, top_p, repetition_penalty)
|
|
|
|
|
357 |
chatbot.model = model
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
if web_search:
|
360 |
contextualized_question, topics, entity_tracker, instructions = chatbot.process_question(question)
|
|
|
|
|
361 |
search_results = google_search(contextualized_question, num_results=3)
|
362 |
+
|
363 |
+
context_chunks = []
|
364 |
+
for result in search_results:
|
365 |
+
if result["text"]:
|
366 |
+
context_chunks.extend(chunk_text(result["text"]))
|
367 |
+
|
368 |
+
relevant_chunks = get_most_relevant_chunks(question, context_chunks)
|
369 |
+
|
370 |
+
prompt_parts = [
|
371 |
+
f"Question: {question}",
|
372 |
+
f"Conversation Context: {chatbot.get_context()[-1000:]}", # Last 1000 characters
|
373 |
+
"Relevant Web Search Results:"
|
374 |
+
]
|
375 |
+
|
376 |
+
for chunk in relevant_chunks:
|
377 |
+
if len(' '.join(prompt_parts)) + len(chunk) < MAX_PROMPT_CHARS:
|
378 |
+
prompt_parts.append(chunk)
|
379 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
break
|
381 |
+
|
382 |
+
if instructions:
|
383 |
+
prompt_parts.append(f"User Instructions: {instructions}")
|
384 |
+
|
385 |
+
prompt_template = """
|
386 |
+
Answer the question based on the following information:
|
387 |
+
{context}
|
388 |
+
Provide a concise and relevant answer to the question.
|
389 |
+
"""
|
390 |
+
|
391 |
+
formatted_prompt = prompt_template.format(context='\n'.join(prompt_parts))
|
392 |
+
|
393 |
+
# Generate response using the model
|
394 |
+
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
395 |
+
answer = extract_answer(full_response, instructions)
|
396 |
+
|
397 |
# Update chatbot context with the answer
|
398 |
chatbot.add_to_history(answer)
|
399 |
+
|
400 |
return answer
|
401 |
|
402 |
else: # PDF document chat
|