Phoenix21 commited on
Commit
a004b34
·
verified ·
1 Parent(s): 9afbd21

Updated app.py with multiple feature

Browse files
Files changed (1) hide show
  1. app.py +259 -165
app.py CHANGED
@@ -1,214 +1,308 @@
1
- # Install necessary libraries in Colab
2
- # !pip install datasets langchain_community smolagents chardet gradio pandas nltk sklearn
3
 
4
- # Import required modules
5
  import os
6
  import getpass
7
  import pandas as pd
8
  import chardet
9
- import re
10
- from langchain.docstore.document import Document
11
- from langchain.text_splitter import RecursiveCharacterTextSplitter
12
- from langchain_community.retrievers import BM25Retriever
13
- from smolagents import CodeAgent, HfApiModel, DuckDuckGoSearchTool , Tool ,LiteLLMModel
14
- import gradio as gr
15
  import logging
16
- from nltk.corpus import words
17
- from sklearn.feature_extraction.text import TfidfVectorizer
18
- from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
 
 
 
 
 
19
 
 
20
  # Set up logging
21
- logging.basicConfig(level=logging.INFO)
 
22
  logger = logging.getLogger("Daily Wellness AI Guru")
23
 
24
- # Securely input the GROQ API key
25
- if 'GROQ_API_KEY' not in os.environ or not os.environ['GROQ_API_KEY']:
26
- os.environ['GROQ_API_KEY'] = getpass.getpass('Enter GROQ_API_KEY: ')
 
 
 
 
 
27
  else:
28
- print("GROQ_API_KEY is already set.")
29
-
30
- # Load NLTK word list for valid word checks
31
- try:
32
- english_words = set(words.words())
33
- except LookupError:
34
- import nltk
35
- nltk.download('words')
36
- english_words = set(words.words())
37
-
38
- # Define allowed topics for health and wellness
39
- ALLOWED_TOPICS = [
40
- "mental health",
41
- "physical health",
42
- "fitness",
43
- "nutrition",
44
- "exercise",
45
- "mindfulness",
46
- "sleep",
47
- "stress management",
48
- "wellness",
49
- "relaxation",
50
- "healthy lifestyle",
51
- "self-care",
52
- "meditation",
53
- "diet",
54
- "hydration",
55
- "breathing techniques",
56
- "yoga",
57
- "stress relief",
58
- "emotional health",
59
- "spiritual health",
60
- "healthy habits"
61
- ]
62
-
63
- def is_valid_input(query):
64
- """
65
- Validate the user's input question.
66
- """
67
- if not query or query.strip() == "":
68
- return False, "Input cannot be empty. Please provide a meaningful question."
69
-
70
- if len(query.strip()) < 2:
71
- return False, "Input is too short. Please provide more context or details."
72
-
73
- # Check for valid words
74
- words_in_text = re.findall(r'\b\w+\b', query.lower())
75
- recognized_words = [word for word in words_in_text if word in english_words]
76
-
77
- if not recognized_words:
78
- return False, "Input appears unclear. Please use valid words in your question."
79
-
80
- return True, "Valid input."
81
 
82
- def similarity_search(query, corpus, threshold=0.2):
83
- """
84
- Perform similarity search using TF-IDF and cosine similarity.
85
- """
86
- vectorizer = TfidfVectorizer()
87
- tfidf_matrix = vectorizer.fit_transform(corpus + [query])
88
- query_vector = tfidf_matrix[-1]
89
- similarities = cosine_similarity(query_vector, tfidf_matrix[:-1]).flatten()
90
- max_similarity = max(similarities)
91
- if max_similarity >= threshold:
92
- most_similar_idx = similarities.argmax()
93
- return True, corpus[most_similar_idx], max_similarity
94
- return False, None, max_similarity
95
-
96
- # Load and process the AIChatbot.csv file
97
  def load_csv(file_path):
98
  """
99
- Load and process a CSV file into a list of documents.
100
  """
101
  try:
 
102
  with open(file_path, 'rb') as f:
103
  result = chardet.detect(f.read())
104
  encoding = result['encoding']
 
 
105
  data = pd.read_csv(file_path, encoding=encoding)
106
- questions = data['Question'].dropna().tolist()
107
- documents = [
108
- Document(page_content=row.to_string(index=False), metadata={"source": file_path})
109
- for _, row in data.iterrows()
110
- ]
111
- logger.info(f"Loaded {len(documents)} documents from {file_path}")
112
- return documents, questions
 
 
 
 
 
 
 
113
  except Exception as e:
114
  logger.error(f"Error loading CSV file: {e}")
115
  return [], []
116
 
 
117
  # Load the AIChatbot.csv file
118
- file_path = "AIChatbot.csv" # Ensure this file is uploaded to your environment
119
- source_docs, corpus_questions = load_csv(file_path)
120
- if not source_docs:
121
- raise ValueError(f"Failed to load documents from {file_path}. Please check the file.")
122
-
123
- # Split documents into manageable chunks
124
- text_splitter = RecursiveCharacterTextSplitter(
125
- chunk_size=500,
126
- chunk_overlap=50,
127
- add_start_index=True,
128
- strip_whitespace=True,
129
- separators=["\n\n", "\n", ".", " ", ""],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
131
- docs_processed = text_splitter.split_documents(source_docs)
132
- logger.info(f"Split documents into {len(docs_processed)} chunks.")
133
 
134
- # Define the retriever tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  class RetrieverTool(Tool):
136
- name = "retriever"
137
- description = "Uses semantic search to retrieve the parts of chatbot documentation most relevant to the query."
138
  inputs = {
139
  "query": {
140
  "type": "string",
141
- "description": "The query to perform. Use an affirmative tone rather than a question."
142
  }
143
  }
144
  output_type = "string"
145
 
146
- def __init__(self, docs, **kwargs):
147
- super().__init__(**kwargs)
148
- self.retriever = BM25Retriever.from_documents(docs, k=10)
149
-
150
- def forward(self, query: str) -> str:
151
- assert isinstance(query, str), "Search query must be a string."
152
- docs = self.retriever.invoke(query)
153
- if docs:
154
- return docs[0].page_content.strip()
155
- else:
156
- return "No relevant information found."
157
-
158
- retriever_tool = RetrieverTool(docs_processed)
159
-
160
- # Define DuckDuckGoSearchTool
161
- duckduckgo_search_tool = DuckDuckGoSearchTool()
162
-
163
- # Define the improved custom prompt
164
- custom_prompt = """
165
- You are Daily Wellness AI Guru, a friendly and knowledgeable assistant here to simplify wellness. Your goal is to provide clear, concise, and actionable answers to the user's health and wellness-related questions. Mention how Daily Wellness AI offers tailored solutions for day-to-day wellness tasks. Use a warm and friendly tone to make the user feel at ease.
166
-
167
- When answering:
168
- 1. Address the user warmly with "Hello! This is Daily Wellness AI Guru."
169
- 2. Highlight the key points in an easy-to-understand manner.
170
- 3. Include practical examples, tips, or short guides where relevant.
171
- 4. Format the response for clarity using markdown (e.g., numbered lists, bullet points).
172
- 5. Reinforce how Daily Wellness AI helps simplify wellness through AI-powered solutions.
173
- 6. End with an engaging and polite closing remark that invites further questions.
174
- """
175
-
176
- # Define the agent using smolagents
177
- model = LiteLLMModel("groq/llama3-8b-8192") # Ensure the model is available
178
- agent = CodeAgent(
179
- tools=[retriever_tool, duckduckgo_search_tool], model=model, max_iterations=4, verbose=True
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  )
181
 
182
- # Gradio interface for interacting with the RAG pipeline
 
 
183
  def gradio_interface(query):
184
  try:
185
- # Validate input
186
- is_valid, message = is_valid_input(query)
187
- if not is_valid:
188
- return message
189
-
190
- # Perform similarity search
191
- similar, similar_question, similarity_score = similarity_search(query, corpus_questions, threshold=0.2)
192
- if similar:
193
- response = agent.run(f"{custom_prompt}\n\nQuestion: {query}")
194
- return response.strip()
195
- else:
196
- response = duckduckgo_search_tool.invoke(query)
197
- return f"{response.strip()}\n\nRemember, Daily Wellness AI is here to simplify wellness with AI-powered solutions. Feel free to ask more questions!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  except Exception as e:
199
- logger.error(f"Error during query processing: {e}")
200
  return "**An error occurred while processing your request. Please try again later.**"
201
 
202
-
203
- # Create the Gradio interface
 
204
  interface = gr.Interface(
205
  fn=gradio_interface,
206
- inputs=gr.Textbox(label="Enter your question", placeholder="e.g., How does box breathing help reduce anxiety?"),
207
- outputs=gr.Markdown(label="Answer"),
 
 
 
208
  title="Daily Wellness AI Guru Chatbot",
209
- description="Ask health and wellness questions. Get actionable, friendly advice from your wellness companion.",
 
 
 
 
210
  theme="compact"
211
  )
212
 
 
 
 
 
213
  if __name__ == "__main__":
214
- interface.launch(debug=True)
 
1
+ # app.py
 
2
 
 
3
  import os
4
  import getpass
5
  import pandas as pd
6
  import chardet
 
 
 
 
 
 
7
  import logging
8
+ import gradio as gr
9
+
10
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
11
+ from langchain_community.retrievers import BM25Retriever
12
+ from smolagents import (
13
+ CodeAgent,
14
+ HfApiModel,
15
+ DuckDuckGoSearchTool,
16
+ Tool,
17
+ ManagedAgent,
18
+ LiteLLMModel
19
+ )
20
 
21
+ # --------------------------------------------------------------------------------
22
  # Set up logging
23
+ # --------------------------------------------------------------------------------
24
+ logging.basicConfig(level=logging.DEBUG)
25
  logger = logging.getLogger("Daily Wellness AI Guru")
26
 
27
+ # --------------------------------------------------------------------------------
28
+ # Ensure Hugging Face API Token
29
+ # --------------------------------------------------------------------------------
30
+ # In a Hugging Face Space, you can set HF_API_TOKEN as a secret variable.
31
+ # If it's not set, you could prompt for it locally, but in Spaces,
32
+ # you typically wouldn't do getpass. We'll leave the logic here as fallback.
33
+ if 'HF_API_TOKEN' not in os.environ or not os.environ['HF_API_TOKEN']:
34
+ os.environ['HF_API_TOKEN'] = getpass.getpass('Enter your Hugging Face API Token: ')
35
  else:
36
+ print("HF_API_TOKEN is already set.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # --------------------------------------------------------------------------------
39
+ # CSV Loading and Processing
40
+ # --------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
41
  def load_csv(file_path):
42
  """
43
+ Load and process a CSV file into two lists: questions and answers.
44
  """
45
  try:
46
+ # Detect the encoding of the file
47
  with open(file_path, 'rb') as f:
48
  result = chardet.detect(f.read())
49
  encoding = result['encoding']
50
+
51
+ # Load the CSV using the detected encoding
52
  data = pd.read_csv(file_path, encoding=encoding)
53
+
54
+ # Validate that the required columns are present
55
+ if 'Question' not in data.columns or 'Answers' not in data.columns:
56
+ raise ValueError("The CSV file must contain 'Question' and 'Answers' columns.")
57
+
58
+ # Drop any rows with missing values in 'Question' or 'Answers'
59
+ data = data.dropna(subset=['Question', 'Answers'])
60
+
61
+ # Extract questions and answers
62
+ questions = data['Question'].tolist()
63
+ answers = data['Answers'].tolist()
64
+
65
+ logger.info(f"Loaded {len(questions)} questions and {len(answers)} answers from {file_path}")
66
+ return questions, answers
67
  except Exception as e:
68
  logger.error(f"Error loading CSV file: {e}")
69
  return [], []
70
 
71
+ # --------------------------------------------------------------------------------
72
  # Load the AIChatbot.csv file
73
+ # --------------------------------------------------------------------------------
74
+ file_path = "AIChatbot.csv" # Ensure this file is in the same directory as app.py
75
+ corpus_questions, corpus_answers = load_csv(file_path)
76
+
77
+ if not corpus_questions:
78
+ raise ValueError(f"Failed to load questions from {file_path}.")
79
+
80
+ # --------------------------------------------------------------------------------
81
+ # Embedding Model
82
+ # --------------------------------------------------------------------------------
83
+ embedding_model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
84
+ embedding_model = SentenceTransformer(embedding_model_name)
85
+ logger.info(f"Loaded sentence embedding model: {embedding_model_name}")
86
+
87
+ # Encode Questions (for retrieval)
88
+ question_embeddings = embedding_model.encode(corpus_questions, convert_to_tensor=True)
89
+
90
+ # --------------------------------------------------------------------------------
91
+ # Cross-Encoder for Re-Ranking
92
+ # --------------------------------------------------------------------------------
93
+ cross_encoder_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
94
+ cross_encoder = CrossEncoder(cross_encoder_model_name)
95
+ logger.info(f"Loaded cross-encoder model: {cross_encoder_model_name}")
96
+
97
+ # --------------------------------------------------------------------------------
98
+ # Retrieval + Re-ranking Class
99
+ # --------------------------------------------------------------------------------
100
+ class EmbeddingRetriever:
101
+ def __init__(self, questions, answers, embeddings, model, cross_encoder):
102
+ self.questions = questions
103
+ self.answers = answers
104
+ self.embeddings = embeddings
105
+ self.model = model
106
+ self.cross_encoder = cross_encoder
107
+
108
+ def retrieve(self, query, top_k=3):
109
+ # Compute query embedding
110
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
111
+ scores = util.pytorch_cos_sim(query_embedding, self.embeddings)[0].cpu().tolist()
112
+
113
+ # Combine data
114
+ scored_data = list(zip(self.questions, self.answers, scores))
115
+ # Sort by best scores
116
+ scored_data = sorted(scored_data, key=lambda x: x[2], reverse=True)
117
+ # Take top_k
118
+ top_candidates = scored_data[:top_k]
119
+
120
+ # Cross-encode re-rank
121
+ cross_inputs = [[query, candidate[0]] for candidate in top_candidates]
122
+ cross_scores = self.cross_encoder.predict(cross_inputs)
123
+
124
+ reranked = sorted(
125
+ zip(top_candidates, cross_scores),
126
+ key=lambda x: x[1],
127
+ reverse=True
128
+ )
129
+
130
+ # The best candidate
131
+ best_candidate = reranked[0][0] # (question, answer, score)
132
+ best_answer = best_candidate[1]
133
+ return best_answer
134
+
135
+ retriever = EmbeddingRetriever(
136
+ questions=corpus_questions,
137
+ answers=corpus_answers,
138
+ embeddings=question_embeddings,
139
+ model=embedding_model,
140
+ cross_encoder=cross_encoder
141
  )
 
 
142
 
143
+ # --------------------------------------------------------------------------------
144
+ # Simple Answer Expander (Without custom sampling parameters)
145
+ # --------------------------------------------------------------------------------
146
+ class AnswerExpander:
147
+ def __init__(self, model: HfApiModel):
148
+ self.model = model
149
+
150
+ def expand(self, question: str, short_answer: str) -> str:
151
+ """
152
+ Prompt the LLM to provide a more creative, brand-aligned answer.
153
+ """
154
+ prompt = (
155
+ "You are Daily Wellness AI, a friendly and creative wellness expert. "
156
+ "The user has a question about well-being. Provide an encouraging, day-to-day "
157
+ "wellness perspective. Be gentle, uplifting, and brand-aligned.\n\n"
158
+ f"Question: {question}\n"
159
+ f"Current short answer: {short_answer}\n\n"
160
+ "Please rephrase and expand with more detail, wellness tips, daily-life "
161
+ "applications, and an optimistic tone. Keep it informal, friendly, and end "
162
+ "with a short inspirational note.\n"
163
+ )
164
+ try:
165
+ expanded_answer = self.model.run(prompt)
166
+ return expanded_answer.strip()
167
+ except Exception as e:
168
+ logger.error(f"Failed to expand answer: {e}")
169
+ return short_answer
170
+
171
+ # NOTE: We are using a basic HfApiModel here (no custom sampling).
172
+ expander_model = HfApiModel()
173
+ answer_expander = AnswerExpander(expander_model)
174
+
175
+ # --------------------------------------------------------------------------------
176
+ # Enhanced Retriever Tool
177
+ # --------------------------------------------------------------------------------
178
+ from smolagents import Tool
179
+
180
  class RetrieverTool(Tool):
181
+ name = "retriever_tool"
182
+ description = "Uses semantic search + cross-encoder re-ranking to retrieve the best answer."
183
  inputs = {
184
  "query": {
185
  "type": "string",
186
+ "description": "User query for retrieving relevant information.",
187
  }
188
  }
189
  output_type = "string"
190
 
191
+ def __init__(self, retriever, expander):
192
+ super().__init__()
193
+ self.retriever = retriever
194
+ self.expander = expander
195
+
196
+ def forward(self, query):
197
+ best_answer = self.retriever.retrieve(query, top_k=3)
198
+ if best_answer:
199
+ # If short, expand it
200
+ if len(best_answer.strip()) < 80:
201
+ logger.info("Answer is short. Expanding with LLM.")
202
+ best_answer = self.expander.expand(query, best_answer)
203
+ return best_answer
204
+ return "No relevant information found."
205
+
206
+ retriever_tool = RetrieverTool(retriever, answer_expander)
207
+
208
+ # --------------------------------------------------------------------------------
209
+ # DuckDuckGo (Web) Fallback
210
+ # --------------------------------------------------------------------------------
211
+ search_tool = DuckDuckGoSearchTool()
212
+
213
+ # --------------------------------------------------------------------------------
214
+ # Managed Agents
215
+ # --------------------------------------------------------------------------------
216
+ from smolagents import ManagedAgent, CodeAgent, LiteLLMModel
217
+
218
+ retriever_agent = ManagedAgent(
219
+ agent=CodeAgent(tools=[retriever_tool], model=LiteLLMModel("groq/llama3-8b-8192")),
220
+ name="retriever_agent",
221
+ description="Retrieves answers from the local knowledge base (CSV file)."
222
+ )
223
+
224
+ web_agent = ManagedAgent(
225
+ agent=CodeAgent(tools=[search_tool], model=HfApiModel()),
226
+ name="web_search_agent",
227
+ description="Performs web searches if the local knowledge base doesn't have an answer."
228
+ )
229
+
230
+ # --------------------------------------------------------------------------------
231
+ # Manager Agent to Orchestrate
232
+ # --------------------------------------------------------------------------------
233
+ manager_agent = CodeAgent(
234
+ tools=[],
235
+ model=HfApiModel(),
236
+ managed_agents=[retriever_agent, web_agent],
237
+ verbose=True
238
  )
239
 
240
+ # --------------------------------------------------------------------------------
241
+ # Gradio Interface
242
+ # --------------------------------------------------------------------------------
243
  def gradio_interface(query):
244
  try:
245
+ logger.info(f"User query: {query}")
246
+
247
+ # 1) Query local knowledge base
248
+ retriever_response = retriever_tool.forward(query)
249
+ if retriever_response != "No relevant information found.":
250
+ logger.info("Provided answer from local DB (possibly expanded).")
251
+ return (
252
+ f"Hello! This is **Daily Wellness AI**.\n\n"
253
+ f"{retriever_response}\n\n"
254
+ "Disclaimer: This is general wellness information, "
255
+ "not a substitute for professional medical advice.\n\n"
256
+ "Wishing you a calm and wonderful day!"
257
+ )
258
+
259
+ # 2) Fallback to Web if no relevant local info
260
+ logger.info("Falling back to web search.")
261
+ web_response = web_agent.run(query)
262
+ if web_response:
263
+ logger.info("Response retrieved from the web.")
264
+ return (
265
+ f"Hello! This is **Daily Wellness AI**.\n\n"
266
+ f"{web_response.strip()}\n\n"
267
+ "Disclaimer: This is general wellness information, "
268
+ "not a substitute for professional medical advice.\n\n"
269
+ "Wishing you a calm and wonderful day!"
270
+ )
271
+
272
+ # 3) Default fallback
273
+ logger.info("No response found from any source.")
274
+ return (
275
+ "Hello! This is **Daily Wellness AI**.\n\n"
276
+ "I'm sorry, I couldn't find an answer to your question. "
277
+ "Please try rephrasing or ask something else.\n\n"
278
+ "Take care, and have a wonderful day!"
279
+ )
280
  except Exception as e:
281
+ logger.error(f"Error processing query: {e}")
282
  return "**An error occurred while processing your request. Please try again later.**"
283
 
284
+ # --------------------------------------------------------------------------------
285
+ # Launch Gradio App
286
+ # --------------------------------------------------------------------------------
287
  interface = gr.Interface(
288
  fn=gradio_interface,
289
+ inputs=gr.Textbox(
290
+ label="Ask Daily Wellness AI",
291
+ placeholder="e.g., What is box breathing?"
292
+ ),
293
+ outputs=gr.Markdown(label="Answer from Daily Wellness AI"),
294
  title="Daily Wellness AI Guru Chatbot",
295
+ description=(
296
+ "Ask wellness-related questions to get detailed, creative answers from "
297
+ "our knowledge base—expanded by an LLM if needed—or from the web. "
298
+ "We aim to bring calm and positivity to your day."
299
+ ),
300
  theme="compact"
301
  )
302
 
303
+ def main():
304
+ interface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
305
+
306
+ # If running in a local environment, we can also just call main()
307
  if __name__ == "__main__":
308
+ main()