Spaces:
Sleeping
Sleeping
import faiss | |
import numpy as np | |
import gradio as gr | |
import requests | |
import json | |
import re | |
import torch | |
from transformers import AutoTokenizer | |
from langdetect import detect | |
from sentence_transformers import SentenceTransformer | |
from concurrent.futures import ThreadPoolExecutor | |
from tqdm import tqdm | |
# Configuration | |
GROQ_API_KEY = "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" # Replace with your actual key | |
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
DATASET_URL = "https://huggingface.co/datasets/midrees2806/7K_Dataset/resolve/main/University_of_Education_Lahore_FAQ.json" | |
CHUNK_SIZE = 512 | |
MAX_TOKENS = 4096 | |
WORKERS = 4 | |
EMBEDDING_BATCH_SIZE = 32 | |
# Load the embedding model | |
model = SentenceTransformer(MODEL_NAME) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
class UniversityKnowledgeBase: | |
def __init__(self): | |
self.index = faiss.IndexFlatL2(model.get_sentence_embedding_dimension()) | |
self.chunks = [] | |
self.loaded = False | |
self.total_chunks = 0 | |
def load_dataset(self): | |
"""Loads and thoroughly processes the University dataset""" | |
try: | |
print("\n" + "="*50) | |
print("Loading University of Education, Lahore dataset...") | |
print("="*50 + "\n") | |
# Fetch dataset with error handling | |
response = requests.get(DATASET_URL, timeout=30) | |
if response.status_code != 200: | |
raise Exception(f"Failed to fetch dataset. HTTP Status: {response.status_code}") | |
# Parse JSON content | |
try: | |
data = response.json() | |
except json.JSONDecodeError: | |
raise Exception("Invalid JSON format in dataset") | |
if not isinstance(data, list): | |
raise Exception("Dataset format is invalid. Expected a list of Q&A pairs.") | |
# Process all content with progress tracking | |
self.chunks = [] | |
with tqdm(data, desc="Processing dataset") as progress_bar: | |
for item in progress_bar: | |
if isinstance(item, dict): | |
if 'question' in item and 'answer' in item: | |
# Create comprehensive Q&A chunks | |
self.chunks.append(f"QUESTION: {item['question'].strip()}\nANSWER: {item['answer'].strip()}\n") | |
elif 'text' in item: | |
# Process text content with semantic chunking | |
text = item['text'].strip() | |
if len(text) > CHUNK_SIZE: | |
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text) | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) < CHUNK_SIZE: | |
current_chunk += " " + sentence | |
else: | |
if current_chunk: | |
self.chunks.append(current_chunk.strip()) | |
current_chunk = sentence | |
if current_chunk: | |
self.chunks.append(current_chunk.strip()) | |
else: | |
self.chunks.append(text) | |
self.total_chunks = len(self.chunks) | |
if self.total_chunks == 0: | |
raise Exception("No valid content found in the dataset") | |
print(f"\nSuccessfully processed {self.total_chunks} knowledge chunks from dataset") | |
# Generate embeddings in batches with progress tracking | |
print("\nGenerating embeddings...") | |
embeddings = [] | |
for i in tqdm(range(0, self.total_chunks, EMBEDDING_BATCH_SIZE), | |
desc="Creating embeddings", | |
total=(self.total_chunks//EMBEDDING_BATCH_SIZE)+1): | |
batch = self.chunks[i:i+EMBEDDING_BATCH_SIZE] | |
batch_embeddings = model.encode( | |
batch, | |
convert_to_tensor=True, | |
show_progress_bar=False | |
).cpu().numpy().astype('float32') | |
embeddings.append(batch_embeddings) | |
# Combine all embeddings and build FAISS index | |
all_embeddings = np.concatenate(embeddings) | |
self.index.add(all_embeddings) | |
self.loaded = True | |
return f"✅ Successfully loaded {self.total_chunks} knowledge chunks from University dataset" | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
return f"❌ Error loading dataset: {str(e)}" | |
def find_relevant_context(self, query, k=5): | |
"""Finds the most relevant context with enhanced retrieval""" | |
if not self.loaded or not self.chunks: | |
return None | |
try: | |
# Generate query embedding | |
query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy().astype('float32') | |
# Search with higher k initially for better context | |
_, indices = self.index.search(query_embedding, k*2) | |
# Get unique chunks (avoid duplicates) | |
unique_indices = list(dict.fromkeys(indices[0])) | |
# Select top-k most relevant unique chunks | |
selected_chunks = [] | |
for idx in unique_indices[:k]: | |
if 0 <= idx < len(self.chunks): | |
selected_chunks.append(self.chunks[idx]) | |
return "\n\n---\n\n".join(selected_chunks) if selected_chunks else None | |
except Exception as e: | |
print(f"Context retrieval error: {str(e)}") | |
return None | |
# Initialize the knowledge base | |
knowledge_base = UniversityKnowledgeBase() | |
def detect_language(text): | |
"""Enhanced language detection with Urdu support""" | |
try: | |
text = text.lower().strip() | |
# Roman Urdu detection | |
roman_urdu_keywords = ['hai', 'ho', 'hain', 'ka', 'ki', 'ke', 'main', 'tum', 'ap', 'kyun', 'kya'] | |
if any(keyword in text for keyword in roman_urdu_keywords): | |
return "Roman Urdu" | |
# Standard detection | |
lang = detect(text) | |
if lang == "ur": | |
return "Urdu" | |
elif lang == "hi": # Hindi/Urdu handling | |
return "Urdu" if not text.isascii() else "Roman Urdu" | |
return "English" | |
except: | |
return "English" | |
def get_groq_response(context, user_query, language="English"): | |
"""Generates accurate responses strictly based on context""" | |
headers = { | |
"Authorization": f"Bearer {GROQ_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
# Language-specific system prompts | |
system_prompts = { | |
"Urdu": """ | |
آپ یونیورسٹی آف ایجوکیشن، لاہور کا سرکاری چیٹ بوٹ ہیں۔ درج ذیل معلومات کی بنیاد پر درست جواب دیں۔ | |
اگر جواب دستیاب نہ ہو تو کہیں: | |
"معذرت، یہ معلومات دستیاب نہیں۔ براہ کرم یونیورسٹی کی ویب سائٹ دیکھیں۔" | |
""", | |
"Roman Urdu": """ | |
Aap University of Education, Lahore ka chatbot hain. Diye gaye context ke hisab se jawab dein. | |
Agar jawab nahin mila to kehain: | |
"Maazrat, yeh maloomat mojood nahin. University ki website check karein." | |
""", | |
"English": """ | |
You are the official chatbot of University of Education, Lahore. | |
Answer STRICTLY based on the provided context. If the answer isn't available, say: | |
"I'm sorry, this information isn't available. Please check the university website." | |
""" | |
} | |
payload = { | |
"model": "mixtral-8x7b-32768", | |
"messages": [ | |
{"role": "system", "content": system_prompts.get(language, system_prompts["English"])}, | |
{"role": "user", "content": f"University Context:\n{context}\n\nQuestion: {user_query}"} | |
], | |
"temperature": 0.1, # Low temperature for factual accuracy | |
"max_tokens": MAX_TOKENS, | |
"top_p": 0.9 | |
} | |
try: | |
response = requests.post( | |
"https://api.groq.com/openai/v1/chat/completions", | |
headers=headers, | |
json=payload, | |
timeout=30 | |
) | |
if response.status_code != 200: | |
print(f"API Error {response.status_code}: {response.text[:200]}") | |
return None | |
return response.json().get("choices", [{}])[0].get("message", {}).get("content", "") | |
except Exception as e: | |
print(f"API Request Failed: {str(e)}") | |
return None | |
def chatbot_response(user_input, chat_history): | |
"""Handles user queries with comprehensive response generation""" | |
if not user_input.strip(): | |
return chat_history + [(user_input, "Please enter a valid question.")] | |
# Detect language | |
language = detect_language(user_input) | |
# Retrieve relevant context (more chunks for better accuracy) | |
context = knowledge_base.find_relevant_context(user_input, k=5) | |
# Handle no context found | |
if not context: | |
error_messages = { | |
"Urdu": "معذرت، یہ معلومات دستیاب نہیں۔ براہ کرم یونیورسٹی کی ویب سائٹ دیکھیں۔", | |
"Roman Urdu": "Maazrat, yeh maloomat mojood nahin. University ki website check karein.", | |
"English": "I'm sorry, this information isn't available. Please check the university website." | |
} | |
return chat_history + [(user_input, error_messages.get(language, error_messages["English"]))] | |
# Generate response | |
response = get_groq_response(context, user_input, language) | |
# Fallback if API fails | |
if not response: | |
fallback_messages = { | |
"Urdu": "معذرت، نظام میں عارضی خرابی ہے۔ بعد میں کوشش کریں۔", | |
"Roman Urdu": "Maazrat, system mein masla hai. Baad mein koshish karein.", | |
"English": "Sorry, there's a temporary system issue. Please try again later." | |
} | |
response = fallback_messages.get(language, fallback_messages["English"]) | |
return chat_history + [(user_input, response)] | |
# Gradio Interface | |
with gr.Blocks(title="University of Education ChatBot", theme=gr.themes.Soft()) as app: | |
gr.Markdown(""" | |
<div style='text-align: center;'> | |
<h1>University of Education, Lahore</h1> | |
<h2>Official Information ChatBot</h2> | |
<p>Ask any question about the university in English, Urdu, or Roman Urdu</p> | |
</div> | |
""") | |
# Initialize dataset | |
load_status = knowledge_base.load_dataset() | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Knowledge Base Status") | |
status = gr.Textbox( | |
label="Dataset Status", | |
value=load_status, | |
interactive=False, | |
lines=2 | |
) | |
reload_btn = gr.Button("🔄 Reload Knowledge Base", variant="secondary") | |
gr.Markdown(""" | |
**Note:** This chatbot answers strictly based on the official University of Education, Lahore dataset. | |
""") | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot( | |
height=500, | |
label="Conversation History", | |
bubble_full_width=False | |
) | |
question = gr.Textbox( | |
label="Your Question", | |
placeholder="Type your question about the university...", | |
lines=2, | |
max_lines=5 | |
) | |
with gr.Row(): | |
ask_btn = gr.Button("Ask Question", variant="primary") | |
clear_btn = gr.Button("Clear Conversation", variant="secondary") | |
# Event handlers | |
reload_btn.click( | |
fn=lambda: knowledge_base.load_dataset(), | |
inputs=None, | |
outputs=status, | |
queue=False | |
) | |
ask_btn.click( | |
fn=chatbot_response, | |
inputs=[question, chatbot], | |
outputs=chatbot, | |
queue=True | |
).then(lambda: "", None, question) | |
clear_btn.click( | |
fn=lambda: [], | |
inputs=None, | |
outputs=chatbot, | |
queue=False | |
) | |
# Launch the application | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", server_port=7860) |