UE_ChatBot / app.py
gmustafa413's picture
Update app.py
911a038 verified
raw
history blame
6.08 kB
import os
import gradio as gr
import numpy as np
import google.generativeai as genai
import faiss
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from dotenv import load_dotenv
import threading
# Load environment variables
load_dotenv()
# Configuration
MODEL_NAME = "all-MiniLM-L6-v2"
GENAI_MODEL = "gemini-pro"
DATASET_NAME = "midrees2806/7K_Dataset"
CHUNK_SIZE = 500
TOP_K = 3
class GeminiRAGSystem:
def __init__(self):
self.index = None
self.chunks = []
self.dataset_loaded = False
self.loading_error = None
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
# Initialize embedding model
try:
self.embedding_model = SentenceTransformer(MODEL_NAME)
except Exception as e:
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
# Configure Gemini
if self.gemini_api_key:
genai.configure(api_key=self.gemini_api_key)
# Start dataset loading in background
self.load_dataset_in_background()
def load_dataset_in_background(self):
"""Load dataset in a background thread"""
def load_task():
try:
# Load dataset directly
dataset = load_dataset(
DATASET_NAME,
split='train',
download_mode="force_redownload" # Fixes extraction error
)
# Process dataset
if 'text' in dataset.features:
self.chunks = dataset['text'][:1000] # Limit to first 1000 entries
elif 'context' in dataset.features:
self.chunks = dataset['context'][:1000]
else:
raise ValueError("Dataset must have 'text' or 'context' field")
# Create embeddings
embeddings = self.embedding_model.encode(
self.chunks,
show_progress_bar=False,
convert_to_numpy=True
)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings.astype('float32'))
self.dataset_loaded = True
except Exception as e:
self.loading_error = str(e)
print(f"Dataset loading failed: {str(e)}")
# Start the loading thread
threading.Thread(target=load_task, daemon=True).start()
def get_relevant_context(self, query: str) -> str:
"""Retrieve most relevant chunks"""
if not self.index:
return ""
try:
query_embed = self.embedding_model.encode(
[query],
convert_to_numpy=True
).astype('float32')
_, indices = self.index.search(query_embed, k=TOP_K)
return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)])
except Exception as e:
print(f"Search error: {str(e)}")
return ""
def generate_response(self, query: str) -> str:
"""Generate response with robust error handling"""
if not self.dataset_loaded:
if self.loading_error:
return f"⚠️ Dataset loading failed: {self.loading_error}"
return "⚠️ Dataset is still loading, please wait..."
if not self.gemini_api_key:
return "🔑 Please set your Gemini API key in environment variables"
context = self.get_relevant_context(query)
if not context:
return "No relevant context found"
prompt = f"""Answer based on this context:
{context}
Question: {query}
Answer concisely:"""
try:
model = genai.GenerativeModel(GENAI_MODEL)
response = model.generate_content(prompt)
return response.text
except Exception as e:
return f"⚠️ API Error: {str(e)}"
# Initialize system
try:
rag_system = GeminiRAGSystem()
except Exception as e:
raise RuntimeError(f"System initialization failed: {str(e)}")
# Create interface
with gr.Blocks(title="UE Chatbot") as app:
gr.Markdown("# UE 24 Hour Service")
with gr.Row():
chatbot = gr.Chatbot(height=500, label="Chat History",
avatar_images=(None, (None, "https://huggingface.co/spaces/groq/Groq-LLM/resolve/main/groq_logo.png")),
bubble_full_width=False)
with gr.Row():
query = gr.Textbox(label="Your question",
placeholder="Ask your question...",
scale=4)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary")
# Status indicator
status = gr.Textbox(label="System Status",
value="Initializing...",
interactive=False)
# Update status periodically
def update_status():
if rag_system.loading_error:
return f"Error: {rag_system.loading_error}"
return "Ready" if rag_system.dataset_loaded else "Loading dataset..."
app.load(update_status, None, status, every=1)
# Event handlers
def respond(message, chat_history):
try:
response = rag_system.generate_response(message)
chat_history.append((message, response))
return "", chat_history
except Exception as e:
chat_history.append((message, f"Error: {str(e)}"))
return "", chat_history
def clear_chat():
return []
submit_btn.click(respond, [query, chatbot], [query, chatbot])
query.submit(respond, [query, chatbot], [query, chatbot])
clear_btn.click(clear_chat, outputs=chatbot)
if __name__ == "__main__":
app.launch(share=True)