import os import gradio as gr import numpy as np import google.generativeai as genai import faiss from sentence_transformers import SentenceTransformer from datasets import load_dataset from dotenv import load_dotenv import asyncio import time # Load environment variables load_dotenv() # Configuration MODEL_NAME = "all-MiniLM-L6-v2" GENAI_MODEL = "gemini-pro" DATASET_NAME = "midrees2806/7K_Dataset" CHUNK_SIZE = 500 TOP_K = 3 class GeminiRAGSystem: def __init__(self): self.index = None self.chunks = [] self.dataset_loaded = False self.loading_error = None self.gemini_api_key = os.getenv("AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0") # Initialize embedding model try: self.embedding_model = SentenceTransformer(MODEL_NAME) except Exception as e: raise RuntimeError(f"Failed to initialize embedding model: {str(e)}") # Configure Gemini if self.gemini_api_key: genai.configure(api_key=self.gemini_api_key) # Start dataset loading self.load_dataset() def load_dataset(self): """Load dataset synchronously""" try: # Load dataset directly dataset = load_dataset( DATASET_NAME, split='train', download_mode="force_redownload" ) # Process dataset if 'text' in dataset.features: self.chunks = dataset['text'][:1000] # Limit to first 1000 entries elif 'context' in dataset.features: self.chunks = dataset['context'][:1000] else: raise ValueError("Dataset must have 'text' or 'context' field") # Create embeddings embeddings = self.embedding_model.encode( self.chunks, show_progress_bar=False, convert_to_numpy=True ) self.index = faiss.IndexFlatL2(embeddings.shape[1]) self.index.add(embeddings.astype('float32')) self.dataset_loaded = True except Exception as e: self.loading_error = str(e) print(f"Dataset loading failed: {str(e)}") def get_relevant_context(self, query: str) -> str: """Retrieve most relevant chunks""" if not self.index: return "" try: query_embed = self.embedding_model.encode( [query], convert_to_numpy=True ).astype('float32') _, indices = self.index.search(query_embed, k=TOP_K) return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)]) except Exception as e: print(f"Search error: {str(e)}") return "" def generate_response(self, query: str) -> str: """Generate response with robust error handling""" if not self.dataset_loaded: if self.loading_error: return f"⚠️ Dataset loading failed: {self.loading_error}" return "⚠️ Dataset is still loading, please wait..." if not self.gemini_api_key: return "🔑 Please set your Gemini API key in environment variables" context = self.get_relevant_context(query) if not context: return "No relevant context found" prompt = f"""Answer based on this context: {context} Question: {query} Answer concisely:""" try: model = genai.GenerativeModel(GENAI_MODEL) response = model.generate_content(prompt) return response.text except Exception as e: return f"⚠️ API Error: {str(e)}" # Initialize system try: rag_system = GeminiRAGSystem() except Exception as e: raise RuntimeError(f"System initialization failed: {str(e)}") # Create interface with gr.Blocks(title="UE Chatbot") as app: gr.Markdown("# UE 24 Hour Service") with gr.Row(): chatbot = gr.Chatbot( height=500, avatar_images=(None, (None, "https://huggingface.co/spaces/groq/Groq-LLM/resolve/main/groq_logo.png")), bubble_full_width=False ) with gr.Row(): query = gr.Textbox( label="Your question", placeholder="Ask your question...", scale=4 ) submit_btn = gr.Button("Submit", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary") # Status indicator status = gr.Textbox( label="System Status", value="Loading dataset..." if not rag_system.dataset_loaded else "Ready", interactive=False ) # Event handlers def respond(message, chat_history): try: response = rag_system.generate_response(message) chat_history.append((message, response)) return "", chat_history except Exception as e: chat_history.append((message, f"Error: {str(e)}")) return "", chat_history def clear_chat(): return [] submit_btn.click(respond, [query, chatbot], [query, chatbot]) query.submit(respond, [query, chatbot], [query, chatbot]) clear_btn.click(clear_chat, outputs=chatbot) if __name__ == "__main__": app.launch(share=True)