UOE_ChatBot / app.py
gmustafa413's picture
Update app.py
e25e79d verified
# app.py
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import gradio as gr
import torch
from tqdm import tqdm
from groq import Groq
# Load dataset
dataset = load_dataset("midrees2806/7K_Dataset")
print("Dataset sample:", dataset['train'][0])
# Initialize sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
# Batch processing function
def generate_embeddings_batch(texts):
return model.encode(texts, batch_size=1024, convert_to_tensor=True, device=device).cpu().numpy()
# Prepare embeddings
train_dataset = dataset['train']
texts = [data['text'] for data in train_dataset]
batch_size = 1024
chunked_embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
batch = texts[i:i+batch_size]
embeddings = generate_embeddings_batch(batch)
chunked_embeddings.append(embeddings)
chunked_embeddings = np.vstack(chunked_embeddings)
# Initialize FAISS index
dimension = chunked_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(chunked_embeddings)
# Initialize Groq client
client = Groq(api_key="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")
def get_groq_response(query):
try:
# Get relevant context from FAISS
faiss_results = search_in_faiss(query)
context = "\n".join([result[0] for result in faiss_results])
# Create a prompt that forces the model to only use the provided context
prompt = f"""
You are an expert assistant for University of Education Lahore and its sub-campuses ONLY.
You must ONLY use the following context to answer questions. If the answer isn't in the context,
say "I don't have information about that in the University of Education Lahore dataset."
Context:
{context}
Question: {query}
Answer:"""
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model="llama3-70b-8192",
temperature=0.3,
max_tokens=1024
)
return chat_completion.choices[0].message.content
except Exception as e:
print(f"Error in Groq response: {str(e)}")
return "I encountered an error while processing your request."
def search_in_faiss(query):
query_embedding = model.encode([query], convert_to_tensor=True, device=device).cpu().numpy()
distances, indices = index.search(query_embedding, k=3)
return [(dataset['train'][int(idx)]['text'], float(dist)) for idx, dist in zip(indices[0], distances[0])]
def respond(message, chat_history):
try:
faiss_results = search_in_faiss(message)
model_response = get_groq_response(message)
bot_response = "**Relevant Information from Dataset:**\n\n"
for result in faiss_results:
bot_response += f"- {result[0]} (Similarity: {result[1]:.4f})\n\n"
bot_response = "\n**Model Response:**\n\n" + model_response
#+
return "", chat_history + [(message, bot_response)]
except Exception as e:
print(f"Error: {str(e)}")
return "", chat_history + [(message, f"Error processing request: {str(e)}")]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# <center>UoE Chatbot</center>")
gr.Markdown("<center>University of Education Lahore Information Bot</center>")
gr.Markdown("<center>This bot only answers questions about University of Education Lahore and its sub-campuses</center>")
chatbot = gr.Chatbot(height=500, bubble_full_width=False)
with gr.Row():
msg = gr.Textbox(label="Type your message here...", placeholder="Ask about University of Education Lahore...", scale=7)
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear Chat")
msg.submit(respond, [msg, chatbot], [msg, chatbot])
submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
clear_btn.click(lambda: None, None, chatbot, queue=False)
demo.launch()