Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
import google.generativeai as genai | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
from dotenv import load_dotenv | |
import threading | |
# Load environment variables | |
load_dotenv() | |
# Configuration | |
MODEL_NAME = "all-MiniLM-L6-v2" | |
GENAI_MODEL = "gemini-pro" | |
DATASET_NAME = "midrees2806/7K_Dataset" | |
CHUNK_SIZE = 500 | |
TOP_K = 3 | |
class GeminiRAGSystem: | |
def __init__(self): | |
self.index = None | |
self.chunks = [] | |
self.dataset_loaded = False | |
self.loading_error = None | |
self.gemini_api_key = os.getenv("GEMINI_API_KEY") | |
# Initialize embedding model | |
try: | |
self.embedding_model = SentenceTransformer(MODEL_NAME) | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}") | |
# Configure Gemini | |
if self.gemini_api_key: | |
genai.configure(api_key=self.gemini_api_key) | |
# Start dataset loading in background | |
self.load_dataset_in_background() | |
def load_dataset_in_background(self): | |
"""Load dataset in a background thread""" | |
def load_task(): | |
try: | |
# Load dataset directly | |
dataset = load_dataset( | |
DATASET_NAME, | |
split='train', | |
download_mode="force_redownload" # Fixes extraction error | |
) | |
# Process dataset | |
if 'text' in dataset.features: | |
self.chunks = dataset['text'][:1000] # Limit to first 1000 entries | |
elif 'context' in dataset.features: | |
self.chunks = dataset['context'][:1000] | |
else: | |
raise ValueError("Dataset must have 'text' or 'context' field") | |
# Create embeddings | |
embeddings = self.embedding_model.encode( | |
self.chunks, | |
show_progress_bar=False, | |
convert_to_numpy=True | |
) | |
self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
self.index.add(embeddings.astype('float32')) | |
self.dataset_loaded = True | |
except Exception as e: | |
self.loading_error = str(e) | |
print(f"Dataset loading failed: {str(e)}") | |
# Start the loading thread | |
threading.Thread(target=load_task, daemon=True).start() | |
def get_relevant_context(self, query: str) -> str: | |
"""Retrieve most relevant chunks""" | |
if not self.index: | |
return "" | |
try: | |
query_embed = self.embedding_model.encode( | |
[query], | |
convert_to_numpy=True | |
).astype('float32') | |
_, indices = self.index.search(query_embed, k=TOP_K) | |
return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)]) | |
except Exception as e: | |
print(f"Search error: {str(e)}") | |
return "" | |
def generate_response(self, query: str) -> str: | |
"""Generate response with robust error handling""" | |
if not self.dataset_loaded: | |
if self.loading_error: | |
return f"⚠️ Dataset loading failed: {self.loading_error}" | |
return "⚠️ Dataset is still loading, please wait..." | |
if not self.gemini_api_key: | |
return "🔑 Please set your Gemini API key in environment variables" | |
context = self.get_relevant_context(query) | |
if not context: | |
return "No relevant context found" | |
prompt = f"""Answer based on this context: | |
{context} | |
Question: {query} | |
Answer concisely:""" | |
try: | |
model = genai.GenerativeModel(GENAI_MODEL) | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
return f"⚠️ API Error: {str(e)}" | |
# Initialize system | |
try: | |
rag_system = GeminiRAGSystem() | |
except Exception as e: | |
raise RuntimeError(f"System initialization failed: {str(e)}") | |
# Create interface | |
with gr.Blocks(title="UE Chatbot") as app: | |
gr.Markdown("# UE 24 Hour Service") | |
with gr.Row(): | |
chatbot = gr.Chatbot(height=500, label="Chat History", | |
avatar_images=(None, (None, "https://huggingface.co/spaces/groq/Groq-LLM/resolve/main/groq_logo.png")), | |
bubble_full_width=False) | |
with gr.Row(): | |
query = gr.Textbox(label="Your question", | |
placeholder="Ask your question...", | |
scale=4) | |
submit_btn = gr.Button("Submit", variant="primary", scale=1) | |
with gr.Row(): | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
# Status indicator | |
status = gr.Textbox(label="System Status", | |
value="Initializing...", | |
interactive=False) | |
# Update status periodically | |
def update_status(): | |
if rag_system.loading_error: | |
return f"Error: {rag_system.loading_error}" | |
return "Ready" if rag_system.dataset_loaded else "Loading dataset..." | |
app.load(update_status, None, status, every=1) | |
# Event handlers | |
def respond(message, chat_history): | |
try: | |
response = rag_system.generate_response(message) | |
chat_history.append((message, response)) | |
return "", chat_history | |
except Exception as e: | |
chat_history.append((message, f"Error: {str(e)}")) | |
return "", chat_history | |
def clear_chat(): | |
return [] | |
submit_btn.click(respond, [query, chatbot], [query, chatbot]) | |
query.submit(respond, [query, chatbot], [query, chatbot]) | |
clear_btn.click(clear_chat, outputs=chatbot) | |
if __name__ == "__main__": | |
app.launch(share=True) |