Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import google.generativeai as genai
|
5 |
+
import faiss
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from datasets import load_dataset
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
# Suppress warnings
|
11 |
+
warnings.filterwarnings("ignore")
|
12 |
+
|
13 |
+
# Configuration
|
14 |
+
MODEL_NAME = "all-MiniLM-L6-v2"
|
15 |
+
GENAI_MODEL = "models/gemini-pro" # Updated model path
|
16 |
+
DATASET_NAME = "midrees2806/7K_Dataset"
|
17 |
+
CHUNK_SIZE = 500
|
18 |
+
TOP_K = 3
|
19 |
+
|
20 |
+
# Initialize Gemini - PUT YOUR API KEY HERE (for testing only)
|
21 |
+
GEMINI_API_KEY = "AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0" # ⚠️ Replace with your actual key
|
22 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
23 |
+
|
24 |
+
class GeminiRAGSystem:
|
25 |
+
def __init__(self):
|
26 |
+
self.index = None
|
27 |
+
self.chunks = []
|
28 |
+
self.dataset_loaded = False
|
29 |
+
self.loading_error = None
|
30 |
+
|
31 |
+
# Initialize embedding model
|
32 |
+
try:
|
33 |
+
self.embedding_model = SentenceTransformer(MODEL_NAME)
|
34 |
+
except Exception as e:
|
35 |
+
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
|
36 |
+
|
37 |
+
# Load dataset
|
38 |
+
self.load_dataset()
|
39 |
+
|
40 |
+
def load_dataset(self):
|
41 |
+
"""Load dataset synchronously"""
|
42 |
+
try:
|
43 |
+
dataset = load_dataset(
|
44 |
+
DATASET_NAME,
|
45 |
+
split='train',
|
46 |
+
download_mode="force_redownload"
|
47 |
+
)
|
48 |
+
|
49 |
+
if 'text' in dataset.features:
|
50 |
+
self.chunks = dataset['text'][:1000]
|
51 |
+
elif 'context' in dataset.features:
|
52 |
+
self.chunks = dataset['context'][:1000]
|
53 |
+
else:
|
54 |
+
raise ValueError("Dataset must have 'text' or 'context' field")
|
55 |
+
|
56 |
+
embeddings = self.embedding_model.encode(
|
57 |
+
self.chunks,
|
58 |
+
show_progress_bar=False,
|
59 |
+
convert_to_numpy=True
|
60 |
+
)
|
61 |
+
self.index = faiss.IndexFlatL2(embeddings.shape[1])
|
62 |
+
self.index.add(embeddings.astype('float32'))
|
63 |
+
|
64 |
+
self.dataset_loaded = True
|
65 |
+
except Exception as e:
|
66 |
+
self.loading_error = str(e)
|
67 |
+
print(f"Dataset loading failed: {str(e)}")
|
68 |
+
|
69 |
+
def get_relevant_context(self, query: str) -> str:
|
70 |
+
"""Retrieve most relevant chunks"""
|
71 |
+
if not self.index:
|
72 |
+
return ""
|
73 |
+
|
74 |
+
try:
|
75 |
+
query_embed = self.embedding_model.encode(
|
76 |
+
[query],
|
77 |
+
convert_to_numpy=True
|
78 |
+
).astype('float32')
|
79 |
+
|
80 |
+
_, indices = self.index.search(query_embed, k=TOP_K)
|
81 |
+
return "\n\n".join([self.chunks[i] for i in indices[0] if i < len(self.chunks)])
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Search error: {str(e)}")
|
84 |
+
return ""
|
85 |
+
|
86 |
+
def generate_response(self, query: str) -> str:
|
87 |
+
"""Generate response with robust error handling"""
|
88 |
+
if not self.dataset_loaded:
|
89 |
+
if self.loading_error:
|
90 |
+
return f"⚠️ Dataset loading failed: {self.loading_error}"
|
91 |
+
return "⚠️ System initializing..."
|
92 |
+
|
93 |
+
context = self.get_relevant_context(query)
|
94 |
+
if not context:
|
95 |
+
return "No relevant context found"
|
96 |
+
|
97 |
+
prompt = f"""Answer based on this context:
|
98 |
+
{context}
|
99 |
+
|
100 |
+
Question: {query}
|
101 |
+
Answer concisely:"""
|
102 |
+
|
103 |
+
try:
|
104 |
+
model = genai.GenerativeModel(GENAI_MODEL)
|
105 |
+
response = model.generate_content(prompt)
|
106 |
+
return response.text
|
107 |
+
except Exception as e:
|
108 |
+
return f"⚠️ API Error: {str(e)}"
|
109 |
+
|
110 |
+
# Initialize system
|
111 |
+
try:
|
112 |
+
rag_system = GeminiRAGSystem()
|
113 |
+
init_status = "✅ System ready" if rag_system.dataset_loaded else f"⚠️ Initializing... {rag_system.loading_error or ''}"
|
114 |
+
except Exception as e:
|
115 |
+
init_status = f"❌ Initialization failed: {str(e)}"
|
116 |
+
rag_system = None
|
117 |
+
|
118 |
+
# Create interface
|
119 |
+
with gr.Blocks(title="Chatbot") as app:
|
120 |
+
gr.Markdown("# Chatbot")
|
121 |
+
|
122 |
+
chatbot = gr.Chatbot(height=500)
|
123 |
+
query = gr.Textbox(label="Your question", placeholder="Ask something...")
|
124 |
+
submit_btn = gr.Button("Submit")
|
125 |
+
clear_btn = gr.Button("Clear")
|
126 |
+
status = gr.Textbox(label="Status", value=init_status)
|
127 |
+
|
128 |
+
def respond(message, chat_history):
|
129 |
+
if not rag_system:
|
130 |
+
return chat_history + [(message, "System initialization failed")]
|
131 |
+
response = rag_system.generate_response(message)
|
132 |
+
return chat_history + [(message, response)]
|
133 |
+
|
134 |
+
def clear_chat():
|
135 |
+
return []
|
136 |
+
|
137 |
+
submit_btn.click(respond, [query, chatbot], [chatbot])
|
138 |
+
query.submit(respond, [query, chatbot], [chatbot])
|
139 |
+
clear_btn.click(clear_chat, outputs=chatbot)
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
app.launch(share=True)
|