Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import google.generativeai as genai | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
import warnings | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
# Configuration - PUT YOUR API KEY HERE | |
GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ REPLACE WITH YOUR KEY | |
MODEL_NAME = "all-MiniLM-L6-v2" | |
GENAI_MODEL = "gemini-pro" | |
DATASET_NAME = "midrees2806/7K_Dataset" | |
CHUNK_SIZE = 500 | |
TOP_K = 3 | |
# Initialize Gemini with enhanced configuration | |
genai.configure( | |
api_key=GEMINI_API_KEY, | |
transport='rest', # Force REST API | |
client_options={ | |
'api_endpoint': "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" | |
} | |
) | |
class GeminiRAGSystem: | |
def __init__(self): | |
self.index = None | |
self.chunks = [] | |
self.dataset_loaded = False | |
self.loading_error = None | |
print("Initializing embedding model...") | |
try: | |
self.embedding_model = SentenceTransformer(MODEL_NAME) | |
print("Embedding model initialized successfully") | |
except Exception as e: | |
error_msg = f"Failed to initialize embedding model: {str(e)}" | |
print(error_msg) | |
raise RuntimeError(error_msg) | |
print("Loading dataset...") | |
self.load_dataset() | |
def load_dataset(self): | |
"""Load dataset with detailed error handling""" | |
try: | |
print(f"Downloading dataset: {DATASET_NAME}") | |
dataset = load_dataset( | |
DATASET_NAME, | |
split='train', | |
download_mode="force_redownload" | |
) | |
print("Dataset downloaded successfully") | |
if 'text' in dataset.features: | |
self.chunks = dataset['text'][:1000] | |
print(f"Loaded {len(self.chunks)} text chunks") | |
elif 'context' in dataset.features: | |
self.chunks = dataset['context'][:1000] | |
print(f"Loaded {len(self.chunks)} context chunks") | |
else: | |
raise ValueError("Dataset must have 'text' or 'context' field") | |
print("Creating embeddings...") | |
embeddings = self.embedding_model.encode( | |
self.chunks, | |
show_progress_bar=False, | |
convert_to_numpy=True | |
) | |
print(f"Created embeddings with shape {embeddings.shape}") | |
self.index = faiss.IndexFlatL2(embeddings.shape[1]) | |
self.index.add(embeddings.astype('float32')) | |
print("FAISS index created successfully") | |
self.dataset_loaded = True | |
print("Dataset loading complete") | |
except Exception as e: | |
error_msg = f"Dataset loading failed: {str(e)}" | |
print(error_msg) | |
self.loading_error = error_msg | |
def get_relevant_context(self, query: str) -> str: | |
"""Retrieve context with debugging""" | |
if not self.index: | |
print("No index available for search") | |
return "" | |
try: | |
print(f"Processing query: {query}") | |
query_embed = self.embedding_model.encode( | |
[query], | |
convert_to_numpy=True | |
).astype('float32') | |
print("Query embedded successfully") | |
distances, indices = self.index.search(query_embed, k=TOP_K) | |
print(f"Search results - distances: {distances}, indices: {indices}") | |
context = "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)]) | |
print(f"Context length: {len(context)} characters") | |
return context | |
except Exception as e: | |
print(f"Search error: {str(e)}") | |
return "" | |
def generate_response(self, query: str) -> str: | |
"""Generate response with detailed error handling""" | |
if not self.dataset_loaded: | |
msg = f"⚠️ Dataset loading failed: {self.loading_error}" if self.loading_error else "⚠️ System initializing..." | |
print(msg) | |
return msg | |
print(f"\n{'='*40}\nNew Query: {query}\n{'='*40}") | |
context = self.get_relevant_context(query) | |
if not context: | |
print("No relevant context found") | |
return "No relevant context found" | |
prompt = f"""Answer based on this context: | |
{context} | |
Question: {query} | |
Answer concisely:""" | |
print(f"\nPrompt sent to Gemini:\n{prompt}\n") | |
try: | |
model = genai.GenerativeModel(GENAI_MODEL) | |
response = model.generate_content( | |
prompt, | |
generation_config=genai.types.GenerationConfig( | |
temperature=0.3, | |
max_output_tokens=1000 | |
) | |
) | |
print(f"Raw API response: {response}") | |
if response.candidates and response.candidates[0].content.parts: | |
answer = response.candidates[0].content.parts[0].text | |
print(f"Answer: {answer}") | |
return answer | |
print("⚠️ Empty response from API") | |
return "⚠️ No response from API" | |
except Exception as e: | |
error_msg = f"⚠️ API Error: {str(e)}" | |
print(error_msg) | |
return error_msg | |
# Initialize system with verbose logging | |
print("Initializing RAG system...") | |
try: | |
rag_system = GeminiRAGSystem() | |
init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}" | |
print(init_status) | |
except Exception as e: | |
init_status = f"❌ Initialization failed: {str(e)}" | |
print(init_status) | |
rag_system = None | |
# Create interface with enhanced debugging | |
with gr.Blocks(title="Document Chatbot") as app: | |
gr.Markdown("# Document Chatbot with Gemini") | |
with gr.Row(): | |
chatbot = gr.Chatbot(height=500, label="Chat History") | |
with gr.Row(): | |
query = gr.Textbox(label="Your question", placeholder="Ask about the documents...") | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.Button("Clear", variant="secondary") | |
status = gr.Textbox(label="System Status", value=init_status, interactive=False) | |
def respond(message, chat_history): | |
print(f"\n{'='*40}\nUser Query: {message}\n{'='*40}") | |
if not rag_system: | |
error_msg = "System initialization failed" | |
print(error_msg) | |
return chat_history + [(message, error_msg)] | |
response = rag_system.generate_response(message) | |
return chat_history + [(message, response)] | |
def clear_chat(): | |
print("Chat cleared") | |
return [] | |
submit_btn.click(respond, [query, chatbot], [chatbot]) | |
query.submit(respond, [query, chatbot], [chatbot]) | |
clear_btn.click(clear_chat, outputs=chatbot) | |
if __name__ == "__main__": | |
print("Launching Gradio interface...") | |
app.launch(debug=True) |