Phoenix21 commited on
Commit
8634f11
·
verified ·
1 Parent(s): 44662fa

Create app.py

Browse files

removed gemini wrapper and used gemini directly from LiteLLM

Files changed (1) hide show
  1. app.py +413 -0
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import pandas as pd
5
+ import chardet
6
+ import logging
7
+ import gradio as gr
8
+ import json
9
+ import hashlib
10
+ import numpy as np
11
+ from typing import Optional, List, Tuple, ClassVar, Dict
12
+
13
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
14
+ from langchain.llms.base import LLM
15
+ import google.generativeai as genai
16
+
17
+ # Import smolagents components
18
+ from smolagents import CodeAgent, LiteLLMModel, DuckDuckGoSearchTool, ManagedAgent
19
+
20
+ ###############################################################################
21
+ # 1) Logging Setup
22
+ ###############################################################################
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger("Daily Wellness AI")
25
+
26
+ ###############################################################################
27
+ # 2) API Key Handling and LiteLLMModel Instantiation
28
+ ###############################################################################
29
+ def clean_api_key(key: str) -> str:
30
+ """Remove non-ASCII characters and strip whitespace from the API key."""
31
+ return ''.join(c for c in key if ord(c) < 128).strip()
32
+
33
+ gemini_api_key = os.environ.get("GEMINI_API_KEY")
34
+ if not gemini_api_key:
35
+ logger.error("GEMINI_API_KEY environment variable not set.")
36
+ raise EnvironmentError("Please set the GEMINI_API_KEY environment variable.")
37
+
38
+ gemini_api_key = clean_api_key(gemini_api_key)
39
+ logger.info("GEMINI API Key loaded successfully.")
40
+
41
+ # Instantiate the model using LiteLLMModel
42
+ llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=gemini_api_key)
43
+
44
+ ###############################################################################
45
+ # 3) CSV Loading and Processing
46
+ ###############################################################################
47
+ def load_csv(file_path: str):
48
+ try:
49
+ if not os.path.isfile(file_path):
50
+ logger.error(f"CSV file does not exist: {file_path}")
51
+ return [], []
52
+
53
+ with open(file_path, 'rb') as f:
54
+ result = chardet.detect(f.read())
55
+ encoding = result['encoding']
56
+
57
+ data = pd.read_csv(file_path, encoding=encoding)
58
+ if 'Question' not in data.columns or 'Answers' not in data.columns:
59
+ raise ValueError("CSV must contain 'Question' and 'Answers' columns.")
60
+ data = data.dropna(subset=['Question', 'Answers'])
61
+
62
+ logger.info(f"Loaded {len(data)} entries from {file_path}")
63
+ return data['Question'].tolist(), data['Answers'].tolist()
64
+ except Exception as e:
65
+ logger.error(f"Error loading CSV: {e}")
66
+ return [], []
67
+
68
+ csv_file_path = "AIChatbot.csv"
69
+ corpus_questions, corpus_answers = load_csv(csv_file_path)
70
+ if not corpus_questions:
71
+ raise ValueError("Failed to load the knowledge base.")
72
+
73
+ ###############################################################################
74
+ # 4) Sentence Embeddings & Cross-Encoder
75
+ ###############################################################################
76
+ embedding_model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
77
+ try:
78
+ embedding_model = SentenceTransformer(embedding_model_name)
79
+ logger.info(f"Loaded embedding model: {embedding_model_name}")
80
+ except Exception as e:
81
+ logger.error(f"Failed to load embedding model: {e}")
82
+ raise e
83
+
84
+ try:
85
+ question_embeddings = embedding_model.encode(corpus_questions, convert_to_tensor=True)
86
+ logger.info("Encoded question embeddings successfully.")
87
+ except Exception as e:
88
+ logger.error(f"Failed to encode question embeddings: {e}")
89
+ raise e
90
+
91
+ cross_encoder_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
92
+ try:
93
+ cross_encoder = CrossEncoder(cross_encoder_name)
94
+ logger.info(f"Loaded cross-encoder model: {cross_encoder_name}")
95
+ except Exception as e:
96
+ logger.error(f"Failed to load cross-encoder model: {e}")
97
+ raise e
98
+
99
+ ###############################################################################
100
+ # 5) Retrieval + Re-Ranking
101
+ ###############################################################################
102
+ class EmbeddingRetriever:
103
+ def __init__(self, questions, answers, embeddings, model, cross_encoder):
104
+ self.questions = questions
105
+ self.answers = answers
106
+ self.embeddings = embeddings
107
+ self.model = model
108
+ self.cross_encoder = cross_encoder
109
+
110
+ def retrieve(self, query: str, top_k: int = 3):
111
+ try:
112
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
113
+ scores = util.pytorch_cos_sim(query_embedding, self.embeddings)[0].cpu().tolist()
114
+ scored_data = sorted(zip(self.questions, self.answers, scores), key=lambda x: x[2], reverse=True)[:top_k]
115
+
116
+ cross_inputs = [[query, candidate[0]] for candidate in scored_data]
117
+ cross_scores = self.cross_encoder.predict(cross_inputs)
118
+
119
+ reranked = sorted(zip(scored_data, cross_scores), key=lambda x: x[1], reverse=True)
120
+ final_retrieved = [(entry[0][1], entry[1]) for entry in reranked]
121
+ logger.debug(f"Retrieved and reranked answers: {final_retrieved}")
122
+ return final_retrieved
123
+ except Exception as e:
124
+ logger.error(f"Error during retrieval: {e}")
125
+ logger.debug("Exception details:", exc_info=True)
126
+ return []
127
+
128
+ retriever = EmbeddingRetriever(corpus_questions, corpus_answers, question_embeddings, embedding_model, cross_encoder)
129
+
130
+ ###############################################################################
131
+ # 6) Sanity Check Tool
132
+ ###############################################################################
133
+ class QuestionSanityChecker:
134
+ def __init__(self, llm):
135
+ self.llm = llm
136
+
137
+ def is_relevant(self, question: str) -> bool:
138
+ prompt = (
139
+ f"You are an assistant that determines whether a question is relevant to daily wellness.\n\n"
140
+ f"Question: {question}\n\n"
141
+ f"Is the above question relevant to daily wellness? Respond with 'Yes' or 'No' only."
142
+ )
143
+ try:
144
+ response = self.llm(prompt)
145
+ is_yes = 'yes' in response.lower()
146
+ is_no = 'no' in response.lower()
147
+ logger.debug(f"Sanity check response: '{response}', interpreted as is_yes={is_yes}, is_no={is_no}")
148
+ if is_yes and not is_no:
149
+ return True
150
+ elif is_no and not is_yes:
151
+ return False
152
+ else:
153
+ logger.warning(f"Sanity check ambiguous response: '{response}'. Defaulting to 'No'.")
154
+ return False
155
+ except Exception as e:
156
+ logger.error(f"Error in sanity check: {e}")
157
+ logger.debug("Exception details:", exc_info=True)
158
+ return False
159
+
160
+ sanity_checker = QuestionSanityChecker(llm)
161
+
162
+ ###############################################################################
163
+ # 7) smolagents Integration: GROQ Model and Web Search
164
+ ###############################################################################
165
+ # Initialize the smolagents' LiteLLMModel with GROQ model (already instantiated as llm if needed elsewhere)
166
+
167
+ # Instantiate the DuckDuckGo search tool
168
+ search_tool = DuckDuckGoSearchTool()
169
+
170
+ # Create the web agent with the search tool
171
+ web_agent = CodeAgent(
172
+ tools=[search_tool],
173
+ model=llm # Use the direct model for web queries if applicable
174
+ )
175
+
176
+ # Define the managed web agent
177
+ managed_web_agent = ManagedAgent(
178
+ agent=web_agent,
179
+ name="web_search",
180
+ description="Runs a web search for you. Provide your query as an argument."
181
+ )
182
+
183
+ # Create the manager agent with managed web agent and additional tools if needed
184
+ manager_agent = CodeAgent(
185
+ tools=[], # Add additional tools here if required
186
+ model=llm,
187
+ managed_agents=[managed_web_agent]
188
+ )
189
+
190
+ ###############################################################################
191
+ # 8) Answer Expansion
192
+ ###############################################################################
193
+ class AnswerExpander:
194
+ def __init__(self, llm):
195
+ self.llm = llm
196
+
197
+ def expand(self, query: str, retrieved_answers: List[str], detail: bool = False) -> str:
198
+ try:
199
+ reference_block = "\n".join(
200
+ f"- {idx+1}) {ans}" for idx, ans in enumerate(retrieved_answers, start=1)
201
+ )
202
+
203
+ detail_instructions = (
204
+ "Provide a thorough, in-depth explanation, adding relevant tips and context, "
205
+ "while remaining creative and brand-aligned. "
206
+ if detail else
207
+ "Provide a concise response in no more than 4 sentences."
208
+ )
209
+
210
+ prompt = (
211
+ f"You are Daily Wellness AI, a friendly wellness expert. Below are multiple "
212
+ f"potential answers retrieved from a local knowledge base. You have a user question.\n\n"
213
+ f"Question: {query}\n\n"
214
+ f"Retrieved Answers:\n{reference_block}\n\n"
215
+ f"Please synthesize these references into a single cohesive, creative, and brand-aligned response. "
216
+ f"{detail_instructions} "
217
+ f"End with a short inspirational note.\n\n"
218
+ "Disclaimer: This is general wellness information, not a substitute for professional medical advice."
219
+ )
220
+
221
+ logger.debug(f"Generated prompt for answer expansion: {prompt}")
222
+ response = self.llm(prompt)
223
+ logger.debug(f"Expanded answer: {response}")
224
+ return response.strip()
225
+ except Exception as e:
226
+ logger.error(f"Error expanding answer: {e}")
227
+ logger.debug("Exception details:", exc_info=True)
228
+ return "Sorry, an error occurred while generating a response."
229
+
230
+ answer_expander = AnswerExpander(llm)
231
+
232
+ ###############################################################################
233
+ # 9) Persistent Cache (ADDED)
234
+ ###############################################################################
235
+ CACHE_FILE = "query_cache.json"
236
+ SIMILARITY_THRESHOLD_CACHE = 0.8
237
+
238
+ def load_cache() -> Dict:
239
+ if os.path.isfile(CACHE_FILE):
240
+ try:
241
+ with open(CACHE_FILE, "r", encoding="utf-8") as f:
242
+ return json.load(f)
243
+ except Exception as e:
244
+ logger.error(f"Failed to load cache file: {e}")
245
+ return {}
246
+ return {}
247
+
248
+ def save_cache(cache_data: Dict):
249
+ try:
250
+ with open(CACHE_FILE, "w", encoding="utf-8") as f:
251
+ json.dump(cache_data, f, ensure_ascii=False, indent=2)
252
+ except Exception as e:
253
+ logger.error(f"Failed to save cache file: {e}")
254
+
255
+ def compute_hash(text: str) -> str:
256
+ return hashlib.md5(text.encode("utf-8")).hexdigest()
257
+
258
+ cache_store = load_cache()
259
+
260
+ ###############################################################################
261
+ # 9.1) Utility to attempt cached retrieval (ADDED)
262
+ ###############################################################################
263
+ def get_cached_answer(query: str) -> Optional[str]:
264
+ if not cache_store:
265
+ return None
266
+
267
+ query_embedding = embedding_model.encode(query, convert_to_tensor=True)
268
+
269
+ best_score = 0.0
270
+ best_answer = None
271
+
272
+ for cached_q, cache_data in cache_store.items():
273
+ stored_embedding = np.array(cache_data["embedding"], dtype=np.float32)
274
+ score = util.pytorch_cos_sim(query_embedding, stored_embedding)[0].item()
275
+ if score > best_score:
276
+ best_score = score
277
+ best_answer = cache_data["answer"]
278
+
279
+ if best_score >= SIMILARITY_THRESHOLD_CACHE:
280
+ logger.info(f"Cache hit! Similarity: {best_score:.2f}, returning cached answer.")
281
+ return best_answer
282
+ return None
283
+
284
+ def store_in_cache(query: str, answer: str):
285
+ query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().tolist()
286
+ cache_key = compute_hash(query)
287
+ cache_store[cache_key] = {
288
+ "query": query,
289
+ "answer": answer,
290
+ "embedding": query_embedding
291
+ }
292
+ save_cache(cache_store)
293
+
294
+ ###############################################################################
295
+ # 10) Query Handling
296
+ ###############################################################################
297
+ def handle_query(query: str, detail: bool = False) -> str:
298
+ if not query or not isinstance(query, str) or len(query.strip()) == 0:
299
+ return "Please provide a valid question."
300
+
301
+ try:
302
+ is_relevant = sanity_checker.is_relevant(query)
303
+ if not is_relevant:
304
+ return "Your question seems out of context or not related to daily wellness. Please ask a wellness-related question."
305
+
306
+ retrieved = retriever.retrieve(query)
307
+ cached_answer = get_cached_answer(query)
308
+
309
+ if not retrieved:
310
+ if cached_answer:
311
+ logger.info("No relevant entries found in knowledge base. Returning cached answer.")
312
+ return cached_answer
313
+ return "I'm sorry, I couldn't find an answer to your question."
314
+
315
+ top_score = retrieved[0][1]
316
+ similarity_threshold = 0.3
317
+
318
+ if top_score < similarity_threshold:
319
+ logger.info("Similarity score below threshold. Performing web search.")
320
+ web_search_response = manager_agent.run(query)
321
+ logger.debug(f"Web search response: {web_search_response}")
322
+
323
+ if cached_answer:
324
+ blend_prompt = (
325
+ f"Combine the following previous answer with the new web results to create a more creative and accurate response. "
326
+ f"Do not include any of the previous prompt or instructions in your response. "
327
+ f"Add positivity and conclude with a short inspirational note.\n\n"
328
+ f"Previous Answer:\n{cached_answer}\n\n"
329
+ f"Web Results:\n{web_search_response}"
330
+ )
331
+ final_answer = llm(blend_prompt).strip()
332
+ else:
333
+ final_answer = (
334
+ f"**Daily Wellness AI**\n\n"
335
+ f"{web_search_response}\n\n"
336
+ "Disclaimer: This information is retrieved from the web and is not a substitute for professional medical advice.\n\n"
337
+ "Wishing you a calm and wonderful day!"
338
+ )
339
+
340
+ store_in_cache(query, final_answer)
341
+ return final_answer
342
+
343
+ responses = [ans for ans, score in retrieved]
344
+
345
+ if cached_answer:
346
+ blend_prompt = (
347
+ f"Combine the previous answer with the newly retrieved answers to enhance creativity and accuracy. "
348
+ f"Do not include any of the previous prompt or instructions in your response. "
349
+ f"Add new insights, creativity, and conclude with a short inspirational note.\n\n"
350
+ f"Previous Answer:\n{cached_answer}\n\n"
351
+ f"New Retrieved Answers:\n" + "\n".join(f"- {r}" for r in responses)
352
+ )
353
+ final_answer = llm(blend_prompt).strip()
354
+ else:
355
+ final_answer = answer_expander.expand(query, responses, detail=detail)
356
+
357
+ store_in_cache(query, final_answer)
358
+ return final_answer
359
+
360
+ except Exception as e:
361
+ logger.error(f"Error handling query: {e}")
362
+ logger.debug("Exception details:", exc_info=True)
363
+ return "An error occurred while processing your request."
364
+
365
+ ###############################################################################
366
+ # 11) Gradio Interface
367
+ ###############################################################################
368
+ def gradio_interface(query: str, detail: bool):
369
+ try:
370
+ response = handle_query(query, detail=detail)
371
+ formatted_response = response
372
+ return formatted_response
373
+ except Exception as e:
374
+ logger.error(f"Error in Gradio interface: {e}")
375
+ logger.debug("Exception details:", exc_info=True)
376
+ return "**An error occurred while processing your request. Please try again later.**"
377
+
378
+ interface = gr.Interface(
379
+ fn=gradio_interface,
380
+ inputs=[
381
+ gr.Textbox(
382
+ lines=2,
383
+ placeholder="e.g., What is box breathing?",
384
+ label="Ask Daily Wellness AI"
385
+ ),
386
+ gr.Checkbox(
387
+ label="In-Depth Answer?",
388
+ value=False,
389
+ info="Check for a longer, more detailed response."
390
+ )
391
+ ],
392
+ outputs=gr.Markdown(label="Answer from Daily Wellness AI"),
393
+ title="Daily Wellness AI",
394
+ description="Ask wellness-related questions and receive synthesized, creative answers. Optionally request a more in-depth response.",
395
+ theme="default",
396
+ examples=[
397
+ ["What is box breathing and how does it help reduce anxiety?", True],
398
+ ["Provide a daily wellness schedule incorporating box breathing techniques.", False],
399
+ ["What are some tips for maintaining good posture while working at a desk?", True],
400
+ ["Who is the CEO of Hugging Face?", False]
401
+ ],
402
+ allow_flagging="never"
403
+ )
404
+
405
+ ###############################################################################
406
+ # 12) Launch Gradio
407
+ ###############################################################################
408
+ if __name__ == "__main__":
409
+ try:
410
+ interface.launch(server_name="0.0.0.0", server_port=7860, debug=False, share=True)
411
+ except Exception as e:
412
+ logger.error(f"Failed to launch Gradio interface: {e}")
413
+ logger.debug("Exception details:", exc_info=True)