Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
import google.generativeai as genai | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import warnings | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
# Configuration | |
MODEL_NAME = "all-MiniLM-L6-v2" | |
GENAI_MODEL = "models/gemini-pro" # Updated model path | |
DATASET_NAME = "midrees2806/7K_Dataset" | |
CHUNK_SIZE = 500 | |
TOP_K = 3 | |
# Initialize Gemini - PUT YOUR API KEY HERE (for testing only) | |
GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ Replace with your actual key | |
genai.configure(api_key=GEMINI_API_KEY) | |
class GeminiRAGSystem: | |
def __init__(self): | |
self.index = None | |
self.chunks = [] | |
self.dataset_loaded = False | |
self.loading_error = None | |
# Initialize embedding model | |
try: | |
self.embedding_model = SentenceTransformer(MODEL_NAME) | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}") | |
# Load dataset | |
self.load_dataset() | |
def load_dataset(self): | |
"""Load dataset synchronously""" | |
try: | |
dataset = load_dataset( | |
DATASET_NAME, | |
split='train', | |
download_mode="force_redownload" | |
) | |
if 'text' in dataset.features: | |
self.chunks = dataset['text'][:1000] | |
elif 'context' in dataset.features: | |
self.chunks = dataset['context'][:1000] | |
else: | |
raise ValueError("Dataset must have 'text' or 'context' field") | |
embeddings = self.embedding_model.encode( | |
self.chunks, | |
show_progress_bar=False, | |
convert_to_numpy=True | |
) | |
self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
self.index.add(embeddings.astype('float32')) | |
self.dataset_loaded = True | |
except Exception as e: | |
self.loading_error = str(e) | |
print(f"Dataset loading failed: {str(e)}") | |
def get_relevant_context(self, query: str) -> str: | |
"""Retrieve most relevant chunks""" | |
if not self.index: | |
return "" | |
try: | |
query_embed = self.embedding_model.encode( | |
[query], | |
convert_to_numpy=True | |
).astype('float32') | |
_, indices = self.index.search(query_embed, k=TOP_K) | |
return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)]) | |
except Exception as e: | |
print(f"Search error: {str(e)}") | |
return "" | |
def generate_response(self, query: str) -> str: | |
"""Generate response with robust error handling""" | |
if not self.dataset_loaded: | |
if self.loading_error: | |
return f"⚠️ Dataset loading failed: {self.loading_error}" | |
return "⚠️ System initializing..." | |
context = self.get_relevant_context(query) | |
if not context: | |
return "No relevant context found" | |
prompt = f"""Answer based on this context: | |
{context} | |
Question: {query} | |
Answer concisely:""" | |
try: | |
model = genai.GenerativeModel(GENAI_MODEL) | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
return f"⚠️ API Error: {str(e)}" | |
# Initialize system | |
try: | |
rag_system = GeminiRAGSystem() | |
init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}" | |
except Exception as e: | |
init_status = f"❌ Initialization failed: {str(e)}" | |
rag_system = None | |
# Create interface | |
with gr.Blocks(title="Chatbot") as app: | |
gr.Markdown("# Chatbot") | |
chatbot = gr.Chatbot(height=500) | |
query = gr.Textbox(label="Your question", placeholder="Ask something...") | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear") | |
status = gr.Textbox(label="Status", value=init_status) | |
def respond(message, chat_history): | |
if not rag_system: | |
return chat_history + [(message, "System initialization failed")] | |
response = rag_system.generate_response(message) | |
return chat_history + [(message, response)] | |
def clear_chat(): | |
return [] | |
submit_btn.click(respond, [query, chatbot], [chatbot]) | |
query.submit(respond, [query, chatbot], [chatbot]) | |
clear_btn.click(clear_chat, outputs=chatbot) | |
if __name__ == "__main__": | |
app.launch(share=True) |