adding smolagent
Browse files
app.py
CHANGED
@@ -11,6 +11,9 @@ from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
|
11 |
from langchain.llms.base import LLM
|
12 |
import google.generativeai as genai
|
13 |
|
|
|
|
|
|
|
14 |
###############################################################################
|
15 |
# 1) Logging Setup
|
16 |
###############################################################################
|
@@ -199,7 +202,36 @@ class QuestionSanityChecker:
|
|
199 |
sanity_checker = QuestionSanityChecker(llm)
|
200 |
|
201 |
###############################################################################
|
202 |
-
# 7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
###############################################################################
|
204 |
class AnswerExpander:
|
205 |
def __init__(self, llm: GeminiLLM):
|
@@ -230,7 +262,7 @@ class AnswerExpander:
|
|
230 |
answer_expander = AnswerExpander(llm)
|
231 |
|
232 |
###############################################################################
|
233 |
-
#
|
234 |
###############################################################################
|
235 |
def handle_query(query: str) -> str:
|
236 |
if not query or not isinstance(query, str) or len(query.strip()) == 0:
|
@@ -247,14 +279,26 @@ def handle_query(query: str) -> str:
|
|
247 |
if not retrieved:
|
248 |
return "I'm sorry, I couldn't find an answer to your question."
|
249 |
|
250 |
-
#
|
251 |
top_score = retrieved[0][1] # Assuming the list is sorted descending
|
252 |
similarity_threshold = 0.3 # Adjust this threshold based on empirical results
|
253 |
|
254 |
if top_score < similarity_threshold:
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
-
# Proceed with answer expansion
|
258 |
responses = [ans[0] for ans in retrieved]
|
259 |
expanded_answer = answer_expander.expand(query, responses)
|
260 |
return expanded_answer
|
@@ -264,18 +308,12 @@ def handle_query(query: str) -> str:
|
|
264 |
return "An error occurred while processing your request."
|
265 |
|
266 |
###############################################################################
|
267 |
-
#
|
268 |
###############################################################################
|
269 |
def gradio_interface(query: str):
|
270 |
try:
|
271 |
response = handle_query(query)
|
272 |
-
formatted_response =
|
273 |
-
f"**Daily Wellness AI**\n\n"
|
274 |
-
f"{response}\n\n"
|
275 |
-
"Disclaimer: This is general wellness information, "
|
276 |
-
"not a substitute for professional medical advice.\n\n"
|
277 |
-
"Wishing you a calm and wonderful day!"
|
278 |
-
)
|
279 |
return formatted_response
|
280 |
except Exception as e:
|
281 |
logger.error(f"Error in Gradio interface: {e}")
|
@@ -296,13 +334,14 @@ interface = gr.Interface(
|
|
296 |
examples=[
|
297 |
"What is box breathing and how does it help reduce anxiety?",
|
298 |
"Provide a daily wellness schedule incorporating box breathing techniques.",
|
299 |
-
"What are some tips for maintaining good posture while working at a desk?"
|
|
|
300 |
],
|
301 |
allow_flagging="never"
|
302 |
)
|
303 |
|
304 |
###############################################################################
|
305 |
-
#
|
306 |
###############################################################################
|
307 |
if __name__ == "__main__":
|
308 |
try:
|
|
|
11 |
from langchain.llms.base import LLM
|
12 |
import google.generativeai as genai
|
13 |
|
14 |
+
# Import smolagents components
|
15 |
+
from smolagents import CodeAgent, LiteLLMModel, DuckDuckGoSearchTool, ManagedAgent
|
16 |
+
|
17 |
###############################################################################
|
18 |
# 1) Logging Setup
|
19 |
###############################################################################
|
|
|
202 |
sanity_checker = QuestionSanityChecker(llm)
|
203 |
|
204 |
###############################################################################
|
205 |
+
# 7) smolagents Integration: GROQ Model and Web Search
|
206 |
+
###############################################################################
|
207 |
+
# Initialize the smolagents' LiteLLMModel with GROQ model
|
208 |
+
smol_model = LiteLLMModel("groq/llama3-8b-8192")
|
209 |
+
|
210 |
+
# Instantiate the DuckDuckGo search tool
|
211 |
+
search_tool = DuckDuckGoSearchTool()
|
212 |
+
|
213 |
+
# Create the web agent with the search tool
|
214 |
+
web_agent = CodeAgent(
|
215 |
+
tools=[search_tool],
|
216 |
+
model=smol_model
|
217 |
+
)
|
218 |
+
|
219 |
+
# Define the managed web agent
|
220 |
+
managed_web_agent = ManagedAgent(
|
221 |
+
agent=web_agent,
|
222 |
+
name="web_search",
|
223 |
+
description="Runs a web search for you. Provide your query as an argument."
|
224 |
+
)
|
225 |
+
|
226 |
+
# Create the manager agent with managed web agent and additional tools if needed
|
227 |
+
manager_agent = CodeAgent(
|
228 |
+
tools=[], # Add additional tools here if required
|
229 |
+
model=smol_model,
|
230 |
+
managed_agents=[managed_web_agent]
|
231 |
+
)
|
232 |
+
|
233 |
+
###############################################################################
|
234 |
+
# 8) Answer Expansion
|
235 |
###############################################################################
|
236 |
class AnswerExpander:
|
237 |
def __init__(self, llm: GeminiLLM):
|
|
|
262 |
answer_expander = AnswerExpander(llm)
|
263 |
|
264 |
###############################################################################
|
265 |
+
# 9) Query Handling
|
266 |
###############################################################################
|
267 |
def handle_query(query: str) -> str:
|
268 |
if not query or not isinstance(query, str) or len(query.strip()) == 0:
|
|
|
279 |
if not retrieved:
|
280 |
return "I'm sorry, I couldn't find an answer to your question."
|
281 |
|
282 |
+
# Check similarity threshold
|
283 |
top_score = retrieved[0][1] # Assuming the list is sorted descending
|
284 |
similarity_threshold = 0.3 # Adjust this threshold based on empirical results
|
285 |
|
286 |
if top_score < similarity_threshold:
|
287 |
+
# Perform web search using manager_agent
|
288 |
+
logger.info("Similarity score below threshold. Performing web search.")
|
289 |
+
web_search_response = manager_agent.run(query)
|
290 |
+
logger.debug(f"Web search response: {web_search_response}")
|
291 |
+
|
292 |
+
# Optionally, process the web_search_response if needed
|
293 |
+
# For simplicity, return the web search response directly
|
294 |
+
return (
|
295 |
+
f"**Daily Wellness AI**\n\n"
|
296 |
+
f"{web_search_response}\n\n"
|
297 |
+
"Disclaimer: This information is retrieved from the web and is not a substitute for professional medical advice.\n\n"
|
298 |
+
"Wishing you a calm and wonderful day!"
|
299 |
+
)
|
300 |
|
301 |
+
# Proceed with answer expansion using retrieved_answers
|
302 |
responses = [ans[0] for ans in retrieved]
|
303 |
expanded_answer = answer_expander.expand(query, responses)
|
304 |
return expanded_answer
|
|
|
308 |
return "An error occurred while processing your request."
|
309 |
|
310 |
###############################################################################
|
311 |
+
# 10) Gradio Interface
|
312 |
###############################################################################
|
313 |
def gradio_interface(query: str):
|
314 |
try:
|
315 |
response = handle_query(query)
|
316 |
+
formatted_response = response # Response is already formatted
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
return formatted_response
|
318 |
except Exception as e:
|
319 |
logger.error(f"Error in Gradio interface: {e}")
|
|
|
334 |
examples=[
|
335 |
"What is box breathing and how does it help reduce anxiety?",
|
336 |
"Provide a daily wellness schedule incorporating box breathing techniques.",
|
337 |
+
"What are some tips for maintaining good posture while working at a desk?",
|
338 |
+
"Who is the CEO of Hugging Face?" # Example of an out-of-context question
|
339 |
],
|
340 |
allow_flagging="never"
|
341 |
)
|
342 |
|
343 |
###############################################################################
|
344 |
+
# 11) Launch Gradio
|
345 |
###############################################################################
|
346 |
if __name__ == "__main__":
|
347 |
try:
|