Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,8 +2,8 @@ import os
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import google.generativeai as genai
|
5 |
-
from sentence_transformers import SentenceTransformer
|
6 |
import faiss
|
|
|
7 |
from datasets import load_dataset
|
8 |
from dotenv import load_dotenv
|
9 |
|
@@ -13,50 +13,44 @@ load_dotenv()
|
|
13 |
# Configuration
|
14 |
MODEL_NAME = "all-MiniLM-L6-v2"
|
15 |
GENAI_MODEL = "gemini-pro"
|
16 |
-
|
17 |
-
"https://datasets-server.huggingface.co/rows?dataset=midrees2806%2F7K_Dataset&config=default&split=train&offset=100&length=100""
|
18 |
CHUNK_SIZE = 500
|
19 |
TOP_K = 3
|
20 |
|
21 |
-
# Initialize models
|
22 |
-
try:
|
23 |
-
embedding_model = SentenceTransformer(MODEL_NAME)
|
24 |
-
except Exception as e:
|
25 |
-
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
|
26 |
-
|
27 |
class GeminiRAGSystem:
|
28 |
def __init__(self):
|
29 |
self.index = None
|
30 |
self.chunks = []
|
31 |
self.dataset_loaded = False
|
32 |
self.gemini_api_key = os.getenv("AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
if self.gemini_api_key:
|
34 |
genai.configure(api_key=self.gemini_api_key)
|
35 |
|
36 |
-
def
|
37 |
-
"""Load dataset from Hugging Face
|
38 |
try:
|
39 |
-
# Extract dataset name from URL
|
40 |
-
dataset_name = dataset_link.split("datasets/")[-1].split("/")[0].strip()
|
41 |
-
if not dataset_name:
|
42 |
-
raise ValueError("Invalid dataset URL format")
|
43 |
-
|
44 |
with gr.Progress() as progress:
|
45 |
progress(0.1, desc="π¦ Downloading dataset...")
|
46 |
-
dataset = load_dataset(
|
47 |
|
48 |
progress(0.5, desc="π¨ Processing dataset...")
|
49 |
-
if '
|
50 |
-
self.chunks =
|
51 |
-
elif '
|
52 |
-
self.chunks = dataset['
|
53 |
-
elif 'question' in dataset.features and 'answer' in dataset.features:
|
54 |
-
self.chunks = [f"Q: {q}\nA: {a}" for q, a in zip(dataset['question'], dataset['answer'])]
|
55 |
else:
|
56 |
-
raise ValueError("
|
57 |
|
58 |
progress(0.7, desc="π§ Creating embeddings...")
|
59 |
-
embeddings = embedding_model.encode(self.chunks, show_progress_bar=False)
|
60 |
self.index = faiss.IndexFlatL2(embeddings.shape[1])
|
61 |
self.index.add(embeddings.astype('float32'))
|
62 |
|
@@ -68,168 +62,74 @@ class GeminiRAGSystem:
|
|
68 |
return False
|
69 |
|
70 |
def get_relevant_context(self, query: str) -> str:
|
71 |
-
"""Retrieve most relevant chunks
|
72 |
-
if not self.index
|
73 |
return ""
|
74 |
|
75 |
-
query_embed = embedding_model.encode([query])
|
76 |
-
|
77 |
|
78 |
context = []
|
79 |
-
for
|
80 |
if idx < len(self.chunks):
|
81 |
-
context.append(
|
82 |
-
return "\n".join(context)
|
83 |
|
84 |
def generate_response(self, query: str) -> str:
|
85 |
-
"""Generate response using
|
86 |
if not self.dataset_loaded:
|
87 |
return "β οΈ Please load the dataset first"
|
88 |
if not self.gemini_api_key:
|
89 |
-
return "π Please set your Gemini API key
|
90 |
|
91 |
context = self.get_relevant_context(query)
|
92 |
if not context:
|
93 |
-
return "No relevant context found
|
94 |
|
95 |
-
prompt = f"""
|
96 |
-
Follow these rules:
|
97 |
-
1. Answer concisely using ONLY the context below
|
98 |
-
2. If the answer isn't in the context, say "I couldn't find this in the dataset"
|
99 |
-
3. Never make up information
|
100 |
-
4. For ambiguous questions, ask for clarification
|
101 |
-
Context:
|
102 |
{context}
|
|
|
103 |
Question: {query}
|
104 |
-
Answer:"""
|
105 |
|
106 |
try:
|
107 |
model = genai.GenerativeModel(GENAI_MODEL)
|
108 |
response = model.generate_content(prompt)
|
109 |
return response.text
|
110 |
except Exception as e:
|
111 |
-
return f"β οΈ Error
|
112 |
|
113 |
-
# Initialize
|
114 |
rag_system = GeminiRAGSystem()
|
115 |
|
116 |
-
#
|
117 |
-
|
118 |
-
.
|
119 |
-
max-width: 900px !important;
|
120 |
-
margin: auto !important;
|
121 |
-
font-family: 'Inter', sans-serif;
|
122 |
-
}
|
123 |
-
.dark .gradio-container {
|
124 |
-
background-color: #1e1e2e;
|
125 |
-
}
|
126 |
-
.message-user {
|
127 |
-
background: #3b82f6;
|
128 |
-
color: white;
|
129 |
-
border-radius: 18px 18px 0 18px;
|
130 |
-
padding: 12px;
|
131 |
-
margin: 8px 0;
|
132 |
-
max-width: 80%;
|
133 |
-
margin-left: auto;
|
134 |
-
}
|
135 |
-
.message-bot {
|
136 |
-
background: #f3f4f6;
|
137 |
-
color: #111827;
|
138 |
-
border-radius: 18px 18px 18px 0;
|
139 |
-
padding: 12px;
|
140 |
-
margin: 8px 0;
|
141 |
-
max-width: 80%;
|
142 |
-
}
|
143 |
-
.dark .message-bot {
|
144 |
-
background: #2d3748;
|
145 |
-
color: #f7fafc;
|
146 |
-
}
|
147 |
-
.progress-bar {
|
148 |
-
height: 6px !important;
|
149 |
-
}
|
150 |
-
"""
|
151 |
-
|
152 |
-
# Chat interface
|
153 |
-
with gr.Blocks(css=css, theme=gr.themes.Default()) as app:
|
154 |
-
# Store chat history
|
155 |
-
chat_history = gr.State([])
|
156 |
-
|
157 |
-
gr.Markdown("## UE-ChatBot")
|
158 |
-
gr.Markdown(f"**Dataset:** {DATASET_LINK}")
|
159 |
|
160 |
with gr.Row():
|
161 |
-
with gr.Column(
|
162 |
-
gr.Markdown("### βοΈ Configuration")
|
163 |
-
dataset_url = gr.Textbox(
|
164 |
-
label="Hugging Face Dataset URL",
|
165 |
-
value=DATASET_LINK,
|
166 |
-
interactive=True
|
167 |
-
)
|
168 |
load_btn = gr.Button("π Load Dataset", variant="primary")
|
169 |
-
status = gr.Markdown("βΉοΈ
|
170 |
|
171 |
-
with gr.Column(
|
172 |
-
chatbot = gr.Chatbot(
|
173 |
-
|
174 |
-
|
175 |
-
avatar_images=(
|
176 |
-
"https://avatars.githubusercontent.com/u/1561194?v=4", # User avatar
|
177 |
-
"https://huggingface.co/spaces/groq/Groq-LLM/resolve/main/groq_logo.png" # Bot avatar
|
178 |
-
)
|
179 |
-
)
|
180 |
-
query = gr.Textbox(
|
181 |
-
label="Type your question...",
|
182 |
-
placeholder="Ask about the dataset content",
|
183 |
-
autofocus=True
|
184 |
-
)
|
185 |
-
with gr.Row():
|
186 |
-
submit_btn = gr.Button("π€ Submit", variant="primary")
|
187 |
-
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary")
|
188 |
|
189 |
# Event handlers
|
190 |
-
def load_dataset(
|
191 |
-
if rag_system.
|
192 |
-
return "β
Dataset
|
193 |
return "β Failed to load dataset"
|
194 |
|
195 |
-
def respond(
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
# Add user message
|
200 |
-
history.append((query, None))
|
201 |
-
|
202 |
-
# Get response
|
203 |
-
response = rag_system.generate_response(query)
|
204 |
-
|
205 |
-
# Update history
|
206 |
-
history[-1] = (query, response)
|
207 |
-
return history, ""
|
208 |
-
|
209 |
-
# Connect components
|
210 |
-
load_btn.click(
|
211 |
-
load_dataset,
|
212 |
-
inputs=dataset_url,
|
213 |
-
outputs=status
|
214 |
-
)
|
215 |
-
|
216 |
-
submit_btn.click(
|
217 |
-
respond,
|
218 |
-
inputs=[query, chat_history],
|
219 |
-
outputs=[chatbot, query]
|
220 |
-
)
|
221 |
-
|
222 |
-
query.submit(
|
223 |
-
respond,
|
224 |
-
inputs=[query, chat_history],
|
225 |
-
outputs=[chatbot, query]
|
226 |
-
)
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
outputs=chatbot
|
232 |
-
)
|
233 |
|
234 |
if __name__ == "__main__":
|
235 |
-
app.launch(
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import google.generativeai as genai
|
|
|
5 |
import faiss
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
from datasets import load_dataset
|
8 |
from dotenv import load_dotenv
|
9 |
|
|
|
13 |
# Configuration
|
14 |
MODEL_NAME = "all-MiniLM-L6-v2"
|
15 |
GENAI_MODEL = "gemini-pro"
|
16 |
+
DATASET_NAME = "midrees2806/7K_Dataset" # Direct dataset name
|
|
|
17 |
CHUNK_SIZE = 500
|
18 |
TOP_K = 3
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class GeminiRAGSystem:
|
21 |
def __init__(self):
|
22 |
self.index = None
|
23 |
self.chunks = []
|
24 |
self.dataset_loaded = False
|
25 |
self.gemini_api_key = os.getenv("AIzaSyASrFvE3gFPigihza0JTuALzZmBx0Kc3d0")
|
26 |
+
|
27 |
+
# Initialize embedding model
|
28 |
+
try:
|
29 |
+
self.embedding_model = SentenceTransformer(MODEL_NAME)
|
30 |
+
except Exception as e:
|
31 |
+
raise RuntimeError(f"Failed to initialize embedding model: {str(e)}")
|
32 |
+
|
33 |
+
# Configure Gemini
|
34 |
if self.gemini_api_key:
|
35 |
genai.configure(api_key=self.gemini_api_key)
|
36 |
|
37 |
+
def load_dataset(self):
|
38 |
+
"""Load dataset from Hugging Face"""
|
39 |
try:
|
|
|
|
|
|
|
|
|
|
|
40 |
with gr.Progress() as progress:
|
41 |
progress(0.1, desc="π¦ Downloading dataset...")
|
42 |
+
dataset = load_dataset(DATASET_NAME, split='train')
|
43 |
|
44 |
progress(0.5, desc="π¨ Processing dataset...")
|
45 |
+
if 'text' in dataset.features:
|
46 |
+
self.chunks = dataset['text'][:1000] # Limit to first 1000 entries
|
47 |
+
elif 'context' in dataset.features:
|
48 |
+
self.chunks = dataset['context'][:1000]
|
|
|
|
|
49 |
else:
|
50 |
+
raise ValueError("Dataset must have 'text' or 'context' field")
|
51 |
|
52 |
progress(0.7, desc="π§ Creating embeddings...")
|
53 |
+
embeddings = self.embedding_model.encode(self.chunks, show_progress_bar=False)
|
54 |
self.index = faiss.IndexFlatL2(embeddings.shape[1])
|
55 |
self.index.add(embeddings.astype('float32'))
|
56 |
|
|
|
62 |
return False
|
63 |
|
64 |
def get_relevant_context(self, query: str) -> str:
|
65 |
+
"""Retrieve most relevant chunks"""
|
66 |
+
if not self.index:
|
67 |
return ""
|
68 |
|
69 |
+
query_embed = self.embedding_model.encode([query])
|
70 |
+
_, indices = self.index.search(query_embed.astype('float32'), k=TOP_K)
|
71 |
|
72 |
context = []
|
73 |
+
for idx in indices[0]:
|
74 |
if idx < len(self.chunks):
|
75 |
+
context.append(self.chunks[idx])
|
76 |
+
return "\n\n".join(context)
|
77 |
|
78 |
def generate_response(self, query: str) -> str:
|
79 |
+
"""Generate response using Gemini"""
|
80 |
if not self.dataset_loaded:
|
81 |
return "β οΈ Please load the dataset first"
|
82 |
if not self.gemini_api_key:
|
83 |
+
return "π Please set your Gemini API key"
|
84 |
|
85 |
context = self.get_relevant_context(query)
|
86 |
if not context:
|
87 |
+
return "No relevant context found"
|
88 |
|
89 |
+
prompt = f"""Answer based on this context:
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
{context}
|
91 |
+
|
92 |
Question: {query}
|
93 |
+
Answer concisely:"""
|
94 |
|
95 |
try:
|
96 |
model = genai.GenerativeModel(GENAI_MODEL)
|
97 |
response = model.generate_content(prompt)
|
98 |
return response.text
|
99 |
except Exception as e:
|
100 |
+
return f"β οΈ Error: {str(e)}"
|
101 |
|
102 |
+
# Initialize system
|
103 |
rag_system = GeminiRAGSystem()
|
104 |
|
105 |
+
# Create interface
|
106 |
+
with gr.Blocks(title="RAG Chatbot") as app:
|
107 |
+
gr.Markdown("# UE_ChatBot")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
with gr.Row():
|
110 |
+
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
load_btn = gr.Button("π Load Dataset", variant="primary")
|
112 |
+
status = gr.Markdown("βΉοΈ Click to load dataset")
|
113 |
|
114 |
+
with gr.Column():
|
115 |
+
chatbot = gr.Chatbot()
|
116 |
+
query = gr.Textbox(label="Your question", placeholder="Ask about the dataset...")
|
117 |
+
submit_btn = gr.Button("π€ Submit", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# Event handlers
|
120 |
+
def load_dataset():
|
121 |
+
if rag_system.load_dataset():
|
122 |
+
return "β
Dataset ready! You can now ask questions."
|
123 |
return "β Failed to load dataset"
|
124 |
|
125 |
+
def respond(message, chat_history):
|
126 |
+
response = rag_system.generate_response(message)
|
127 |
+
chat_history.append((message, response))
|
128 |
+
return "", chat_history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
load_btn.click(load_dataset, outputs=status)
|
131 |
+
submit_btn.click(respond, [query, chatbot], [query, chatbot])
|
132 |
+
query.submit(respond, [query, chatbot], [query, chatbot])
|
|
|
|
|
133 |
|
134 |
if __name__ == "__main__":
|
135 |
+
app.launch(share=True)
|