Phoenix21 commited on
Commit
1c73b9c
·
verified ·
1 Parent(s): 892745c

adding smolagent

Browse files
Files changed (1) hide show
  1. app.py +54 -15
app.py CHANGED
@@ -11,6 +11,9 @@ from sentence_transformers import SentenceTransformer, util, CrossEncoder
11
  from langchain.llms.base import LLM
12
  import google.generativeai as genai
13
 
 
 
 
14
  ###############################################################################
15
  # 1) Logging Setup
16
  ###############################################################################
@@ -199,7 +202,36 @@ class QuestionSanityChecker:
199
  sanity_checker = QuestionSanityChecker(llm)
200
 
201
  ###############################################################################
202
- # 7) Answer Expansion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  ###############################################################################
204
  class AnswerExpander:
205
  def __init__(self, llm: GeminiLLM):
@@ -230,7 +262,7 @@ class AnswerExpander:
230
  answer_expander = AnswerExpander(llm)
231
 
232
  ###############################################################################
233
- # 8) Query Handling
234
  ###############################################################################
235
  def handle_query(query: str) -> str:
236
  if not query or not isinstance(query, str) or len(query.strip()) == 0:
@@ -247,14 +279,26 @@ def handle_query(query: str) -> str:
247
  if not retrieved:
248
  return "I'm sorry, I couldn't find an answer to your question."
249
 
250
- # Optional: Check similarity threshold (if still desired)
251
  top_score = retrieved[0][1] # Assuming the list is sorted descending
252
  similarity_threshold = 0.3 # Adjust this threshold based on empirical results
253
 
254
  if top_score < similarity_threshold:
255
- return "I'm sorry, I didn't understand your question. Could you please rephrase it?"
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- # Proceed with answer expansion
258
  responses = [ans[0] for ans in retrieved]
259
  expanded_answer = answer_expander.expand(query, responses)
260
  return expanded_answer
@@ -264,18 +308,12 @@ def handle_query(query: str) -> str:
264
  return "An error occurred while processing your request."
265
 
266
  ###############################################################################
267
- # 9) Gradio Interface
268
  ###############################################################################
269
  def gradio_interface(query: str):
270
  try:
271
  response = handle_query(query)
272
- formatted_response = (
273
- f"**Daily Wellness AI**\n\n"
274
- f"{response}\n\n"
275
- "Disclaimer: This is general wellness information, "
276
- "not a substitute for professional medical advice.\n\n"
277
- "Wishing you a calm and wonderful day!"
278
- )
279
  return formatted_response
280
  except Exception as e:
281
  logger.error(f"Error in Gradio interface: {e}")
@@ -296,13 +334,14 @@ interface = gr.Interface(
296
  examples=[
297
  "What is box breathing and how does it help reduce anxiety?",
298
  "Provide a daily wellness schedule incorporating box breathing techniques.",
299
- "What are some tips for maintaining good posture while working at a desk?"
 
300
  ],
301
  allow_flagging="never"
302
  )
303
 
304
  ###############################################################################
305
- # 10) Launch Gradio
306
  ###############################################################################
307
  if __name__ == "__main__":
308
  try:
 
11
  from langchain.llms.base import LLM
12
  import google.generativeai as genai
13
 
14
+ # Import smolagents components
15
+ from smolagents import CodeAgent, LiteLLMModel, DuckDuckGoSearchTool, ManagedAgent
16
+
17
  ###############################################################################
18
  # 1) Logging Setup
19
  ###############################################################################
 
202
  sanity_checker = QuestionSanityChecker(llm)
203
 
204
  ###############################################################################
205
+ # 7) smolagents Integration: GROQ Model and Web Search
206
+ ###############################################################################
207
+ # Initialize the smolagents' LiteLLMModel with GROQ model
208
+ smol_model = LiteLLMModel("groq/llama3-8b-8192")
209
+
210
+ # Instantiate the DuckDuckGo search tool
211
+ search_tool = DuckDuckGoSearchTool()
212
+
213
+ # Create the web agent with the search tool
214
+ web_agent = CodeAgent(
215
+ tools=[search_tool],
216
+ model=smol_model
217
+ )
218
+
219
+ # Define the managed web agent
220
+ managed_web_agent = ManagedAgent(
221
+ agent=web_agent,
222
+ name="web_search",
223
+ description="Runs a web search for you. Provide your query as an argument."
224
+ )
225
+
226
+ # Create the manager agent with managed web agent and additional tools if needed
227
+ manager_agent = CodeAgent(
228
+ tools=[], # Add additional tools here if required
229
+ model=smol_model,
230
+ managed_agents=[managed_web_agent]
231
+ )
232
+
233
+ ###############################################################################
234
+ # 8) Answer Expansion
235
  ###############################################################################
236
  class AnswerExpander:
237
  def __init__(self, llm: GeminiLLM):
 
262
  answer_expander = AnswerExpander(llm)
263
 
264
  ###############################################################################
265
+ # 9) Query Handling
266
  ###############################################################################
267
  def handle_query(query: str) -> str:
268
  if not query or not isinstance(query, str) or len(query.strip()) == 0:
 
279
  if not retrieved:
280
  return "I'm sorry, I couldn't find an answer to your question."
281
 
282
+ # Check similarity threshold
283
  top_score = retrieved[0][1] # Assuming the list is sorted descending
284
  similarity_threshold = 0.3 # Adjust this threshold based on empirical results
285
 
286
  if top_score < similarity_threshold:
287
+ # Perform web search using manager_agent
288
+ logger.info("Similarity score below threshold. Performing web search.")
289
+ web_search_response = manager_agent.run(query)
290
+ logger.debug(f"Web search response: {web_search_response}")
291
+
292
+ # Optionally, process the web_search_response if needed
293
+ # For simplicity, return the web search response directly
294
+ return (
295
+ f"**Daily Wellness AI**\n\n"
296
+ f"{web_search_response}\n\n"
297
+ "Disclaimer: This information is retrieved from the web and is not a substitute for professional medical advice.\n\n"
298
+ "Wishing you a calm and wonderful day!"
299
+ )
300
 
301
+ # Proceed with answer expansion using retrieved_answers
302
  responses = [ans[0] for ans in retrieved]
303
  expanded_answer = answer_expander.expand(query, responses)
304
  return expanded_answer
 
308
  return "An error occurred while processing your request."
309
 
310
  ###############################################################################
311
+ # 10) Gradio Interface
312
  ###############################################################################
313
  def gradio_interface(query: str):
314
  try:
315
  response = handle_query(query)
316
+ formatted_response = response # Response is already formatted
 
 
 
 
 
 
317
  return formatted_response
318
  except Exception as e:
319
  logger.error(f"Error in Gradio interface: {e}")
 
334
  examples=[
335
  "What is box breathing and how does it help reduce anxiety?",
336
  "Provide a daily wellness schedule incorporating box breathing techniques.",
337
+ "What are some tips for maintaining good posture while working at a desk?",
338
+ "Who is the CEO of Hugging Face?" # Example of an out-of-context question
339
  ],
340
  allow_flagging="never"
341
  )
342
 
343
  ###############################################################################
344
+ # 11) Launch Gradio
345
  ###############################################################################
346
  if __name__ == "__main__":
347
  try: