UE_ChatBot / app.py
gmustafa413's picture
Update app.py
784a1e4 verified
import gradio as gr
import numpy as np
import google.generativeai as genai
import faiss
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import warnings
# Suppress warnings
warnings.filterwarnings("ignore")
# Configuration - PUT YOUR API KEY HERE
GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ REPLACE WITH YOUR KEY
MODEL_NAME = "all-MiniLM-L6-v2"
GENAI_MODEL = "gemini-pro"
DATASET_NAME = "midrees2806/7K_Dataset"
CHUNK_SIZE = 500
TOP_K = 3
# Initialize Gemini with enhanced configuration
genai.configure(
api_key=GEMINI_API_KEY,
transport='rest', # Force REST API
client_options={
'api_endpoint': "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"
}
)
class GeminiRAGSystem:
def __init__(self):
self.index = None
self.chunks = []
self.dataset_loaded = False
self.loading_error = None
print("Initializing embedding model...")
try:
self.embedding_model = SentenceTransformer(MODEL_NAME)
print("Embedding model initialized successfully")
except Exception as e:
error_msg = f"Failed to initialize embedding model: {str(e)}"
print(error_msg)
raise RuntimeError(error_msg)
print("Loading dataset...")
self.load_dataset()
def load_dataset(self):
"""Load dataset with detailed error handling"""
try:
print(f"Downloading dataset: {DATASET_NAME}")
dataset = load_dataset(
DATASET_NAME,
split='train',
download_mode="force_redownload"
)
print("Dataset downloaded successfully")
if 'text' in dataset.features:
self.chunks = dataset['text'][:1000]
print(f"Loaded {len(self.chunks)} text chunks")
elif 'context' in dataset.features:
self.chunks = dataset['context'][:1000]
print(f"Loaded {len(self.chunks)} context chunks")
else:
raise ValueError("Dataset must have 'text' or 'context' field")
print("Creating embeddings...")
embeddings = self.embedding_model.encode(
self.chunks,
show_progress_bar=False,
convert_to_numpy=True
)
print(f"Created embeddings with shape {embeddings.shape}")
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings.astype('float32'))
print("FAISS index created successfully")
self.dataset_loaded = True
print("Dataset loading complete")
except Exception as e:
error_msg = f"Dataset loading failed: {str(e)}"
print(error_msg)
self.loading_error = error_msg
def get_relevant_context(self, query: str) -> str:
"""Retrieve context with debugging"""
if not self.index:
print("No index available for search")
return ""
try:
print(f"Processing query: {query}")
query_embed = self.embedding_model.encode(
[query],
convert_to_numpy=True
).astype('float32')
print("Query embedded successfully")
distances, indices = self.index.search(query_embed, k=TOP_K)
print(f"Search results - distances: {distances}, indices: {indices}")
context = "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)])
print(f"Context length: {len(context)} characters")
return context
except Exception as e:
print(f"Search error: {str(e)}")
return ""
def generate_response(self, query: str) -> str:
"""Generate response with detailed error handling"""
if not self.dataset_loaded:
msg = f"⚠️ Dataset loading failed: {self.loading_error}" if self.loading_error else "⚠️ System initializing..."
print(msg)
return msg
print(f"\n{'='*40}\nNew Query: {query}\n{'='*40}")
context = self.get_relevant_context(query)
if not context:
print("No relevant context found")
return "No relevant context found"
prompt = f"""Answer based on this context:
{context}
Question: {query}
Answer concisely:"""
print(f"\nPrompt sent to Gemini:\n{prompt}\n")
try:
model = genai.GenerativeModel(GENAI_MODEL)
response = model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
temperature=0.3,
max_output_tokens=1000
)
)
print(f"Raw API response: {response}")
if response.candidates and response.candidates[0].content.parts:
answer = response.candidates[0].content.parts[0].text
print(f"Answer: {answer}")
return answer
print("⚠️ Empty response from API")
return "⚠️ No response from API"
except Exception as e:
error_msg = f"⚠️ API Error: {str(e)}"
print(error_msg)
return error_msg
# Initialize system with verbose logging
print("Initializing RAG system...")
try:
rag_system = GeminiRAGSystem()
init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}"
print(init_status)
except Exception as e:
init_status = f"❌ Initialization failed: {str(e)}"
print(init_status)
rag_system = None
# Create interface with enhanced debugging
with gr.Blocks(title="Document Chatbot") as app:
gr.Markdown("# Document Chatbot with Gemini")
with gr.Row():
chatbot = gr.Chatbot(height=500, label="Chat History")
with gr.Row():
query = gr.Textbox(label="Your question", placeholder="Ask about the documents...")
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
status = gr.Textbox(label="System Status", value=init_status, interactive=False)
def respond(message, chat_history):
print(f"\n{'='*40}\nUser Query: {message}\n{'='*40}")
if not rag_system:
error_msg = "System initialization failed"
print(error_msg)
return chat_history + [(message, error_msg)]
response = rag_system.generate_response(message)
return chat_history + [(message, response)]
def clear_chat():
print("Chat cleared")
return []
submit_btn.click(respond, [query, chatbot], [chatbot])
query.submit(respond, [query, chatbot], [chatbot])
clear_btn.click(clear_chat, outputs=chatbot)
if __name__ == "__main__":
print("Launching Gradio interface...")
app.launch(debug=True)