CH2_UE / app.py
gmustafa413's picture
Create app.py
9d60267 verified
import os
import gradio as gr
import numpy as np
import google.generativeai as genai
import faiss
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import warnings
# Suppress warnings
warnings.filterwarnings("ignore")
# Configuration
MODEL_NAME = "all-MiniLM-L6-v2"
GENAI_MODEL = "models/gemini-pro" # Updated model path
DATASET_NAME = "midrees2806/7K_Dataset"
CHUNK_SIZE = 500
TOP_K = 3
# Initialize Gemini - PUT YOUR API KEY HERE (for testing only)
GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ Replace with your actual key
genai.configure(api_key=GEMINI_API_KEY)
class GeminiRAGSystem:
def __init__(self):
self.index = None
self.chunks = []
self.dataset_loaded = False
self.loading_error = None
# Initialize embedding model
try:
self.embedding_model = SentenceTransformer(MODEL_NAME)
except Exception as e:
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
# Load dataset
self.load_dataset()
def load_dataset(self):
"""Load dataset synchronously"""
try:
dataset = load_dataset(
DATASET_NAME,
split='train',
download_mode="force_redownload"
)
if 'text' in dataset.features:
self.chunks = dataset['text'][:1000]
elif 'context' in dataset.features:
self.chunks = dataset['context'][:1000]
else:
raise ValueError("Dataset must have 'text' or 'context' field")
embeddings = self.embedding_model.encode(
self.chunks,
show_progress_bar=False,
convert_to_numpy=True
)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings.astype('float32'))
self.dataset_loaded = True
except Exception as e:
self.loading_error = str(e)
print(f"Dataset loading failed: {str(e)}")
def get_relevant_context(self, query: str) -> str:
"""Retrieve most relevant chunks"""
if not self.index:
return ""
try:
query_embed = self.embedding_model.encode(
[query],
convert_to_numpy=True
).astype('float32')
_, indices = self.index.search(query_embed, k=TOP_K)
return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)])
except Exception as e:
print(f"Search error: {str(e)}")
return ""
def generate_response(self, query: str) -> str:
"""Generate response with robust error handling"""
if not self.dataset_loaded:
if self.loading_error:
return f"⚠️ Dataset loading failed: {self.loading_error}"
return "⚠️ System initializing..."
context = self.get_relevant_context(query)
if not context:
return "No relevant context found"
prompt = f"""Answer based on this context:
{context}
Question: {query}
Answer concisely:"""
try:
model = genai.GenerativeModel(GENAI_MODEL)
response = model.generate_content(prompt)
return response.text
except Exception as e:
return f"⚠️ API Error: {str(e)}"
# Initialize system
try:
rag_system = GeminiRAGSystem()
init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}"
except Exception as e:
init_status = f"❌ Initialization failed: {str(e)}"
rag_system = None
# Create interface
with gr.Blocks(title="Chatbot") as app:
gr.Markdown("# Chatbot")
chatbot = gr.Chatbot(height=500)
query = gr.Textbox(label="Your question", placeholder="Ask something...")
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
status = gr.Textbox(label="Status", value=init_status)
def respond(message, chat_history):
if not rag_system:
return chat_history + [(message, "System initialization failed")]
response = rag_system.generate_response(message)
return chat_history + [(message, response)]
def clear_chat():
return []
submit_btn.click(respond, [query, chatbot], [chatbot])
query.submit(respond, [query, chatbot], [chatbot])
clear_btn.click(clear_chat, outputs=chatbot)
if __name__ == "__main__":
app.launch(share=True)